Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Update the scan implementation to follow P0571's guidance.
Browse files Browse the repository at this point in the history
  • Loading branch information
alliepiper committed Sep 25, 2020
1 parent 4bf55ed commit 2fba463
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 28 deletions.
11 changes: 6 additions & 5 deletions cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ struct AgentScan
//---------------------------------------------------------------------

// The input value type
typedef typename std::iterator_traits<InputIteratorT>::value_type InputT;
using InputT = typename std::iterator_traits<InputIteratorT>::value_type;

// The output value type
typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The output value type -- used as the intermediate accumulator
// Per https://wg21.link/P0571, use InitValueT if provided, otherwise the
// input iterator's value type.
using OutputT =
typename If<Equals<InitValueT, NullType>::VALUE, InputT, InitValueT>::Type;

// Tile status descriptor interface type
typedef ScanTileState<OutputT> ScanTileStateT;
Expand Down
7 changes: 3 additions & 4 deletions cub/device/device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,9 @@ struct DeviceScan
// Signed integer type for global offsets
typedef int OffsetT;

// The output value type
typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The output value type -- used as the intermediate accumulator
// Use the input value type per https://wg21.link/P0571
typedef typename std::iterator_traits<InputIteratorT>::value_type OutputT;

// Initial value
OutputT init_value = 0;
Expand Down
18 changes: 11 additions & 7 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ template <
typename InitValueT, ///< The init_value element type for ScanOpT (cub::NullType for inclusive scans)
typename OffsetT, ///< Signed integer type for global offsets
typename SelectedPolicy = DeviceScanPolicy<
typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type> >
// Accumulator type.
typename If<Equals<InitValueT, NullType>::VALUE,
typename std::iterator_traits<InputIteratorT>::value_type,
InitValueT>::Type>>
struct DispatchScan:
SelectedPolicy
{
Expand All @@ -269,11 +270,14 @@ struct DispatchScan:
INIT_KERNEL_THREADS = 128
};

// The output value type
typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The input value type
using InputT = typename std::iterator_traits<InputIteratorT>::value_type;

// The output value type -- used as the intermediate accumulator
// Per https://wg21.link/P0571, use InitValueT if provided, otherwise the
// input iterator's value type.
using OutputT =
typename If<Equals<InitValueT, NullType>::VALUE, InputT, InitValueT>::Type;

void* d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t& temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
Expand Down
47 changes: 35 additions & 12 deletions test/test_device_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -544,24 +544,28 @@ void Initialize(
template <
typename InputIteratorT,
typename OutputT,
typename ScanOpT>
typename ScanOpT,
typename InitialValueT>
void Solve(
InputIteratorT h_in,
OutputT *h_reference,
int num_items,
ScanOpT scan_op,
OutputT initial_value)
InitialValueT initial_value)
{
// Use the initial value type for accumulation per P0571
using AccumT = InitialValueT;

if (num_items > 0)
{
OutputT val = h_in[0];
h_reference[0] = initial_value;
OutputT inclusive = scan_op(initial_value, val);
AccumT val = static_cast<AccumT>(h_in[0]);
h_reference[0] = initial_value;
AccumT inclusive = scan_op(initial_value, val);

for (int i = 1; i < num_items; ++i)
{
val = h_in[i];
h_reference[i] = inclusive;
val = static_cast<AccumT>(h_in[i]);
h_reference[i] = static_cast<OutputT>(inclusive);
inclusive = scan_op(inclusive, val);
}
}
Expand All @@ -582,16 +586,20 @@ void Solve(
ScanOpT scan_op,
NullType)
{
// When no initial value type is supplied, use InputT for accumulation
// per P0571
using AccumT = typename std::iterator_traits<InputIteratorT>::value_type;

if (num_items > 0)
{
OutputT inclusive = h_in[0];
h_reference[0] = inclusive;
AccumT inclusive = h_in[0];
h_reference[0] = static_cast<OutputT>(inclusive);

for (int i = 1; i < num_items; ++i)
{
OutputT val = h_in[i];
AccumT val = h_in[i];
inclusive = scan_op(inclusive, val);
h_reference[i] = inclusive;
h_reference[i] = static_cast<OutputT>(inclusive);
}
}
}
Expand Down Expand Up @@ -746,7 +754,22 @@ void TestPointer(

// Initialize problem and solution
Initialize(gen_mode, h_in, num_items);
Solve(h_in, h_reference, num_items, scan_op, initial_value);

// If the output type is primitive and the operator is cub::Sum, the test
// dispatcher throws away scan_op and initial_value for exclusive scan.
// Without an initial_value arg, the accumulator switches to the input value
// type.
// Do the same thing here:
if (Traits<OutputT>::PRIMITIVE &&
Equals<ScanOpT, cub::Sum>::VALUE &&
!Equals<InitialValueT, NullType>::VALUE)
{
Solve(h_in, h_reference, num_items, cub::Sum{}, InputT{});
}
else
{
Solve(h_in, h_reference, num_items, scan_op, initial_value);
}

// Allocate problem device arrays
InputT *d_in = NULL;
Expand Down

0 comments on commit 2fba463

Please sign in to comment.