diff --git a/src/cunumeric/binary/binary_red.cu b/src/cunumeric/binary/binary_red.cu index ca4cb60f4..f5c11b83d 100644 --- a/src/cunumeric/binary/binary_red.cu +++ b/src/cunumeric/binary/binary_red.cu @@ -29,7 +29,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) { const size_t idx = global_tid_1d(); if (idx >= volume) return; - if (!func(in1[idx], in2[idx])) out <<= false; + if (!func(in1[idx], in2[idx])) out.reduce(false); } template @@ -39,7 +39,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) gen const size_t idx = global_tid_1d(); if (idx >= volume) return; auto point = pitches.unflatten(idx, rect.lo); - if (!func(in1[point], in2[point])) out <<= false; + if (!func(in1[point], in2[point])) out.reduce(false); } template @@ -64,8 +64,8 @@ struct BinaryRedImplBody { { size_t volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - DeferredReduction> result; - auto stream = get_cached_stream(); + auto stream = get_cached_stream(); + DeviceScalarReductionBuffer> result(stream); if (dense) { auto in1ptr = in1.ptr(rect); auto in2ptr = in2.ptr(rect); diff --git a/src/cunumeric/cuda_help.h b/src/cunumeric/cuda_help.h index 0b36c419f..63bd6d4e1 100644 --- a/src/cunumeric/cuda_help.h +++ b/src/cunumeric/cuda_help.h @@ -21,6 +21,7 @@ #include "core/cuda/stream_pool.h" #include "cunumeric/arg.h" #include "cunumeric/arg.inl" +#include "cunumeric/device_scalar_reduction_buffer.h" #include #include #include @@ -211,71 +212,35 @@ __device__ __forceinline__ T shuffle(unsigned mask, T var, int laneMask, int wid return var; } -// Overload for complex -// TBD: if compiler optimizes out the shuffle function we defined, we could make it the default -// version -template -__device__ __forceinline__ void reduce_output(Legion::DeferredReduction result, - complex value) -{ - __shared__ complex trampoline[THREADS_PER_BLOCK / 32]; - // Reduce across the warp - const int laneid = threadIdx.x & 0x1f; - const int warpid = threadIdx.x >> 5; - for (int i = 16; i >= 1; i /= 2) { - const complex shuffle_value = shuffle(0xffffffff, value, i, 32); - REDUCTION::template fold(value, shuffle_value); - } - // Write warp values into shared memory - if ((laneid == 0) && (warpid > 0)) trampoline[warpid] = value; - __syncthreads(); - // Output reduction - if (threadIdx.x == 0) { - for (int i = 1; i < (THREADS_PER_BLOCK / 32); i++) - REDUCTION::template fold(value, trampoline[i]); - result <<= value; - // Make sure the result is visible externally - __threadfence_system(); - } -} +template +struct HasNativeShuffle { + static constexpr bool value = true; +}; -// Overload for argval -// TBD: if compiler optimizes out the shuffle function we defined, we could make it the default -// version -template -__device__ __forceinline__ void reduce_output(Legion::DeferredReduction result, - Argval value) -{ - __shared__ Argval trampoline[THREADS_PER_BLOCK / 32]; - // Reduce across the warp - const int laneid = threadIdx.x & 0x1f; - const int warpid = threadIdx.x >> 5; - for (int i = 16; i >= 1; i /= 2) { - const Argval shuffle_value = shuffle(0xffffffff, value, i, 32); - REDUCTION::template fold(value, shuffle_value); - } - // Write warp values into shared memory - if ((laneid == 0) && (warpid > 0)) trampoline[warpid] = value; - __syncthreads(); - // Output reduction - if (threadIdx.x == 0) { - for (int i = 1; i < (THREADS_PER_BLOCK / 32); i++) - REDUCTION::template fold(value, trampoline[i]); - result <<= value; - // Make sure the result is visible externally - __threadfence_system(); - } -} +template +struct HasNativeShuffle> { + static constexpr bool value = false; +}; + +template +struct HasNativeShuffle> { + static constexpr bool value = false; +}; template -__device__ __forceinline__ void reduce_output(Legion::DeferredReduction result, T value) +__device__ __forceinline__ void reduce_output(DeviceScalarReductionBuffer result, + T value) { __shared__ T trampoline[THREADS_PER_BLOCK / 32]; // Reduce across the warp const int laneid = threadIdx.x & 0x1f; const int warpid = threadIdx.x >> 5; for (int i = 16; i >= 1; i /= 2) { - const T shuffle_value = __shfl_xor_sync(0xffffffff, value, i, 32); + T shuffle_value; + if constexpr (HasNativeShuffle::value) + shuffle_value = __shfl_xor_sync(0xffffffff, value, i, 32); + else + shuffle_value = shuffle(0xffffffff, value, i, 32); REDUCTION::template fold(value, shuffle_value); } // Write warp values into shared memory @@ -285,190 +250,12 @@ __device__ __forceinline__ void reduce_output(Legion::DeferredReduction(value, trampoline[i]); - result <<= value; + result.reduce(value); // Make sure the result is visible externally __threadfence_system(); } } -__device__ __forceinline__ void reduce_bool(Legion::DeferredValue result, int value) -{ - __shared__ int trampoline[THREADS_PER_BLOCK / 32]; - // Reduce across the warp - const int laneid = threadIdx.x & 0x1f; - const int warpid = threadIdx.x >> 5; - for (int i = 16; i >= 1; i /= 2) { - const int shuffle_value = __shfl_xor_sync(0xffffffff, value, i, 32); - if (shuffle_value == 0) value = 0; - } - // Write warp values into shared memory - if ((laneid == 0) && (warpid > 0)) trampoline[warpid] = value; - __syncthreads(); - // Output reduction - if (threadIdx.x == 0) { - for (int i = 1; i < (THREADS_PER_BLOCK / 32); i++) - if (trampoline[i] == 0) { - value = 0; - break; - } - if (value == 0) { - result = false; - // Make sure the result is visible externally - __threadfence_system(); - } - } -} - -template -__device__ __forceinline__ T load_cached(const T* ptr) -{ - return *ptr; -} - -// Specializations to use PTX cache qualifiers to keep -// all the input data in as many caches as we can -// Use .ca qualifier to cache at all levels -template <> -__device__ __forceinline__ uint16_t load_cached(const uint16_t* ptr) -{ - uint16_t value; - asm volatile("ld.global.ca.u16 %0, [%1];" : "=h"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ uint32_t load_cached(const uint32_t* ptr) -{ - uint32_t value; - asm volatile("ld.global.ca.u32 %0, [%1];" : "=r"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ uint64_t load_cached(const uint64_t* ptr) -{ - uint64_t value; - asm volatile("ld.global.ca.u64 %0, [%1];" : "=l"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ int16_t load_cached(const int16_t* ptr) -{ - int16_t value; - asm volatile("ld.global.ca.s16 %0, [%1];" : "=h"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ int32_t load_cached(const int32_t* ptr) -{ - int32_t value; - asm volatile("ld.global.ca.s32 %0, [%1];" : "=r"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ int64_t load_cached(const int64_t* ptr) -{ - int64_t value; - asm volatile("ld.global.ca.s64 %0, [%1];" : "=l"(value) : "l"(ptr) : "memory"); - return value; -} - -// No half because inline ptx is dumb about the type - -template <> -__device__ __forceinline__ float load_cached(const float* ptr) -{ - float value; - asm volatile("ld.global.ca.f32 %0, [%1];" : "=f"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ double load_cached(const double* ptr) -{ - double value; - asm volatile("ld.global.ca.f64 %0, [%1];" : "=d"(value) : "l"(ptr) : "memory"); - return value; -} - -template -__device__ __forceinline__ T load_l2(const T* ptr) -{ - return *ptr; -} - -// Specializations to use PTX cache qualifiers to keep -// data loaded into L2 but no higher in the hierarchy -// Use .cg qualifier to cache at L2 -template <> -__device__ __forceinline__ uint16_t load_l2(const uint16_t* ptr) -{ - uint16_t value; - asm volatile("ld.global.cg.u16 %0, [%1];" : "=h"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ uint32_t load_l2(const uint32_t* ptr) -{ - uint32_t value; - asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ uint64_t load_l2(const uint64_t* ptr) -{ - uint64_t value; - asm volatile("ld.global.cg.u64 %0, [%1];" : "=l"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ int16_t load_l2(const int16_t* ptr) -{ - int16_t value; - asm volatile("ld.global.cg.s16 %0, [%1];" : "=h"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ int32_t load_l2(const int32_t* ptr) -{ - int32_t value; - asm volatile("ld.global.cg.s32 %0, [%1];" : "=r"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ int64_t load_l2(const int64_t* ptr) -{ - int64_t value; - asm volatile("ld.global.cg.s64 %0, [%1];" : "=l"(value) : "l"(ptr) : "memory"); - return value; -} - -// No half because inline ptx is dumb about the type - -template <> -__device__ __forceinline__ float load_l2(const float* ptr) -{ - float value; - asm volatile("ld.global.cg.f32 %0, [%1];" : "=f"(value) : "l"(ptr) : "memory"); - return value; -} - -template <> -__device__ __forceinline__ double load_l2(const double* ptr) -{ - double value; - asm volatile("ld.global.cg.f64 %0, [%1];" : "=d"(value) : "l"(ptr) : "memory"); - return value; -} - template __device__ __forceinline__ T load_streaming(const T* ptr) { diff --git a/src/cunumeric/device_scalar_reduction_buffer.h b/src/cunumeric/device_scalar_reduction_buffer.h new file mode 100644 index 000000000..5e772649f --- /dev/null +++ b/src/cunumeric/device_scalar_reduction_buffer.h @@ -0,0 +1,59 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "core/cuda/cuda_help.h" +#include "core/data/buffer.h" + +namespace cunumeric { + +template +class DeviceScalarReductionBuffer { + private: + using VAL = typename REDOP::RHS; + + public: + DeviceScalarReductionBuffer(cudaStream_t stream) + : buffer_(legate::create_buffer(1, Legion::Memory::Kind::GPU_FB_MEM)) + { + VAL identity{REDOP::identity}; + ptr_ = buffer_.ptr(0); + CHECK_CUDA(cudaMemcpyAsync(ptr_, &identity, sizeof(VAL), cudaMemcpyHostToDevice, stream)); + } + + template + __device__ void reduce(const VAL& value) const + { + REDOP::template fold(*ptr_, value); + } + + __host__ VAL read(cudaStream_t stream) const + { + VAL result{REDOP::identity}; + CHECK_CUDA(cudaMemcpyAsync(&result, ptr_, sizeof(VAL), cudaMemcpyDeviceToHost, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); + return result; + } + + __device__ VAL read() const { return *ptr_; } + + private: + legate::Buffer buffer_; + VAL* ptr_; +}; + +} // namespace cunumeric diff --git a/src/cunumeric/index/advanced_indexing.cu b/src/cunumeric/index/advanced_indexing.cu index a5217b212..2470931b8 100644 --- a/src/cunumeric/index/advanced_indexing.cu +++ b/src/cunumeric/index/advanced_indexing.cu @@ -37,14 +37,14 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const size_t skip_size, const size_t key_dim) { - size_t value = 0; + uint64_t value = 0; for (size_t i = 0; i < iters; i++) { size_t idx = (i * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (idx > volume) break; auto point = pitches.unflatten(idx, origin); bool val = (index[point] && ((idx + 1) % skip_size == 0)); offsets[idx] = static_cast(val); - SumReduction::fold(value, val); + SumReduction::fold(value, val); } // Every thread in the thread block must participate in the exchange to get correct results reduce_output(out, value); @@ -90,7 +90,7 @@ struct AdvancedIndexingImplBody { const size_t skip_size, const size_t key_dim) const { - DeferredReduction> size; + DeviceScalarReductionBuffer> size(stream); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; @@ -104,12 +104,12 @@ struct AdvancedIndexingImplBody { count_nonzero_kernel<<>>( volume, size, offsets, in, pitches, rect.lo, 1, skip_size, key_dim); - cudaStreamSynchronize(stream); + CHECK_CUDA_STREAM(stream); auto off_ptr = offsets.ptr(0); thrust::exclusive_scan(thrust::cuda::par.on(stream), off_ptr, off_ptr + volume, off_ptr); - return size.read(); + return size.read(stream); } void operator()(Array& out_arr, diff --git a/src/cunumeric/index/repeat.cu b/src/cunumeric/index/repeat.cu index 09d6c7197..30f0c2aff 100644 --- a/src/cunumeric/index/repeat.cu +++ b/src/cunumeric/index/repeat.cu @@ -35,7 +35,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) const size_t iters, Buffer offsets) { - int64_t value = 0; + uint64_t value = 0; for (size_t idx = 0; idx < iters; idx++) { const int64_t offset = (idx * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (offset < extent) { @@ -43,7 +43,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) p[axis] += offset; auto val = repeats[p]; offsets[offset] = val; - SumReduction::fold(value, val); + SumReduction::fold(value, val); } } // Every thread in the thread block must participate in the exchange to get correct results @@ -137,7 +137,7 @@ struct RepeatImplBody { int64_t extent = in_rect.hi[axis] - in_rect.lo[axis] + 1; auto offsets = create_buffer(Point<1>(extent), Memory::Kind::Z_COPY_MEM); - DeferredReduction> sum; + DeviceScalarReductionBuffer> sum(stream); const size_t blocks_count = (extent + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; const size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); @@ -151,10 +151,8 @@ struct RepeatImplBody { } CHECK_CUDA_STREAM(stream); - cudaStreamSynchronize(stream); - Point out_extents = in_rect.hi - in_rect.lo + Point::ONES(); - out_extents[axis] = sum.read(); + out_extents[axis] = static_cast(sum.read(stream)); auto out = out_array.create_output_buffer(out_extents, true); diff --git a/src/cunumeric/matrix/dot.cu b/src/cunumeric/matrix/dot.cu index 5a44bc410..3d11e19c3 100644 --- a/src/cunumeric/matrix/dot.cu +++ b/src/cunumeric/matrix/dot.cu @@ -61,7 +61,7 @@ struct DotImplBody { const auto volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - DeferredReduction> result; + DeviceScalarReductionBuffer> result(stream); size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(ACC); if (blocks >= MAX_REDUCTION_CTAS) { diff --git a/src/cunumeric/search/nonzero.cu b/src/cunumeric/search/nonzero.cu index 1180e1fb5..081865b3c 100644 --- a/src/cunumeric/search/nonzero.cu +++ b/src/cunumeric/search/nonzero.cu @@ -36,14 +36,14 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) size_t iters, Buffer offsets) { - int64_t value = 0; + uint64_t value = 0; for (size_t idx = 0; idx < iters; idx++) { const size_t offset = (idx * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (offset < volume) { auto point = pitches.unflatten(offset, origin); - auto val = static_cast(in[point] != VAL(0)); + auto val = static_cast(in[point] != VAL(0)); offsets[offset] = val; - SumReduction::fold(value, val); + SumReduction::fold(value, val); } } // Every thread in the thread block must participate in the exchange to get correct results @@ -85,7 +85,7 @@ struct NonzeroImplBody { Buffer& offsets, cudaStream_t stream) { - DeferredReduction> size; + DeviceScalarReductionBuffer> size(stream); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(int64_t); @@ -98,14 +98,12 @@ struct NonzeroImplBody { count_nonzero_kernel<<>>( volume, size, in, pitches, rect.lo, 1, offsets); - cudaStreamSynchronize(stream); - auto p_offsets = offsets.ptr(0); exclusive_sum(p_offsets, volume, stream); CHECK_CUDA_STREAM(stream); - return size.read(); + return size.read(stream); } void populate_nonzeros(const AccessorRO& in, @@ -135,7 +133,6 @@ struct NonzeroImplBody { auto offsets = create_buffer(volume, Memory::Kind::GPU_FB_MEM); auto size = compute_offsets(in, pitches, rect, volume, offsets, stream); - CHECK_CUDA_STREAM(stream); for (auto& result : results) result = create_buffer(size, Memory::Kind::GPU_FB_MEM); diff --git a/src/cunumeric/unary/scalar_unary_red.cu b/src/cunumeric/unary/scalar_unary_red.cu index 6f2059847..485879a47 100644 --- a/src/cunumeric/unary/scalar_unary_red.cu +++ b/src/cunumeric/unary/scalar_unary_red.cu @@ -127,7 +127,7 @@ struct ScalarUnaryRedImplBody { const size_t volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - DeferredReduction result; + DeviceScalarReductionBuffer result(stream); size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(LHS); if (blocks >= MAX_REDUCTION_CTAS) { @@ -156,7 +156,7 @@ struct ScalarUnaryRedImplBody { const size_t volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - DeferredReduction result; + DeviceScalarReductionBuffer result(stream); size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(LHS); if (blocks >= MAX_REDUCTION_CTAS) { @@ -190,7 +190,7 @@ struct ScalarUnaryRedImplBody(); const size_t volume = rect.volume(); const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - DeferredReduction> result; + DeviceScalarReductionBuffer> result(stream); size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(bool); if (blocks >= MAX_REDUCTION_CTAS) {