Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dead code in bincount #546

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 0 additions & 78 deletions src/cunumeric/stat/bincount.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,6 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
}
}

template <typename VAL>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
bincount_kernel_rw(AccessorRW<int64_t, 1> lhs,
AccessorRO<VAL, 1> rhs,
const size_t volume,
const size_t num_bins,
Point<1> origin)
{
extern __shared__ char array[];
auto bins = reinterpret_cast<int32_t*>(array);
_bincount(bins, rhs, volume, num_bins, origin);
// Now do the atomics out to global memory
for (int32_t bin = threadIdx.x; bin < num_bins; bin += blockDim.x) {
const auto count = bins[bin];
if (count > 0) SumReduction<int64_t>::fold<false>(lhs[bin], count);
}
}

template <typename VAL>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
weighted_bincount_kernel_rd(AccessorRD<SumReduction<double>, false, 1> lhs,
Expand All @@ -135,25 +117,6 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
}
}

template <typename VAL>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
weighted_bincount_kernel_rw(AccessorRW<double, 1> lhs,
AccessorRO<VAL, 1> rhs,
AccessorRO<double, 1> weights,
const size_t volume,
const size_t num_bins,
Point<1> origin)
{
extern __shared__ char array[];
auto bins = reinterpret_cast<double*>(array);
_weighted_bincount(bins, rhs, weights, volume, num_bins, origin);
// Now do the atomics out to global memory
for (int32_t bin = threadIdx.x; bin < num_bins; bin += blockDim.x) {
const auto weight = bins[bin];
SumReduction<double>::fold<false>(lhs[bin], weight);
}
}

template <LegateTypeCode CODE>
struct BincountImplBody<VariantKind::GPU, CODE> {
using VAL = legate_type_of<CODE>;
Expand All @@ -178,26 +141,6 @@ struct BincountImplBody<VariantKind::GPU, CODE> {
CHECK_CUDA_STREAM(stream);
}

void operator()(const AccessorRW<int64_t, 1>& lhs,
const AccessorRO<VAL, 1>& rhs,
const Rect<1>& rect,
const Rect<1>& lhs_rect) const
{
const auto volume = rect.volume();
const auto num_bins = lhs_rect.volume();
const auto bin_size = num_bins * sizeof(int32_t);

int32_t num_ctas = 0;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_ctas, bincount_kernel_rw<VAL>, THREADS_PER_BLOCK, bin_size);
assert(num_ctas > 0);
// Launch a kernel with this number of CTAs
auto stream = get_cached_stream();
bincount_kernel_rw<VAL>
<<<num_ctas, THREADS_PER_BLOCK, bin_size, stream>>>(lhs, rhs, volume, num_bins, rect.lo);
CHECK_CUDA_STREAM(stream);
}

void operator()(AccessorRD<SumReduction<double>, false, 1> lhs,
const AccessorRO<VAL, 1>& rhs,
const AccessorRO<double, 1>& weights,
Expand All @@ -218,27 +161,6 @@ struct BincountImplBody<VariantKind::GPU, CODE> {
lhs, rhs, weights, volume, num_bins, rect.lo);
CHECK_CUDA_STREAM(stream);
}

void operator()(const AccessorRW<double, 1>& lhs,
const AccessorRO<VAL, 1>& rhs,
const AccessorRO<double, 1>& weights,
const Rect<1>& rect,
const Rect<1>& lhs_rect) const
{
const auto volume = rect.volume();
const auto num_bins = lhs_rect.volume();
const auto bin_size = num_bins * sizeof(double);

int32_t num_ctas = 0;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_ctas, weighted_bincount_kernel_rw<VAL>, THREADS_PER_BLOCK, bin_size);
assert(num_ctas > 0);
// Launch a kernel with this number of CTAs
auto stream = get_cached_stream();
weighted_bincount_kernel_rw<VAL><<<num_ctas, THREADS_PER_BLOCK, bin_size, stream>>>(
lhs, rhs, weights, volume, num_bins, rect.lo);
CHECK_CUDA_STREAM(stream);
}
};

/*static*/ void BincountTask::gpu_variant(TaskContext& context)
Expand Down
23 changes: 0 additions & 23 deletions src/cunumeric/stat/bincount_omp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,6 @@ struct BincountImplBody<VariantKind::OMP, CODE> {
lhs.reduce(bin_num, local_bins[bin_num]);
}

void operator()(const AccessorRW<int64_t, 1>& lhs,
const AccessorRO<VAL, 1>& rhs,
const Rect<1>& rect,
const Rect<1>& lhs_rect) const
{
auto all_local_bins = _bincount(rhs, rect, lhs_rect);
for (auto& local_bins : all_local_bins)
for (size_t bin_num = 0; bin_num < local_bins.size(); ++bin_num)
lhs[bin_num] += local_bins[bin_num];
}

void operator()(AccessorRD<SumReduction<double>, true, 1> lhs,
const AccessorRO<VAL, 1>& rhs,
const AccessorRO<double, 1>& weights,
Expand All @@ -112,18 +101,6 @@ struct BincountImplBody<VariantKind::OMP, CODE> {
for (size_t bin_num = 0; bin_num < local_bins.size(); ++bin_num)
lhs.reduce(bin_num, local_bins[bin_num]);
}

void operator()(const AccessorRW<double, 1>& lhs,
const AccessorRO<VAL, 1>& rhs,
const AccessorRO<double, 1>& weights,
const Rect<1>& rect,
const Rect<1>& lhs_rect) const
{
auto all_local_bins = _bincount(rhs, weights, rect, lhs_rect);
for (auto& local_bins : all_local_bins)
for (size_t bin_num = 0; bin_num < local_bins.size(); ++bin_num)
lhs[bin_num] += local_bins[bin_num];
}
};

/*static*/ void BincountTask::omp_variant(TaskContext& context)
Expand Down