diff --git a/cpp/src/join/conditional_join_kernels.cuh b/cpp/src/join/conditional_join_kernels.cuh index dc455ad9cef..f665aba698f 100644 --- a/cpp/src/join/conditional_join_kernels.cuh +++ b/cpp/src/join/conditional_join_kernels.cuh @@ -67,23 +67,25 @@ __global__ void compute_conditional_join_output_size( &intermediate_storage[threadIdx.x * device_expression_data.num_intermediates]; std::size_t thread_counter{0}; - cudf::size_type const start_idx = threadIdx.x + blockIdx.x * block_size; - cudf::size_type const stride = block_size * gridDim.x; - cudf::size_type const left_num_rows = left_table.num_rows(); - cudf::size_type const right_num_rows = right_table.num_rows(); - auto const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows); - auto const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows); + auto const start_idx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); + + cudf::thread_index_type const left_num_rows = left_table.num_rows(); + cudf::thread_index_type const right_num_rows = right_table.num_rows(); + auto const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows); + auto const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows); auto evaluator = cudf::ast::detail::expression_evaluator( left_table, right_table, device_expression_data); - for (cudf::size_type outer_row_index = start_idx; outer_row_index < outer_num_rows; + for (cudf::thread_index_type outer_row_index = start_idx; outer_row_index < outer_num_rows; outer_row_index += stride) { bool found_match = false; - for (cudf::size_type inner_row_index = 0; inner_row_index < inner_num_rows; inner_row_index++) { - auto output_dest = cudf::ast::detail::value_expression_result(); - auto const left_row_index = swap_tables ? inner_row_index : outer_row_index; - auto const right_row_index = swap_tables ? outer_row_index : inner_row_index; + for (cudf::thread_index_type inner_row_index = 0; inner_row_index < inner_num_rows; + ++inner_row_index) { + auto output_dest = cudf::ast::detail::value_expression_result(); + cudf::size_type const left_row_index = swap_tables ? inner_row_index : outer_row_index; + cudf::size_type const right_row_index = swap_tables ? outer_row_index : inner_row_index; evaluator.evaluate( output_dest, left_row_index, right_row_index, 0, thread_intermediate_storage); if (output_dest.is_valid() && output_dest.value()) { @@ -161,18 +163,18 @@ __global__ void conditional_join(table_device_view left_table, auto thread_intermediate_storage = &intermediate_storage[threadIdx.x * device_expression_data.num_intermediates]; - int const warp_id = threadIdx.x / detail::warp_size; - int const lane_id = threadIdx.x % detail::warp_size; - cudf::size_type const left_num_rows = left_table.num_rows(); - cudf::size_type const right_num_rows = right_table.num_rows(); - auto const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows); - auto const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows); + int const warp_id = threadIdx.x / detail::warp_size; + int const lane_id = threadIdx.x % detail::warp_size; + cudf::thread_index_type const left_num_rows = left_table.num_rows(); + cudf::thread_index_type const right_num_rows = right_table.num_rows(); + cudf::thread_index_type const outer_num_rows = (swap_tables ? right_num_rows : left_num_rows); + cudf::thread_index_type const inner_num_rows = (swap_tables ? left_num_rows : right_num_rows); if (0 == lane_id) { current_idx_shared[warp_id] = 0; } __syncwarp(); - cudf::size_type outer_row_index = threadIdx.x + blockIdx.x * block_size; + auto outer_row_index = cudf::detail::grid_1d::global_thread_id(); unsigned int const activemask = __ballot_sync(0xffff'ffffu, outer_row_index < outer_num_rows); @@ -181,7 +183,8 @@ __global__ void conditional_join(table_device_view left_table, if (outer_row_index < outer_num_rows) { bool found_match = false; - for (size_type inner_row_index(0); inner_row_index < inner_num_rows; ++inner_row_index) { + for (thread_index_type inner_row_index(0); inner_row_index < inner_num_rows; + ++inner_row_index) { auto output_dest = cudf::ast::detail::value_expression_result(); auto const left_row_index = swap_tables ? inner_row_index : outer_row_index; auto const right_row_index = swap_tables ? outer_row_index : inner_row_index;