Skip to content

Commit

Permalink
Use thread_index_type to avoid out of bounds accesses in conditional …
Browse files Browse the repository at this point in the history
…joins (#13971)

See #10368 (and more recently #13771

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Yunsong Wang (https://github.com/PointKernel)
  - David Wendt (https://github.com/davidwendt)

URL: #13971
  • Loading branch information
vyasr authored Sep 7, 2023
1 parent c9d8821 commit b4da39c
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions cpp/src/join/conditional_join_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,25 @@ __global__ void compute_conditional_join_output_size(
&intermediate_storage[threadIdx.x * device_expression_data.num_intermediates];

std::size_t thread_counter{0};
cudf::size_type const start_idx = threadIdx.x + blockIdx.x * block_size;
cudf::size_type const stride = block_size * gridDim.x;
cudf::size_type const left_num_rows = left_table.num_rows();
cudf::size_type const right_num_rows = right_table.num_rows();
auto const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows);
auto const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows);
auto const start_idx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();

cudf::thread_index_type const left_num_rows = left_table.num_rows();
cudf::thread_index_type const right_num_rows = right_table.num_rows();
auto const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows);
auto const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows);

auto evaluator = cudf::ast::detail::expression_evaluator<has_nulls>(
left_table, right_table, device_expression_data);

for (cudf::size_type outer_row_index = start_idx; outer_row_index < outer_num_rows;
for (cudf::thread_index_type outer_row_index = start_idx; outer_row_index < outer_num_rows;
outer_row_index += stride) {
bool found_match = false;
for (cudf::size_type inner_row_index = 0; inner_row_index < inner_num_rows; inner_row_index++) {
auto output_dest = cudf::ast::detail::value_expression_result<bool, has_nulls>();
auto const left_row_index = swap_tables ? inner_row_index : outer_row_index;
auto const right_row_index = swap_tables ? outer_row_index : inner_row_index;
for (cudf::thread_index_type inner_row_index = 0; inner_row_index < inner_num_rows;
++inner_row_index) {
auto output_dest = cudf::ast::detail::value_expression_result<bool, has_nulls>();
cudf::size_type const left_row_index = swap_tables ? inner_row_index : outer_row_index;
cudf::size_type const right_row_index = swap_tables ? outer_row_index : inner_row_index;
evaluator.evaluate(
output_dest, left_row_index, right_row_index, 0, thread_intermediate_storage);
if (output_dest.is_valid() && output_dest.value()) {
Expand Down Expand Up @@ -161,18 +163,18 @@ __global__ void conditional_join(table_device_view left_table,
auto thread_intermediate_storage =
&intermediate_storage[threadIdx.x * device_expression_data.num_intermediates];

int const warp_id = threadIdx.x / detail::warp_size;
int const lane_id = threadIdx.x % detail::warp_size;
cudf::size_type const left_num_rows = left_table.num_rows();
cudf::size_type const right_num_rows = right_table.num_rows();
auto const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows);
auto const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows);
int const warp_id = threadIdx.x / detail::warp_size;
int const lane_id = threadIdx.x % detail::warp_size;
cudf::thread_index_type const left_num_rows = left_table.num_rows();
cudf::thread_index_type const right_num_rows = right_table.num_rows();
cudf::thread_index_type const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows);
cudf::thread_index_type const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows);

if (0 == lane_id) { current_idx_shared[warp_id] = 0; }

__syncwarp();

cudf::size_type outer_row_index = threadIdx.x + blockIdx.x * block_size;
auto outer_row_index = cudf::detail::grid_1d::global_thread_id();

unsigned int const activemask = __ballot_sync(0xffff'ffffu, outer_row_index < outer_num_rows);

Expand All @@ -181,7 +183,8 @@ __global__ void conditional_join(table_device_view left_table,

if (outer_row_index < outer_num_rows) {
bool found_match = false;
for (size_type inner_row_index(0); inner_row_index < inner_num_rows; ++inner_row_index) {
for (thread_index_type inner_row_index(0); inner_row_index < inner_num_rows;
++inner_row_index) {
auto output_dest = cudf::ast::detail::value_expression_result<bool, has_nulls>();
auto const left_row_index = swap_tables ? inner_row_index : outer_row_index;
auto const right_row_index = swap_tables ? outer_row_index : inner_row_index;
Expand Down

0 comments on commit b4da39c

Please sign in to comment.