diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 3941d776f75..00562f12633 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -581,7 +581,7 @@ class collect_list_aggregation final : public rolling_aggregation { /** * @brief Derived aggregation class for specifying COLLECT_SET aggregation */ -class collect_set_aggregation final : public aggregation { +class collect_set_aggregation final : public rolling_aggregation { public: explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE, null_equality nulls_equal = null_equality::EQUAL, diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 3a2215eaa53..a878dbe1535 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -468,6 +468,8 @@ std::unique_ptr make_collect_set_aggregation(null_policy null_handling, } template std::unique_ptr make_collect_set_aggregation( null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal); +template std::unique_ptr make_collect_set_aggregation( + null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal); /// Factory to create a LAG aggregation template diff --git a/cpp/src/rolling/rolling_collect_list.cuh b/cpp/src/rolling/rolling_collect_list.cuh index f5a2e59fd2a..0ffafe349b9 100644 --- a/cpp/src/rolling/rolling_collect_list.cuh +++ b/cpp/src/rolling/rolling_collect_list.cuh @@ -283,7 +283,7 @@ std::unique_ptr rolling_collect_list(column_view const& input, PrecedingIter preceding_begin_raw, FollowingIter following_begin_raw, size_type min_periods, - rolling_aggregation const& agg, + null_policy null_handling, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -321,7 +321,6 @@ std::unique_ptr rolling_collect_list(column_view const& input, // If gather_map collects null elements, and null_policy == EXCLUDE, // those elements must be filtered out, and offsets recomputed. - auto null_handling = dynamic_cast(agg)._null_handling; if (null_handling == null_policy::EXCLUDE && input.has_nulls()) { auto num_child_nulls = count_child_nulls(input, gather_map, stream); if (num_child_nulls != 0) { diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index 9e6d135b153..6f1776e40a3 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -581,6 +582,14 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre return {}; } + // COLLECT_SET aggregations do not peform a rolling operation at all. They get processed + // entirely in the finalize() step. + std::vector> visit( + data_type col_type, cudf::detail::collect_set_aggregation const& agg) override + { + return {}; + } + // LEAD and LAG have custom behaviors for non fixed-width types. std::vector> visit( data_type col_type, cudf::detail::lead_lag_aggregation const& agg) override @@ -678,11 +687,30 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation preceding_window_begin, following_window_begin, min_periods, - agg, + agg._null_handling, stream, mr); } + // perform the actual COLLECT_SET operation entirely. + void visit(cudf::detail::collect_set_aggregation const& agg) override + { + auto const collected_list = rolling_collect_list(input, + default_outputs, + preceding_window_begin, + following_window_begin, + min_periods, + agg._null_handling, + stream, + rmm::mr::get_current_device_resource()); + + result = lists::detail::drop_list_duplicates(lists_column_view(collected_list->view()), + null_equality::EQUAL, + nan_equality::UNEQUAL, + stream, + mr); + } + std::unique_ptr get_result() { CUDF_EXPECTS(result != nullptr, diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index a3df5989c3b..23b92250549 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -285,7 +285,8 @@ ConfigureTest(ROLLING_TEST rolling/lead_lag_test.cpp rolling/range_window_bounds_test.cpp rolling/range_rolling_window_test.cpp - rolling/collect_list_test.cpp) + rolling/collect_ops_test.cpp + ) ################################################################################################### # - filling test ---------------------------------------------------------------------------------- diff --git a/cpp/tests/rolling/collect_list_test.cpp b/cpp/tests/rolling/collect_ops_test.cpp similarity index 64% rename from cpp/tests/rolling/collect_list_test.cpp rename to cpp/tests/rolling/collect_ops_test.cpp index 8322dd0eee9..f97e13b49f1 100644 --- a/cpp/tests/rolling/collect_list_test.cpp +++ b/cpp/tests/rolling/collect_ops_test.cpp @@ -1281,3 +1281,762 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } + +struct CollectSetTest : public cudf::test::BaseFixture { +}; + +template +struct TypedCollectSetTest : public CollectSetTest { +}; + +using TypesForSetTest = cudf::test::Concat; + +TYPED_TEST_CASE(TypedCollectSetTest, TypesForSetTest); + +TYPED_TEST(TypedCollectSetTest, BasicRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const input_column = fixed_width_column_wrapper{10, 10, 11, 12, 11}; + + auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2}; + auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0}; + + EXPECT_EQ(static_cast(prev_column).size(), + static_cast(foll_column).size()); + + auto const result_column_based_window = + rolling_window(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); + + auto const expected_result = + lists_column_wrapper{ + {10}, + {10, 11}, + {10, 11, 12}, + {11, 12}, + {11, 12}, + } + .release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); + + auto const result_fixed_window = + rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectSetTest, RollingWindowWithEmptyOutputLists) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const input_column = fixed_width_column_wrapper{10, 11, 11, 11, 14, 15}; + + auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 0, 2, 2}; + auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 0, 1, 0}; + + EXPECT_EQ(static_cast(prev_column).size(), + static_cast(foll_column).size()); + + auto const result_column_based_window = + rolling_window(input_column, + prev_column, + foll_column, + 0, + *make_collect_set_aggregation()); + + auto const expected_result = + lists_column_wrapper{ + {10, 11}, + {10, 11}, + {11}, + {}, + {11, 14, 15}, + {14, 15}, + } + .release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + prev_column, + foll_column, + 0, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectSetTest, RollingWindowHonoursMinPeriods) +{ + // Test that when the number of observations is fewer than min_periods, + // the result is null. + + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const input_column = fixed_width_column_wrapper{0, 1, 2, 2, 4, 5}; + auto const num_elements = static_cast(input_column).size(); + + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto const expected_result = lists_column_wrapper{ + {{}, {0, 1, 2}, {1, 2}, {2, 4}, {2, 4, 5}, {}}, + cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { + return i != 0 && i != (num_elements - 1); + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + + preceding = 2; + following = 2; + min_periods = 4; + + auto result_2 = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + auto expected_result_2 = lists_column_wrapper{ + {{}, {0, 1, 2}, {1, 2, 4}, {2, 4, 5}, {}, {}}, + cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { + return i != 0 && i < 4; + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); + + auto result_2_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), + result_2_with_nulls_excluded->view()); +} + +TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsOnStrings) +{ + // Test that when the number of observations is fewer than min_periods, + // the result is null. + + using namespace cudf; + using namespace cudf::test; + + auto const input_column = strings_column_wrapper{"0", "1", "2", "2", "4", "4"}; + auto const num_elements = static_cast(input_column).size(); + + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto const expected_result = lists_column_wrapper{ + {{}, {"0", "1", "2"}, {"1", "2"}, {"2", "4"}, {"2", "4"}, {}}, + cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { + return i != 0 && i != (num_elements - 1); + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + + preceding = 2; + following = 2; + min_periods = 4; + + auto result_2 = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + auto expected_result_2 = lists_column_wrapper{ + {{}, {"0", "1", "2"}, {"1", "2", "4"}, {"2", "4"}, {}, {}}, + cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { + return i != 0 && i < 4; + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); + + auto result_2_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), + result_2_with_nulls_excluded->view()); +} + +TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsWithDecimal) +{ + // Test that when the number of observations is fewer than min_periods, + // the result is null. + + using namespace cudf; + using namespace cudf::test; + + auto const input_column = + fixed_point_column_wrapper{{0, 0, 1, 2, 3, 3}, numeric::scale_type{0}}; + + { + // One result row at each end should be null. + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto expected_result_child_values = std::vector{0, 1, 0, 1, 2, 1, 2, 3, 2, 3}; + auto expected_result_child = + fixed_point_column_wrapper{expected_result_child_values.begin(), + expected_result_child_values.end(), + numeric::scale_type{0}}; + auto expected_offsets = fixed_width_column_wrapper{0, 0, 2, 5, 8, 10, 10}.release(); + auto expected_num_rows = expected_offsets->size() - 1; + auto null_mask_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [expected_num_rows](auto i) { return i != 0 && i != (expected_num_rows - 1); }); + + auto expected_result = make_lists_column( + expected_num_rows, + std::move(expected_offsets), + expected_result_child.release(), + 2, + cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), + result_with_nulls_excluded->view()); + } + + { + // First result row, and the last two result rows should be null. + auto preceding = 2; + auto following = 2; + auto min_periods = 4; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto expected_result_child_values = std::vector{0, 1, 2, 0, 1, 2, 3, 1, 2, 3}; + auto expected_result_child = + fixed_point_column_wrapper{expected_result_child_values.begin(), + expected_result_child_values.end(), + numeric::scale_type{0}}; + auto expected_offsets = fixed_width_column_wrapper{0, 0, 3, 7, 10, 10, 10}.release(); + auto expected_num_rows = expected_offsets->size() - 1; + auto null_mask_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [expected_num_rows](auto i) { return i > 0 && i < 4; }); + + auto expected_result = make_lists_column( + expected_num_rows, + std::move(expected_offsets), + expected_result_child.release(), + 3, + cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), + result_with_nulls_excluded->view()); + } +} + +TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = + fixed_width_column_wrapper{10, 11, 11, 13, 13, 20, 21, 20, 23}; + + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto const expected_result = + lists_column_wrapper{ + {10, 11}, {10, 11}, {11, 13}, {11, 13}, {13}, {20, 21}, {20, 21}, {20, 21, 23}, {20, 23}} + .release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = grouped_rolling_window( + table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {10, 11, 12, 13, 13, 20, 21, 21, 23}, {1, 0, 0, 1, 1, 1, 0, 1, 1}}; + + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + + { + // Nulls included. + auto const result = + grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + // Null values are sorted to the tails of lists (sets) + auto expected_child = fixed_width_column_wrapper{ + {10, 11, 10, 11, 13, 11, 13, 12, 13, 20, 21, 20, 21, 21, 21, 23, 21, 21, 23}, + {1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1}}; + auto expected_offsets = fixed_width_column_wrapper{0, 2, 4, 6, 8, 9, 11, 14, 17, 19}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } + + { + // Nulls excluded. + auto const result = grouped_rolling_window( + table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + auto expected_child = + fixed_width_column_wrapper{10, 10, 13, 13, 13, 20, 20, 21, 21, 23, 21, 23}; + + auto expected_offsets = fixed_width_column_wrapper{0, 1, 2, 3, 4, 5, 6, 8, 10, 12}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } +} + +TYPED_TEST(TypedCollectSetTest, BasicGroupedTimeRangeRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = + fixed_width_column_wrapper{10, 11, 12, 13, 14, 20, 21, 22, 23}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); + + auto const expected_result = lists_column_wrapper{ + {10, 11, 12, 13}, + {10, 11, 12, 13}, + {10, 11, 12, 13, 14}, + {10, 11, 12, 13, 14}, + {10, 11, 12, 13, 14}, + {20}, + {21, 22}, + {21, 22, 23}, + {21, 22, 23}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectSetTest, GroupedTimeRangeRollingWindowWithNulls) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {10, 10, 12, 13, 14, 20, 21, 22, 22}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto null_at_1 = iterator_with_null_at(1); + auto null_at_2 = iterator_with_null_at(2); + auto null_at_3 = iterator_with_null_at(3); + auto null_at_4 = iterator_with_null_at(4); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{10, 12, 13, 10}, null_at_3}, + {{10, 12, 13, 10}, null_at_3}, + {{10, 12, 13, 14, 10}, null_at_4}, + {{10, 12, 13, 14, 10}, null_at_4}, + {{10, 12, 13, 14, 10}, null_at_4}, + {{20}, null_at_1}, + {{22, 21}, null_at_1}, + {{22, 21}, null_at_1}, + {{22, 21}, null_at_1}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {10, 12, 13}, + {10, 12, 13}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {20}, + {22}, + {22}, + {22}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectSetTest, SlicedGroupedRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const group_original = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_original = + fixed_width_column_wrapper{10, 11, 11, 13, 13, 20, 21, 21, 23}; + auto const group_col = cudf::slice(group_original, {2, 7})[0]; // { 1, 1, 1, 2, 2 } + auto const input_col = cudf::slice(input_original, {2, 7})[0]; // { 11, 13, 13, 20, 21 } + + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = grouped_rolling_window(table_view{std::vector{group_col}}, + input_col, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto const expected_result = + lists_column_wrapper{{11, 13}, {11, 13}, {13}, {20, 21}, {20, 21}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); +} + +TEST_F(CollectSetTest, BoolRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + auto const input_column = fixed_width_column_wrapper{false, false, true, true, true}; + + auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2}; + auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0}; + + EXPECT_EQ(static_cast(prev_column).size(), + static_cast(foll_column).size()); + + auto const result_column_based_window = + rolling_window(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); + + auto const expected_result = + lists_column_wrapper{ + {false}, + {false, true}, + {false, true}, + {true}, + {true}, + } + .release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); + + auto const result_fixed_window = + rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectSetTest, BoolGroupedRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = + fixed_width_column_wrapper{false, true, false, true, false, false, false, true, true}; + + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto const expected_result = lists_column_wrapper{ + {false, true}, + {false, true}, + {false, true}, + {false, true}, + {false, true}, + {false}, + {false, true}, + {false, true}, + {true}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = grouped_rolling_window( + table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectSetTest, BasicRollingWindowWithNaNs) +{ + using namespace cudf; + using namespace cudf::test; + + auto const input_column = + fixed_width_column_wrapper{1.23, 0.2341, std::nan("1"), std::nan("1"), -5.23e9}; + + auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2}; + auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0}; + + EXPECT_EQ(static_cast(prev_column).size(), + static_cast(foll_column).size()); + + auto const result_column_based_window = + rolling_window(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()); + + auto const expected_result = + lists_column_wrapper{ + {0.2341, 1.23}, + {0.2341, 1.23, std::nan("1")}, + {0.2341, std::nan("1"), std::nan("1")}, + {-5.23e9, std::nan("1"), std::nan("1")}, + {-5.23e9, std::nan("1")}, + } + .release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); + + auto const result_fixed_window = + rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectSetTest, ListTypeRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + auto const input_column = lists_column_wrapper{{1, 2, 3}, {4, 5}, {6}, {7, 8, 9}, {10}}; + + auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2}; + auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0}; + + EXPECT_THROW(rolling_window(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()), + cudf::logic_error); +} + +TEST_F(CollectSetTest, StructTypeRollingWindow) +{ + using namespace cudf; + using namespace cudf::test; + + auto col1 = fixed_width_column_wrapper{1, 2, 3, 4, 5}; + auto col2 = strings_column_wrapper{"a", "b", "c", "d", "e"}; + auto const input_column = cudf::test::structs_column_wrapper{{col1, col2}}; + auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2}; + auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0}; + + EXPECT_THROW(rolling_window(input_column, + prev_column, + foll_column, + 1, + *make_collect_set_aggregation()), + cudf::logic_error); +}