Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] add repeat, sparsity, eval_n_elements APIs to bitset #2439

Open
wants to merge 1 commit into
base: branch-24.10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <raft/util/device_atomics.cuh>
#include <raft/util/popc.cuh>

#include <rmm/device_scalar.hpp>

#include <thrust/for_each.h>

namespace raft::core {
Expand Down Expand Up @@ -60,6 +62,102 @@ _RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index
}
}

template <typename bitset_t, typename index_t>
struct bitset_copy_functor {
const bitset_t* bitset_ptr;
bitset_t* output_device_ptr;
index_t valid_bits;
index_t bits_per_element;
index_t total_bits;

bitset_copy_functor(const bitset_t* _bitset_ptr,
bitset_t* _output_device_ptr,
index_t _valid_bits,
index_t _bits_per_element,
index_t _total_bits)
: bitset_ptr(_bitset_ptr),
output_device_ptr(_output_device_ptr),
valid_bits(_valid_bits),
bits_per_element(_bits_per_element),
total_bits(_total_bits)
{
}

__device__ void operator()(index_t i)
{
if (i < total_bits) {
index_t src_bit_index = i % valid_bits;
index_t dst_bit_index = i;

index_t src_element_index = src_bit_index / bits_per_element;
index_t src_bit_offset = src_bit_index % bits_per_element;

index_t dst_element_index = dst_bit_index / bits_per_element;
index_t dst_bit_offset = dst_bit_index % bits_per_element;

bitset_t src_element = bitset_ptr[src_element_index];
bitset_t src_bit = (src_element >> src_bit_offset) & 1;

if (src_bit) {
atomicOr(output_device_ptr + dst_element_index, bitset_t(1) << dst_bit_offset);
} else {
atomicAnd(output_device_ptr + dst_element_index, ~(bitset_t(1) << dst_bit_offset));
}
}
}
};

template <typename bitset_t, typename index_t>
void bitset_view<bitset_t, index_t>::repeat(const raft::resources& res,
index_t times,
bitset_t* output_device_ptr) const
{
auto thrust_policy = raft::resource::get_thrust_policy(res);
constexpr index_t bits_per_element = sizeof(bitset_t) * 8;

if (bitset_len_ % bits_per_element == 0) {
index_t num_elements_to_copy = bitset_len_ / bits_per_element;

for (index_t i = 0; i < times; ++i) {
raft::copy(output_device_ptr + i * num_elements_to_copy,
bitset_ptr_,
num_elements_to_copy,
raft::resource::get_cuda_stream(res));
}
} else {
index_t valid_bits = bitset_len_;
index_t total_bits = valid_bits * times;
index_t output_row_elements = (total_bits + bits_per_element - 1) / bits_per_element;
thrust::for_each_n(thrust_policy,
thrust::counting_iterator<index_t>(0),
total_bits,
bitset_copy_functor<bitset_t, index_t>(
bitset_ptr_, output_device_ptr, valid_bits, bits_per_element, total_bits));
}
}

