Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix select if for mixed types
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Mar 16, 2022
1 parent 74baeed commit f52cfbb
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 23 deletions.
32 changes: 14 additions & 18 deletions cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ struct AgentSelectIf
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The output value type
using OutputT =
cub::detail::non_void_value_t<SelectedOutputIteratorT, InputT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;

Expand Down Expand Up @@ -156,7 +152,7 @@ struct AgentSelectIf
FlagsInputIteratorT>;

// Parameterized BlockLoad type for input data
using BlockLoadT = BlockLoad<OutputT,
using BlockLoadT = BlockLoad<InputT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSelectIfPolicyT::LOAD_ALGORITHM>;
Expand All @@ -168,7 +164,7 @@ struct AgentSelectIf
AgentSelectIfPolicyT::LOAD_ALGORITHM>;

// Parameterized BlockDiscontinuity type for items
using BlockDiscontinuityT = BlockDiscontinuity<OutputT, BLOCK_THREADS>;
using BlockDiscontinuityT = BlockDiscontinuity<InputT, BLOCK_THREADS>;

// Parameterized BlockScan type
using BlockScanT =
Expand All @@ -179,7 +175,7 @@ struct AgentSelectIf
TilePrefixCallbackOp<OffsetT, cub::Sum, ScanTileStateT>;

// Item exchange type
typedef OutputT ItemExchangeT[TILE_ITEMS];
typedef InputT ItemExchangeT[TILE_ITEMS];

// Shared memory type for this thread block
union _TempStorage
Expand Down Expand Up @@ -254,7 +250,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT /*tile_offset*/,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_OP> /*select_method*/)
{
Expand All @@ -277,7 +273,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&/*items*/)[ITEMS_PER_THREAD],
InputT (&/*items*/)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_FLAGS> /*select_method*/)
{
Expand Down Expand Up @@ -311,7 +307,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_DISCONTINUITY> /*select_method*/)
{
Expand All @@ -324,7 +320,7 @@ struct AgentSelectIf
}
else
{
OutputT tile_predecessor;
InputT tile_predecessor;
if (threadIdx.x == 0)
tile_predecessor = d_in[tile_offset - 1];

Expand Down Expand Up @@ -353,7 +349,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterDirect(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
OffsetT num_selections)
Expand All @@ -378,7 +374,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int /*num_tile_items*/, ///< Number of valid items in this tile
Expand Down Expand Up @@ -414,7 +410,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
Expand Down Expand Up @@ -454,7 +450,7 @@ struct AgentSelectIf
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;

OutputT item = temp_storage.raw_exchange.Alias()[item_idx];
InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
Expand All @@ -469,7 +465,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void Scatter(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
Expand Down Expand Up @@ -515,7 +511,7 @@ struct AgentSelectIf
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
InputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];

Expand Down Expand Up @@ -575,7 +571,7 @@ struct AgentSelectIf
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
InputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];

Expand Down
8 changes: 3 additions & 5 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ struct DispatchSelectIf
* Types and constants
******************************************************************************/

// The output value type
using OutputT =
cub::detail::non_void_value_t<SelectedOutputIteratorT,
cub::detail::value_t<InputIteratorT>>;
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;
Expand All @@ -155,7 +153,7 @@ struct DispatchSelectIf
{
enum {
NOMINAL_4B_ITEMS_PER_THREAD = 10,
ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))),
ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))),
};

typedef AgentSelectIfPolicy<
Expand Down
76 changes: 76 additions & 0 deletions test/test_device_select_if.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
#include <cub/device/device_partition.cuh>
#include <cub/iterator/counting_input_iterator.cuh>

#include <thrust/count.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/device_vector.h>

#include "test_util.h"

using namespace cub;
Expand Down Expand Up @@ -652,6 +658,74 @@ void Test(
}
}

template<class FT, class ST>
struct pair_to_col_t
{
__host__ __device__ FT operator()(const thrust::tuple<FT, ST> &in)
{
return thrust::get<0>(in);
}
};

template<class FT, class ST>
struct select_t {
__host__ __device__ bool operator()(const thrust::tuple<FT, ST> &in) {
return static_cast<ST>(thrust::get<0>(in)) > thrust::get<1>(in);
}
};

template <typename FT, typename ST>
void TestMixedOp(int num_items)
{
const FT target_value = static_cast<FT>(42);
thrust::device_vector<FT> col_a(num_items, target_value);
thrust::device_vector<ST> col_b(num_items, static_cast<ST>(4.2));

thrust::device_vector<FT> result(num_items);

auto in = thrust::make_zip_iterator(col_a.begin(), col_b.begin());
auto out = thrust::make_transform_output_iterator(result.begin(), pair_to_col_t<FT, ST>{});

void *d_tmp_storage {};
std::size_t tmp_storage_size{};
cub::DeviceSelect::If(
d_tmp_storage, tmp_storage_size,
in, out, thrust::make_discard_iterator(),
num_items, select_t<FT, ST>{},
0, true);

thrust::device_vector<char> tmp_storage(tmp_storage_size);
d_tmp_storage = thrust::raw_pointer_cast(tmp_storage.data());

cub::DeviceSelect::If(
d_tmp_storage, tmp_storage_size,
in, out, thrust::make_discard_iterator(),
num_items, select_t<FT, ST>{},
0, true);

AssertEquals(num_items, thrust::count(result.begin(), result.end(), target_value));
}

/**
* Test different input sizes
*/
template <typename FT, typename ST>
void TestMixed(int num_items)
{
if (num_items < 0)
{
TestMixedOp<FT, ST>(0);
TestMixedOp<FT, ST>(1);
TestMixedOp<FT, ST>(100);
TestMixedOp<FT, ST>(10000);
TestMixedOp<FT, ST>(1000000);
}
else
{
TestMixedOp<FT, ST>(num_items);
}
}

//---------------------------------------------------------------------
// Main
//---------------------------------------------------------------------
Expand Down Expand Up @@ -708,6 +782,8 @@ int main(int argc, char** argv)
Test<TestFoo>(num_items);
Test<TestBar>(num_items);

TestMixed<int, double>(num_items);

return 0;
}

Expand Down

0 comments on commit f52cfbb

Please sign in to comment.