diff --git a/cub/agent/agent_select_if.cuh b/cub/agent/agent_select_if.cuh index 807ea18c6d..aded8589e2 100644 --- a/cub/agent/agent_select_if.cuh +++ b/cub/agent/agent_select_if.cuh @@ -107,10 +107,6 @@ struct AgentSelectIf // The input value type using InputT = cub::detail::value_t; - // The output value type - using OutputT = - cub::detail::non_void_value_t; - // The flag value type using FlagT = cub::detail::value_t; @@ -156,7 +152,7 @@ struct AgentSelectIf FlagsInputIteratorT>; // Parameterized BlockLoad type for input data - using BlockLoadT = BlockLoad; @@ -168,7 +164,7 @@ struct AgentSelectIf AgentSelectIfPolicyT::LOAD_ALGORITHM>; // Parameterized BlockDiscontinuity type for items - using BlockDiscontinuityT = BlockDiscontinuity; + using BlockDiscontinuityT = BlockDiscontinuity; // Parameterized BlockScan type using BlockScanT = @@ -179,7 +175,7 @@ struct AgentSelectIf TilePrefixCallbackOp; // Item exchange type - typedef OutputT ItemExchangeT[TILE_ITEMS]; + typedef InputT ItemExchangeT[TILE_ITEMS]; // Shared memory type for this thread block union _TempStorage @@ -254,7 +250,7 @@ struct AgentSelectIf __device__ __forceinline__ void InitializeSelections( OffsetT /*tile_offset*/, OffsetT num_tile_items, - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { @@ -277,7 +273,7 @@ struct AgentSelectIf __device__ __forceinline__ void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, - OutputT (&/*items*/)[ITEMS_PER_THREAD], + InputT (&/*items*/)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { @@ -311,7 +307,7 @@ struct AgentSelectIf __device__ __forceinline__ void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { @@ -324,7 +320,7 @@ struct AgentSelectIf } else { - OutputT tile_predecessor; + InputT tile_predecessor; if (threadIdx.x == 0) tile_predecessor = d_in[tile_offset - 1]; @@ -353,7 +349,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void ScatterDirect( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], OffsetT num_selections) @@ -378,7 +374,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void ScatterTwoPhase( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int /*num_tile_items*/, ///< Number of valid items in this tile @@ -414,7 +410,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void ScatterTwoPhase( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, ///< Number of valid items in this tile @@ -454,7 +450,7 @@ struct AgentSelectIf num_items - num_rejected_prefix - rejection_idx - 1 : num_selections_prefix + selection_idx; - OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; + InputT item = temp_storage.raw_exchange.Alias()[item_idx]; if (!IS_LAST_TILE || (item_idx < num_tile_items)) { @@ -469,7 +465,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void Scatter( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, ///< Number of valid items in this tile @@ -515,7 +511,7 @@ struct AgentSelectIf OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { - OutputT items[ITEMS_PER_THREAD]; + InputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; @@ -575,7 +571,7 @@ struct AgentSelectIf OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { - OutputT items[ITEMS_PER_THREAD]; + InputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index 94a17419a8..5654ba29a3 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -129,10 +129,8 @@ struct DispatchSelectIf * Types and constants ******************************************************************************/ - // The output value type - using OutputT = - cub::detail::non_void_value_t>; + // The input value type + using InputT = cub::detail::value_t; // The flag value type using FlagT = cub::detail::value_t; @@ -155,7 +153,7 @@ struct DispatchSelectIf { enum { NOMINAL_4B_ITEMS_PER_THREAD = 10, - ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))), }; typedef AgentSelectIfPolicy< diff --git a/test/test_device_select_if.cu b/test/test_device_select_if.cu index c3cc1d8e2a..e1f04f1f84 100644 --- a/test/test_device_select_if.cu +++ b/test/test_device_select_if.cu @@ -41,6 +41,12 @@ #include #include +#include +#include +#include +#include +#include + #include "test_util.h" using namespace cub; @@ -652,6 +658,74 @@ void Test( } } +template +struct pair_to_col_t +{ + __host__ __device__ FT operator()(const thrust::tuple &in) + { + return thrust::get<0>(in); + } +}; + +template +struct select_t { + __host__ __device__ bool operator()(const thrust::tuple &in) { + return static_cast(thrust::get<0>(in)) > thrust::get<1>(in); + } +}; + +template +void TestMixedOp(int num_items) +{ + const FT target_value = static_cast(42); + thrust::device_vector col_a(num_items, target_value); + thrust::device_vector col_b(num_items, static_cast(4.2)); + + thrust::device_vector result(num_items); + + auto in = thrust::make_zip_iterator(col_a.begin(), col_b.begin()); + auto out = thrust::make_transform_output_iterator(result.begin(), pair_to_col_t{}); + + void *d_tmp_storage {}; + std::size_t tmp_storage_size{}; + cub::DeviceSelect::If( + d_tmp_storage, tmp_storage_size, + in, out, thrust::make_discard_iterator(), + num_items, select_t{}, + 0, true); + + thrust::device_vector tmp_storage(tmp_storage_size); + d_tmp_storage = thrust::raw_pointer_cast(tmp_storage.data()); + + cub::DeviceSelect::If( + d_tmp_storage, tmp_storage_size, + in, out, thrust::make_discard_iterator(), + num_items, select_t{}, + 0, true); + + AssertEquals(num_items, thrust::count(result.begin(), result.end(), target_value)); +} + +/** + * Test different input sizes + */ +template +void TestMixed(int num_items) +{ + if (num_items < 0) + { + TestMixedOp(0); + TestMixedOp(1); + TestMixedOp(100); + TestMixedOp(10000); + TestMixedOp(1000000); + } + else + { + TestMixedOp(num_items); + } +} + //--------------------------------------------------------------------- // Main //--------------------------------------------------------------------- @@ -708,6 +782,8 @@ int main(int argc, char** argv) Test(num_items); Test(num_items); + TestMixed(num_items); + return 0; }