Skip to content

Commit

Permalink
Fix comparison of async MRs with different underlying pools. (#965)
Browse files Browse the repository at this point in the history
Fixes #899

Adds a test that comparison of async MRs with  different underlying cudaMempool handles returns false, and implements the correct behavior.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)
  - Rong Ou (https://github.com/rongou)

URL: #965
  • Loading branch information
harrism authored Feb 4, 2022
1 parent dc2a544 commit 6e2821a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
9 changes: 7 additions & 2 deletions include/rmm/mr/device/cuda_async_memory_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -187,7 +187,12 @@ class cuda_async_memory_resource final : public device_memory_resource {
*/
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
return dynamic_cast<cuda_async_memory_resource const*>(&other) != nullptr;
auto const* async_mr = dynamic_cast<cuda_async_memory_resource const*>(&other);
#ifdef RMM_CUDA_MALLOC_ASYNC_SUPPORT
return (async_mr != nullptr) && (this->pool_handle() == async_mr->pool_handle());
#else
return async_mr != nullptr;
#endif
}

/**
Expand Down
17 changes: 13 additions & 4 deletions tests/mr/device/cuda_async_mr_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,7 @@ namespace {

using cuda_async_mr = rmm::mr::cuda_async_memory_resource;

TEST(PoolTest, ThrowIfNotSupported)
TEST(AsyncMRTest, ThrowIfNotSupported)
{
auto construct_mr = []() { cuda_async_mr mr; };
#ifndef RMM_CUDA_MALLOC_ASYNC_SUPPORT
Expand All @@ -35,7 +35,7 @@ TEST(PoolTest, ThrowIfNotSupported)
}

#if defined(RMM_CUDA_MALLOC_ASYNC_SUPPORT)
TEST(PoolTest, ExplicitInitialPoolSize)
TEST(AsyncMRTest, ExplicitInitialPoolSize)
{
const auto pool_init_size{100};
cuda_async_mr mr{pool_init_size};
Expand All @@ -44,7 +44,7 @@ TEST(PoolTest, ExplicitInitialPoolSize)
RMM_CUDA_TRY(cudaDeviceSynchronize());
}

TEST(PoolTest, ExplicitReleaseThreshold)
TEST(AsyncMRTest, ExplicitReleaseThreshold)
{
const auto pool_init_size{100};
const auto pool_release_threshold{1000};
Expand All @@ -54,6 +54,15 @@ TEST(PoolTest, ExplicitReleaseThreshold)
RMM_CUDA_TRY(cudaDeviceSynchronize());
}

TEST(AsyncMRTest, DifferentPoolsUnequal)
{
const auto pool_init_size{100};
const auto pool_release_threshold{1000};
cuda_async_mr mr1{pool_init_size, pool_release_threshold};
cuda_async_mr mr2{pool_init_size, pool_release_threshold};
EXPECT_FALSE(mr1.is_equal(mr2));
}

#endif

} // namespace
Expand Down

0 comments on commit 6e2821a

Please sign in to comment.