Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix select if for mixed types #444

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ struct AgentSelectIf
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The output value type
using OutputT =
cub::detail::non_void_value_t<SelectedOutputIteratorT, InputT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;

Expand Down Expand Up @@ -156,7 +152,7 @@ struct AgentSelectIf
FlagsInputIteratorT>;

// Parameterized BlockLoad type for input data
using BlockLoadT = BlockLoad<OutputT,
using BlockLoadT = BlockLoad<InputT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSelectIfPolicyT::LOAD_ALGORITHM>;
Expand All @@ -168,7 +164,7 @@ struct AgentSelectIf
AgentSelectIfPolicyT::LOAD_ALGORITHM>;

// Parameterized BlockDiscontinuity type for items
using BlockDiscontinuityT = BlockDiscontinuity<OutputT, BLOCK_THREADS>;
using BlockDiscontinuityT = BlockDiscontinuity<InputT, BLOCK_THREADS>;

// Parameterized BlockScan type
using BlockScanT =
Expand All @@ -179,7 +175,7 @@ struct AgentSelectIf
TilePrefixCallbackOp<OffsetT, cub::Sum, ScanTileStateT>;

// Item exchange type
typedef OutputT ItemExchangeT[TILE_ITEMS];
typedef InputT ItemExchangeT[TILE_ITEMS];

// Shared memory type for this thread block
union _TempStorage
Expand Down Expand Up @@ -254,7 +250,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT /*tile_offset*/,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_OP> /*select_method*/)
{
Expand All @@ -277,7 +273,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&/*items*/)[ITEMS_PER_THREAD],
InputT (&/*items*/)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_FLAGS> /*select_method*/)
{
Expand Down Expand Up @@ -311,7 +307,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_DISCONTINUITY> /*select_method*/)
{
Expand All @@ -324,7 +320,7 @@ struct AgentSelectIf
}
else
{
OutputT tile_predecessor;
InputT tile_predecessor;
if (threadIdx.x == 0)
tile_predecessor = d_in[tile_offset - 1];

Expand Down Expand Up @@ -353,7 +349,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterDirect(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
OffsetT num_selections)
Expand All @@ -378,7 +374,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int /*num_tile_items*/, ///< Number of valid items in this tile
Expand Down Expand Up @@ -414,7 +410,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
Expand Down Expand Up @@ -454,7 +450,7 @@ struct AgentSelectIf
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;

OutputT item = temp_storage.raw_exchange.Alias()[item_idx];
InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
Expand All @@ -469,7 +465,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void Scatter(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
Expand Down Expand Up @@ -515,7 +511,7 @@ struct AgentSelectIf
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
InputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];

Expand Down Expand Up @@ -575,7 +571,7 @@ struct AgentSelectIf
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
InputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];

Expand Down
13 changes: 9 additions & 4 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1278,9 +1278,11 @@ struct DispatchRadixSort :
const PortionOffsetT PORTION_SIZE = ((1 << 28) - 1) / ONESWEEP_TILE_ITEMS * ONESWEEP_TILE_ITEMS;
int num_passes = cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS);
OffsetT num_portions = static_cast<OffsetT>(cub::DivideAndRoundUp(num_items, PORTION_SIZE));
PortionOffsetT max_num_blocks = cub::DivideAndRoundUp(CUB_MIN(num_items, PORTION_SIZE),
ONESWEEP_TILE_ITEMS);

PortionOffsetT max_num_blocks = cub::DivideAndRoundUp(
static_cast<int>(
CUB_MIN(num_items, static_cast<OffsetT>(PORTION_SIZE))),
ONESWEEP_TILE_ITEMS);

size_t value_size = KEYS_ONLY ? 0 : sizeof(ValueT);
size_t allocation_sizes[] =
{
Expand Down Expand Up @@ -1355,7 +1357,10 @@ struct DispatchRadixSort :
int num_bits = CUB_MIN(end_bit - current_bit, RADIX_BITS);
for (OffsetT portion = 0; portion < num_portions; ++portion)
{
PortionOffsetT portion_num_items = CUB_MIN(num_items - portion * PORTION_SIZE, PORTION_SIZE);
PortionOffsetT portion_num_items =
static_cast<PortionOffsetT>(
CUB_MIN(num_items - portion * PORTION_SIZE,
static_cast<OffsetT>(PORTION_SIZE)));
PortionOffsetT num_blocks =
cub::DivideAndRoundUp(portion_num_items, ONESWEEP_TILE_ITEMS);
if (CubDebug(error = cudaMemsetAsync(
Expand Down
8 changes: 3 additions & 5 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ struct DispatchSelectIf
* Types and constants
******************************************************************************/

// The output value type
using OutputT =
cub::detail::non_void_value_t<SelectedOutputIteratorT,
cub::detail::value_t<InputIteratorT>>;
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;
Comment on lines +132 to +133
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 👍 I probably would have missed adjusting the tuning policies.


// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;
Expand All @@ -155,7 +153,7 @@ struct DispatchSelectIf
{
enum {
NOMINAL_4B_ITEMS_PER_THREAD = 10,
ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))),
ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))),
};

