-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Unify the template for device reduction tree and do some cleanup * Fix performance bugs in scalar reduction kernels: * Use unsigned 64-bit integers instead of signed integers wherever possible; CUDA hasn't added an atomic intrinsic for the latter yet. * Move reduction buffers from zero-copy memory to framebuffer. This makes the slow atomic update code path in reduction operators run much more efficiently. * Use thew new scalar reduction buffer in binary reductions as well * Use only the RHS type in the reduction buffer as we never call apply * Minor clean up per review * Rename the buffer class and method to make the intent explicit * Flip the polarity of reduce's template parameter Co-authored-by: Wonchan Lee <wonchanl@nvidia.com>
- Loading branch information
1 parent
ffb37b4
commit 2959f0a
Showing
8 changed files
with
103 additions
and
262 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* Copyright 2022 NVIDIA Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "core/cuda/cuda_help.h" | ||
#include "core/data/buffer.h" | ||
|
||
namespace cunumeric { | ||
|
||
template <typename REDOP> | ||
class DeviceScalarReductionBuffer { | ||
private: | ||
using VAL = typename REDOP::RHS; | ||
|
||
public: | ||
DeviceScalarReductionBuffer(cudaStream_t stream) | ||
: buffer_(legate::create_buffer<VAL>(1, Legion::Memory::Kind::GPU_FB_MEM)) | ||
{ | ||
VAL identity{REDOP::identity}; | ||
ptr_ = buffer_.ptr(0); | ||
CHECK_CUDA(cudaMemcpyAsync(ptr_, &identity, sizeof(VAL), cudaMemcpyHostToDevice, stream)); | ||
} | ||
|
||
template <bool EXCLUSIVE> | ||
__device__ void reduce(const VAL& value) const | ||
{ | ||
REDOP::template fold<EXCLUSIVE /*exclusive*/>(*ptr_, value); | ||
} | ||
|
||
__host__ VAL read(cudaStream_t stream) const | ||
{ | ||
VAL result{REDOP::identity}; | ||
CHECK_CUDA(cudaMemcpyAsync(&result, ptr_, sizeof(VAL), cudaMemcpyDeviceToHost, stream)); | ||
CHECK_CUDA(cudaStreamSynchronize(stream)); | ||
return result; | ||
} | ||
|
||
__device__ VAL read() const { return *ptr_; } | ||
|
||
private: | ||
legate::Buffer<VAL> buffer_; | ||
VAL* ptr_; | ||
}; | ||
|
||
} // namespace cunumeric |
Oops, something went wrong.