Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a linear accessor to RMM cuda stream pool #696

Merged
28 changes: 28 additions & 0 deletions include/rmm/cuda_stream_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/error.hpp>

#include <atomic>
#include <vector>
Expand Down Expand Up @@ -61,6 +62,33 @@ class cuda_stream_pool {
return streams_[(next_stream++) % streams_.size()].view();
}

/**
* @brief Get a `cuda_stream_view` of the stream at stream_id index the pool.
afender marked this conversation as resolved.
Show resolved Hide resolved
afender marked this conversation as resolved.
Show resolved Hide resolved
*
* This function is thread safe with respect to other calls to the same function.
*
* @throws rmm::out_of_range exception if `stream_index >= size()`
*
* @param stream_id The index of the stream in the pool
afender marked this conversation as resolved.
Show resolved Hide resolved
*
* @return rmm::cuda_stream_view
*/
rmm::cuda_stream_view get_stream(std::size_t stream_index) const
{
RMM_EXPECTS(
stream_index < streams_.size(), rmm::out_of_range, "Attempt to access out of bounds stream.");
afender marked this conversation as resolved.
Show resolved Hide resolved
return streams_[stream_index].view();
}

/**
* @brief Get the number of streams in the pool.
*
* This function is thread safe with respect to other calls to the same function.
*
* @return the number of streams in the pool
*/
size_t get_pool_size() const noexcept { return streams_.size(); }

private:
std::vector<rmm::cuda_stream> streams_;
mutable std::atomic_size_t next_stream{};
Expand Down
21 changes: 20 additions & 1 deletion tests/cuda_stream_pool_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,26 @@ TEST_F(CudaStreamPoolTest, ValidStreams)
RMM_CUDA_TRY(cudaMemsetAsync(v.data(), 0xcc, 100, stream_a.value()));
stream_a.synchronize();

auto v2 = rmm::device_uvector<uint8_t>{v, stream_b};
auto v2 = rmm::device_uvector<std::uint8_t>{v, stream_b};
auto x = v2.front_element(stream_b);
EXPECT_EQ(x, 0xcc);
}

TEST_F(CudaStreamPoolTest, PoolSize) { EXPECT_GE(this->pool.get_pool_size(), 1); }

TEST_F(CudaStreamPoolTest, OutOfBoundLinearAccess)
{
EXPECT_NO_THROW(this->pool.get_stream(this->pool.get_pool_size() - 1));
EXPECT_THROW(this->pool.get_stream(this->pool.get_pool_size()), rmm::out_of_range);
}

TEST_F(CudaStreamPoolTest, ValidLinearAccess)
{
auto const stream_a = this->pool.get_stream(0);
auto const stream_b = this->pool.get_stream(1);
EXPECT_NE(stream_a, stream_b);
EXPECT_FALSE(stream_a.is_default());
EXPECT_FALSE(stream_a.is_per_thread_default());
EXPECT_FALSE(stream_b.is_default());
EXPECT_FALSE(stream_b.is_per_thread_default());
harrism marked this conversation as resolved.
Show resolved Hide resolved
}