template <typename bitset_t, typename index_t>
double bitset_view<bitset_t, index_t>::sparsity(const raft::resources& res) const
{
index_t nnz_h = 0;
index_t size_h = this->size();
auto stream = raft::resource::get_cuda_stream(res);

if (0 == size_h) { return static_cast<double>(1.0); }

rmm::device_scalar<index_t> nnz(0, stream);

auto vector_view = raft::make_device_vector_view<const bitset_t, index_t>(data(), n_elements());
auto nnz_view = raft::make_device_scalar_view<index_t>(nnz.data());
auto size_view = raft::make_host_scalar_view<index_t>(&size_h);

raft::popc(res, vector_view, size_view, nnz_view);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The count() method can also be implemented in bitset_view, it will be useful for users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also noticing that if the number of bits is not a multiple of the bitset element size the result might be wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also noticing that if the number of bits is not a multiple of the bitset element size the result might be wrong.

Should not, because the size is the actual number of bits. Maybe I miss something?

raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
return static_cast<double>((1.0 * (size_h - nnz_h)) / (1.0 * size_h));
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
Expand Down
53 changes: 53 additions & 0 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

#include <cmath>

namespace raft::core {
/**
* @defgroup bitset Bitset
Expand Down Expand Up @@ -104,6 +106,57 @@ struct bitset_view {
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_ptr_, n_elements());
}

/**
* @brief Repeats the bitset data and copies it to the output device pointer.
*
* This function takes the original bitset data stored in the device memory
* and repeats it a specified number of times into a new location in the device memory.
* The bits are copied bit-by-bit to ensure that even if the number of bits (bitset_len_)
* is not a multiple of the bitset element size (e.g., 32 for uint32_t), the bits are
* tightly packed without any gaps between rows.
*
* @param res RAFT resources for managing CUDA streams and execution policies.
* @param times Number of times the bitset data should be repeated in the output.
* @param output_device_ptr Device pointer where the repeated bitset data will be stored.
*
* The caller must ensure that the output device pointer has enough memory allocated
* to hold `times * bitset_len` bits, where `bitset_len` is the number of bits in the original
* bitset. This function uses Thrust parallel algorithms to efficiently perform the operation on
* the GPU.
*/
void repeat(const raft::resources& res, index_t times, bitset_t* output_device_ptr) const;

/**
* @brief Calculate the sparsity (fraction of 0s) of the bitset.
*
* This function computes the sparsity of the bitset, defined as the ratio of unset bits (0s)
* to the total number of bits in the set. If the total number of bits is zero, the function
* returns 1.0, indicating the set is fully sparse.
*
* @param res RAFT resources for managing CUDA streams and execution policies.
* @return double The sparsity of the bitset, i.e., the fraction of unset bits.
*
* This API will synchronize on the stream of `res`.
*/
double sparsity(const raft::resources& res) const;

/**
* @brief Calculates the number of `bitset_t` elements required to store a bitset.
*
* This function computes the number of `bitset_t` elements needed to store a bitset, ensuring
* that all bits are accounted for. If the bitset length is not a multiple of the `bitset_t` size
* (in bits), the calculation rounds up to include the remaining bits in an additional `bitset_t`
* element.
*
* @param bitset_len The total length of the bitset in bits.
* @return size_t The number of `bitset_t` elements required to store the bitset.
*/
static inline size_t eval_n_elements(size_t bitset_len)
{
const size_t bits_per_element = sizeof(bitset_t) * 8;
return (bitset_len + bits_per_element - 1) / bits_per_element;
}

private:
bitset_t* bitset_ptr_;
index_t bitset_len_;
Expand Down
102 changes: 92 additions & 10 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,12 +32,13 @@ struct test_spec_bitset {
uint64_t bitset_len;
uint64_t mask_len;
uint64_t query_len;
uint64_t repeat_times;
};

auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream&
{
os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len
<< ", query_len: " << ss.query_len << "}";
<< ", query_len: " << ss.query_len << ", repeat_times: " << ss.repeat_times << "}";
return os;
}

Expand Down Expand Up @@ -80,20 +81,68 @@ void flip_cpu_bitset(std::vector<bitset_t>& bitset)
}
}

template <typename bitset_t>
void repeat_cpu_bitset(std::vector<bitset_t>& input,
size_t input_bits,
size_t repeat,
std::vector<bitset_t>& output)
{
const size_t output_bits = input_bits * repeat;
const size_t output_units = (output_bits + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8);

std::memset(output.data(), 0, output_units * sizeof(bitset_t));

size_t output_bit_index = 0;

for (size_t r = 0; r < repeat; ++r) {
for (size_t i = 0; i < input_bits; ++i) {
size_t input_unit_index = i / (sizeof(bitset_t) * 8);
size_t input_bit_offset = i % (sizeof(bitset_t) * 8);
bool bit = (input[input_unit_index] >> input_bit_offset) & 1;

size_t output_unit_index = output_bit_index / (sizeof(bitset_t) * 8);
size_t output_bit_offset = output_bit_index % (sizeof(bitset_t) * 8);

output[output_unit_index] |= (static_cast<bitset_t>(bit) << output_bit_offset);

++output_bit_index;
}
}
}

template <typename bitset_t>
double sparsity_cpu_bitset(std::vector<bitset_t>& data, size_t total_bits)
{
size_t one_count = 0;
for (size_t i = 0; i < total_bits; ++i) {
size_t unit_index = i / (sizeof(bitset_t) * 8);
size_t bit_offset = i % (sizeof(bitset_t) * 8);
bool bit = (data[unit_index] >> bit_offset) & 1;
if (bit == 1) { ++one_count; }
}
return static_cast<double>((total_bits - one_count) / (1.0 * total_bits));
}

