Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix CUDA version detection in CUB
Browse files Browse the repository at this point in the history
This fixes the problem with CUB using deprecated shfl/vote instructions when CUB
is compiled with clang (e.g. some TensorFlow builds).
  • Loading branch information
Artem-B authored and alliepiper committed Nov 5, 2020
1 parent f53dbc7 commit daaa127
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
7 changes: 4 additions & 3 deletions cub/util_arch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ namespace cub {

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

#if ((__CUDACC_VER_MAJOR__ >= 9) || defined(__NVCOMPILER_CUDA__)) && \
!defined(CUB_USE_COOPERATIVE_GROUPS)
#define CUB_USE_COOPERATIVE_GROUPS
#if ((__CUDACC_VER_MAJOR__ >= 9) || defined(__NVCOMPILER_CUDA__) || \
CUDA_VERSION >= 9000) && \
!defined(CUB_USE_COOPERATIVE_GROUPS)
#define CUB_USE_COOPERATIVE_GROUPS
#endif

/// In device code, CUB_PTX_ARCH expands to the PTX version for which we are
Expand Down
6 changes: 3 additions & 3 deletions cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <limits>
#include <cfloat>

#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
#include <cuda_fp16.h>
#endif

Expand Down Expand Up @@ -1063,7 +1063,7 @@ struct FpLimits<double>
};


#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
template <>
struct FpLimits<__half>
{
Expand Down Expand Up @@ -1143,7 +1143,7 @@ template <> struct NumericTraits<unsigned long long> : BaseTraits<UNSIGNED_INTE

template <> struct NumericTraits<float> : BaseTraits<FLOATING_POINT, true, false, unsigned int, float> {};
template <> struct NumericTraits<double> : BaseTraits<FLOATING_POINT, true, false, unsigned long long, double> {};
#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#endif

Expand Down
9 changes: 4 additions & 5 deletions test/test_device_radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <algorithm>
#include <typeinfo>

#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
#include <cuda_fp16.h>
#endif

Expand Down Expand Up @@ -733,7 +733,7 @@ void Test(
ValueT *h_reference_values)
{
// Key alias type
#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
typedef typename If<Equals<KeyT, half_t>::VALUE, __half, KeyT>::Type KeyAliasT;
#else
typedef KeyT KeyAliasT;
Expand Down Expand Up @@ -1240,7 +1240,7 @@ int main(int argc, char** argv)

printf("\n-------------------------------\n");

#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
Test<CUB, half_t, NullType, IS_DESCENDING>(num_items, 1, RANDOM, entropy_reduction, 0, bits);
#endif
Test<CUB, float, NullType, IS_DESCENDING>(num_items, 1, RANDOM, entropy_reduction, 0, bits);
Expand Down Expand Up @@ -1299,7 +1299,7 @@ int main(int argc, char** argv)
TestGen<long long> (num_items, num_segments);
TestGen<unsigned long long> (num_items, num_segments);

#if (__CUDACC_VER_MAJOR__ >= 9)
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
TestGen<half_t> (num_items, num_segments);
#endif
TestGen<float> (num_items, num_segments);
Expand All @@ -1313,4 +1313,3 @@ int main(int argc, char** argv)

return 0;
}

0 comments on commit daaa127

Please sign in to comment.