Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a linear accessor to RMM cuda stream pool #696

Merged
25 changes: 25 additions & 0 deletions include/rmm/cuda_stream_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/error.hpp>

#include <atomic>
#include <vector>
Expand Down Expand Up @@ -61,6 +62,30 @@ class cuda_stream_pool {
return streams_[(next_stream++) % streams_.size()].view();
}

/**
* @brief Get a `cuda_stream_view` of the stream associated with `stream_id`.
* Equivalent values of `stream_id` return a stream_view to the same underlying stream.
*
* This function is thread safe with respect to other calls to the same function.
*
* @param stream_id Unique identifier for the desired stream
*
* @return rmm::cuda_stream_view
*/
rmm::cuda_stream_view get_stream(std::size_t stream_id) const
{
return streams_[stream_id % streams_.size()].view();
}

/**
* @brief Get the number of streams in the pool.
*
* This function is thread safe with respect to other calls to the same function.
*
* @return the number of streams in the pool
*/
size_t get_pool_size() const noexcept { return streams_.size(); }

private:
std::vector<rmm::cuda_stream> streams_;
mutable std::atomic_size_t next_stream{};
Expand Down
18 changes: 17 additions & 1 deletion tests/cuda_stream_pool_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,23 @@ TEST_F(CudaStreamPoolTest, ValidStreams)
RMM_CUDA_TRY(cudaMemsetAsync(v.data(), 0xcc, 100, stream_a.value()));
stream_a.synchronize();

auto v2 = rmm::device_uvector<uint8_t>{v, stream_b};
auto v2 = rmm::device_uvector<std::uint8_t>{v, stream_b};
auto x = v2.front_element(stream_b);
EXPECT_EQ(x, 0xcc);
}

TEST_F(CudaStreamPoolTest, PoolSize) { EXPECT_GE(this->pool.get_pool_size(), 1); }

TEST_F(CudaStreamPoolTest, OutOfBoundLinearAccess)
{
auto const stream_a = this->pool.get_stream(0);
auto const stream_b = this->pool.get_stream(this->pool.get_pool_size());
EXPECT_EQ(stream_a, stream_b);
}

TEST_F(CudaStreamPoolTest, ValidLinearAccess)
{
auto const stream_a = this->pool.get_stream(0);
auto const stream_b = this->pool.get_stream(1);
EXPECT_NE(stream_a, stream_b);
}