template <typename bitset_t, typename index_t>
class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
protected:
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
const test_spec_bitset spec;
std::vector<bitset_t> bitset_result;
std::vector<bitset_t> bitset_ref;
std::vector<bitset_t> bitset_repeat_ref;
std::vector<bitset_t> bitset_repeat_result;
raft::resources res;

public:
explicit BitsetTest()
: spec(testing::TestWithParam<test_spec_bitset>::GetParam()),
bitset_result(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))),
bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size)))
bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))),
bitset_repeat_ref(
raft::ceildiv(spec.bitset_len * spec.repeat_times, uint64_t(bitset_element_size))),
bitset_repeat_result(
raft::ceildiv(spec.bitset_len * spec.repeat_times, uint64_t(bitset_element_size)))
{
}

Expand Down Expand Up @@ -145,6 +194,37 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// test sparsity, repeat and eval_n_elements
if constexpr (std::is_same_v<bitset_t, uint32_t> || std::is_same_v<bitset_t, uint64_t>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why restrict the test to those data types?

auto my_bitset_view = my_bitset.view();
auto sparsity_result = my_bitset_view.sparsity(res);
auto sparsity_ref = sparsity_cpu_bitset(bitset_ref, size_t(spec.bitset_len));
ASSERT_EQ(sparsity_result, sparsity_ref);

auto eval_n_elements =
bitset_view<bitset_t, index_t>::eval_n_elements(spec.bitset_len * spec.repeat_times);
ASSERT_EQ(bitset_repeat_ref.size(), eval_n_elements);

auto repeat_device = raft::make_device_vector<bitset_t, index_t>(res, eval_n_elements);
RAFT_CUDA_TRY(cudaMemsetAsync(
repeat_device.data_handle(), 0, eval_n_elements * sizeof(bitset_t), stream));
repeat_cpu_bitset(
bitset_ref, size_t(spec.bitset_len), size_t(spec.repeat_times), bitset_repeat_ref);

my_bitset_view.repeat(res, index_t(spec.repeat_times), repeat_device.data_handle());

ASSERT_EQ(bitset_repeat_ref.size(), repeat_device.size());
update_host(
bitset_repeat_result.data(), repeat_device.data_handle(), repeat_device.size(), stream);
ASSERT_EQ(bitset_repeat_ref.size(), bitset_repeat_result.size());
ASSERT_TRUE(hostVecMatch(bitset_repeat_ref, bitset_repeat_result, raft::Compare<bitset_t>()));

// recheck the sparsity after repeat
sparsity_result =
sparsity_cpu_bitset(bitset_repeat_result, size_t(spec.bitset_len * spec.repeat_times));
ASSERT_EQ(sparsity_result, sparsity_ref);
}

// Flip the bitset and re-test
auto bitset_count = my_bitset.count(res);
my_bitset.flip(res);
Expand All @@ -167,13 +247,15 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
}
};

auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10},
test_spec_bitset{100, 30, 10},
test_spec_bitset{1024, 55, 100},
test_spec_bitset{10000, 1000, 1000},
test_spec_bitset{1 << 15, 1 << 3, 1 << 12},
test_spec_bitset{1 << 15, 1 << 24, 1 << 13},
test_spec_bitset{1 << 25, 1 << 23, 1 << 14});
auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10, 101},
test_spec_bitset{100, 30, 10, 13},
test_spec_bitset{1024, 55, 100, 1},
test_spec_bitset{10000, 1000, 1000, 100},
test_spec_bitset{1 << 15, 1 << 3, 1 << 12, 5},
test_spec_bitset{1 << 15, 1 << 24, 1 << 13, 3},
test_spec_bitset{1 << 25, 1 << 23, 1 << 14, 3},
test_spec_bitset{1 << 25, 1 << 23, 1 << 14, 201},
test_spec_bitset{10000000, 1 << 23, 1 << 14, 401});

using Uint16_32 = BitsetTest<uint16_t, uint32_t>;
TEST_P(Uint16_32, Run) { run(); }
Expand Down
Loading