typedef AgentSelectIfPolicy<
Expand Down
17 changes: 13 additions & 4 deletions test/test_device_radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include <algorithm>
#include <cstdio>
#include <limits>
#include <memory>
#include <random>
#include <type_traits>
Expand Down Expand Up @@ -283,9 +284,11 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

return DeviceSegmentedRadixSort::SortPairs(
d_temp_storage, temp_storage_bytes,
d_keys, d_values, num_items,
d_keys, d_values, static_cast<int>(num_items),
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
num_segments, d_segment_begin_offsets, d_segment_end_offsets,
begin_bit, end_bit, stream, debug_synchronous);
}
Expand Down Expand Up @@ -317,13 +320,15 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

KeyT const *const_keys_itr = d_keys.Current();
ValueT const *const_values_itr = d_values.Current();

cudaError_t retval = DeviceSegmentedRadixSort::SortPairs(
d_temp_storage, temp_storage_bytes,
const_keys_itr, d_keys.Alternate(), const_values_itr, d_values.Alternate(),
num_items, num_segments, d_segment_begin_offsets, d_segment_end_offsets,
static_cast<int>(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets,
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
begin_bit, end_bit, stream, debug_synchronous);

d_keys.selector ^= 1;
Expand Down Expand Up @@ -359,9 +364,11 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

return DeviceSegmentedRadixSort::SortPairsDescending(
d_temp_storage, temp_storage_bytes,
d_keys, d_values, num_items,
d_keys, d_values, static_cast<int>(num_items),
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
num_segments, d_segment_begin_offsets, d_segment_end_offsets,
begin_bit, end_bit, stream, debug_synchronous);
}
Expand Down Expand Up @@ -393,13 +400,15 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

KeyT const *const_keys_itr = d_keys.Current();
ValueT const *const_values_itr = d_values.Current();

cudaError_t retval = DeviceSegmentedRadixSort::SortPairsDescending(
d_temp_storage, temp_storage_bytes,
const_keys_itr, d_keys.Alternate(), const_values_itr, d_values.Alternate(),
num_items, num_segments, d_segment_begin_offsets, d_segment_end_offsets,
static_cast<int>(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets,
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
begin_bit, end_bit, stream, debug_synchronous);

d_keys.selector ^= 1;
Expand Down
76 changes: 76 additions & 0 deletions test/test_device_select_if.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
#include <cub/device/device_partition.cuh>
#include <cub/iterator/counting_input_iterator.cuh>

#include <thrust/count.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/device_vector.h>

#include "test_util.h"

using namespace cub;
Expand Down Expand Up @@ -652,6 +658,74 @@ void Test(
}
}

template<class T0, class T1>
struct pair_to_col_t
{
__host__ __device__ T0 operator()(const thrust::tuple<T0, T1> &in)
{
return thrust::get<0>(in);
}
};

template<class T0, class T1>
struct select_t {
__host__ __device__ bool operator()(const thrust::tuple<T0, T1> &in) {
return static_cast<T1>(thrust::get<0>(in)) > thrust::get<1>(in);
}
};

template <typename T0, typename T1>
void TestMixedOp(int num_items)
{
const T0 target_value = static_cast<T0>(42);
thrust::device_vector<T0> col_a(num_items, target_value);
thrust::device_vector<T1> col_b(num_items, static_cast<T1>(4.2));

thrust::device_vector<T0> result(num_items);

auto in = thrust::make_zip_iterator(col_a.begin(), col_b.begin());
auto out = thrust::make_transform_output_iterator(result.begin(), pair_to_col_t<T0, T1>{});

void *d_tmp_storage {};
std::size_t tmp_storage_size{};
cub::DeviceSelect::If(
d_tmp_storage, tmp_storage_size,
in, out, thrust::make_discard_iterator(),
num_items, select_t<T0, T1>{},
0, true);

thrust::device_vector<char> tmp_storage(tmp_storage_size);
d_tmp_storage = thrust::raw_pointer_cast(tmp_storage.data());

cub::DeviceSelect::If(
d_tmp_storage, tmp_storage_size,
in, out, thrust::make_discard_iterator(),
num_items, select_t<T0, T1>{},
0, true);

AssertEquals(num_items, thrust::count(result.begin(), result.end(), target_value));
}

/**
* Test different input sizes
*/
template <typename T0, typename T1>
void TestMixed(int num_items)
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
{
if (num_items < 0)
{
TestMixedOp<T0, T1>(0);
TestMixedOp<T0, T1>(1);
TestMixedOp<T0, T1>(100);
TestMixedOp<T0, T1>(10000);
TestMixedOp<T0, T1>(1000000);
}
else
{
TestMixedOp<T0, T1>(num_items);
}
}

//---------------------------------------------------------------------
// Main
//---------------------------------------------------------------------
Expand Down Expand Up @@ -708,6 +782,8 @@ int main(int argc, char** argv)
Test<TestFoo>(num_items);
Test<TestBar>(num_items);

TestMixed<int, double>(num_items);

return 0;
}

Expand Down
Loading