diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp
index 3941d776f75..00562f12633 100644
--- a/cpp/include/cudf/detail/aggregation/aggregation.hpp
+++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp
@@ -581,7 +581,7 @@ class collect_list_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
-class collect_set_aggregation final : public aggregation {
+class collect_set_aggregation final : public rolling_aggregation {
public:
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp
index 3a2215eaa53..a878dbe1535 100644
--- a/cpp/src/aggregation/aggregation.cpp
+++ b/cpp/src/aggregation/aggregation.cpp
@@ -468,6 +468,8 @@ std::unique_ptr make_collect_set_aggregation(null_policy null_handling,
}
template std::unique_ptr make_collect_set_aggregation(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);
+template std::unique_ptr make_collect_set_aggregation(
+ null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);
/// Factory to create a LAG aggregation
template
diff --git a/cpp/src/rolling/rolling_collect_list.cuh b/cpp/src/rolling/rolling_collect_list.cuh
index f5a2e59fd2a..0ffafe349b9 100644
--- a/cpp/src/rolling/rolling_collect_list.cuh
+++ b/cpp/src/rolling/rolling_collect_list.cuh
@@ -283,7 +283,7 @@ std::unique_ptr rolling_collect_list(column_view const& input,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
- rolling_aggregation const& agg,
+ null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
@@ -321,7 +321,6 @@ std::unique_ptr rolling_collect_list(column_view const& input,
// If gather_map collects null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
- auto null_handling = dynamic_cast(agg)._null_handling;
if (null_handling == null_policy::EXCLUDE && input.has_nulls()) {
auto num_child_nulls = count_child_nulls(input, gather_map, stream);
if (num_child_nulls != 0) {
diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh
index 9e6d135b153..6f1776e40a3 100644
--- a/cpp/src/rolling/rolling_detail.cuh
+++ b/cpp/src/rolling/rolling_detail.cuh
@@ -38,6 +38,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -581,6 +582,14 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
return {};
}
+ // COLLECT_SET aggregations do not peform a rolling operation at all. They get processed
+ // entirely in the finalize() step.
+ std::vector> visit(
+ data_type col_type, cudf::detail::collect_set_aggregation const& agg) override
+ {
+ return {};
+ }
+
// LEAD and LAG have custom behaviors for non fixed-width types.
std::vector> visit(
data_type col_type, cudf::detail::lead_lag_aggregation const& agg) override
@@ -678,11 +687,30 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
preceding_window_begin,
following_window_begin,
min_periods,
- agg,
+ agg._null_handling,
stream,
mr);
}
+ // perform the actual COLLECT_SET operation entirely.
+ void visit(cudf::detail::collect_set_aggregation const& agg) override
+ {
+ auto const collected_list = rolling_collect_list(input,
+ default_outputs,
+ preceding_window_begin,
+ following_window_begin,
+ min_periods,
+ agg._null_handling,
+ stream,
+ rmm::mr::get_current_device_resource());
+
+ result = lists::detail::drop_list_duplicates(lists_column_view(collected_list->view()),
+ null_equality::EQUAL,
+ nan_equality::UNEQUAL,
+ stream,
+ mr);
+ }
+
std::unique_ptr get_result()
{
CUDF_EXPECTS(result != nullptr,
diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt
index a3df5989c3b..23b92250549 100644
--- a/cpp/tests/CMakeLists.txt
+++ b/cpp/tests/CMakeLists.txt
@@ -285,7 +285,8 @@ ConfigureTest(ROLLING_TEST
rolling/lead_lag_test.cpp
rolling/range_window_bounds_test.cpp
rolling/range_rolling_window_test.cpp
- rolling/collect_list_test.cpp)
+ rolling/collect_ops_test.cpp
+ )
###################################################################################################
# - filling test ----------------------------------------------------------------------------------
diff --git a/cpp/tests/rolling/collect_list_test.cpp b/cpp/tests/rolling/collect_ops_test.cpp
similarity index 64%
rename from cpp/tests/rolling/collect_list_test.cpp
rename to cpp/tests/rolling/collect_ops_test.cpp
index 8322dd0eee9..f97e13b49f1 100644
--- a/cpp/tests/rolling/collect_list_test.cpp
+++ b/cpp/tests/rolling/collect_ops_test.cpp
@@ -1281,3 +1281,762 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
}
+
+struct CollectSetTest : public cudf::test::BaseFixture {
+};
+
+template
+struct TypedCollectSetTest : public CollectSetTest {
+};
+
+using TypesForSetTest = cudf::test::Concat;
+
+TYPED_TEST_CASE(TypedCollectSetTest, TypesForSetTest);
+
+TYPED_TEST(TypedCollectSetTest, BasicRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const input_column = fixed_width_column_wrapper{10, 10, 11, 12, 11};
+
+ auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2};
+ auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0};
+
+ EXPECT_EQ(static_cast(prev_column).size(),
+ static_cast(foll_column).size());
+
+ auto const result_column_based_window =
+ rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 1,
+ *make_collect_set_aggregation());
+
+ auto const expected_result =
+ lists_column_wrapper{
+ {10},
+ {10, 11},
+ {10, 11, 12},
+ {11, 12},
+ {11, 12},
+ }
+ .release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view());
+
+ auto const result_fixed_window =
+ rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation());
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ 2,
+ 1,
+ 1,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TYPED_TEST(TypedCollectSetTest, RollingWindowWithEmptyOutputLists)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const input_column = fixed_width_column_wrapper{10, 11, 11, 11, 14, 15};
+
+ auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 0, 2, 2};
+ auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 0, 1, 0};
+
+ EXPECT_EQ(static_cast(prev_column).size(),
+ static_cast(foll_column).size());
+
+ auto const result_column_based_window =
+ rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 0,
+ *make_collect_set_aggregation());
+
+ auto const expected_result =
+ lists_column_wrapper{
+ {10, 11},
+ {10, 11},
+ {11},
+ {},
+ {11, 14, 15},
+ {14, 15},
+ }
+ .release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 0,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TYPED_TEST(TypedCollectSetTest, RollingWindowHonoursMinPeriods)
+{
+ // Test that when the number of observations is fewer than min_periods,
+ // the result is null.
+
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const input_column = fixed_width_column_wrapper{0, 1, 2, 2, 4, 5};
+ auto const num_elements = static_cast(input_column).size();
+
+ auto preceding = 2;
+ auto following = 1;
+ auto min_periods = 3;
+ auto const result = rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto const expected_result = lists_column_wrapper{
+ {{}, {0, 1, 2}, {1, 2}, {2, 4}, {2, 4, 5}, {}},
+ cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) {
+ return i != 0 && i != (num_elements - 1);
+ })}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+
+ preceding = 2;
+ following = 2;
+ min_periods = 4;
+
+ auto result_2 = rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+ auto expected_result_2 = lists_column_wrapper{
+ {{}, {0, 1, 2}, {1, 2, 4}, {2, 4, 5}, {}, {}},
+ cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) {
+ return i != 0 && i < 4;
+ })}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view());
+
+ auto result_2_with_nulls_excluded =
+ rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(),
+ result_2_with_nulls_excluded->view());
+}
+
+TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsOnStrings)
+{
+ // Test that when the number of observations is fewer than min_periods,
+ // the result is null.
+
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto const input_column = strings_column_wrapper{"0", "1", "2", "2", "4", "4"};
+ auto const num_elements = static_cast(input_column).size();
+
+ auto preceding = 2;
+ auto following = 1;
+ auto min_periods = 3;
+ auto const result = rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto const expected_result = lists_column_wrapper{
+ {{}, {"0", "1", "2"}, {"1", "2"}, {"2", "4"}, {"2", "4"}, {}},
+ cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) {
+ return i != 0 && i != (num_elements - 1);
+ })}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+
+ preceding = 2;
+ following = 2;
+ min_periods = 4;
+
+ auto result_2 = rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+ auto expected_result_2 = lists_column_wrapper{
+ {{}, {"0", "1", "2"}, {"1", "2", "4"}, {"2", "4"}, {}, {}},
+ cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) {
+ return i != 0 && i < 4;
+ })}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view());
+
+ auto result_2_with_nulls_excluded =
+ rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(),
+ result_2_with_nulls_excluded->view());
+}
+
+TEST_F(CollectSetTest, RollingWindowHonoursMinPeriodsWithDecimal)
+{
+ // Test that when the number of observations is fewer than min_periods,
+ // the result is null.
+
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto const input_column =
+ fixed_point_column_wrapper{{0, 0, 1, 2, 3, 3}, numeric::scale_type{0}};
+
+ {
+ // One result row at each end should be null.
+ auto preceding = 2;
+ auto following = 1;
+ auto min_periods = 3;
+ auto const result = rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto expected_result_child_values = std::vector{0, 1, 0, 1, 2, 1, 2, 3, 2, 3};
+ auto expected_result_child =
+ fixed_point_column_wrapper{expected_result_child_values.begin(),
+ expected_result_child_values.end(),
+ numeric::scale_type{0}};
+ auto expected_offsets = fixed_width_column_wrapper{0, 0, 2, 5, 8, 10, 10}.release();
+ auto expected_num_rows = expected_offsets->size() - 1;
+ auto null_mask_iter = cudf::detail::make_counting_transform_iterator(
+ size_type{0}, [expected_num_rows](auto i) { return i != 0 && i != (expected_num_rows - 1); });
+
+ auto expected_result = make_lists_column(
+ expected_num_rows,
+ std::move(expected_offsets),
+ expected_result_child.release(),
+ 2,
+ cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(),
+ result_with_nulls_excluded->view());
+ }
+
+ {
+ // First result row, and the last two result rows should be null.
+ auto preceding = 2;
+ auto following = 2;
+ auto min_periods = 4;
+ auto const result = rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto expected_result_child_values = std::vector{0, 1, 2, 0, 1, 2, 3, 1, 2, 3};
+ auto expected_result_child =
+ fixed_point_column_wrapper{expected_result_child_values.begin(),
+ expected_result_child_values.end(),
+ numeric::scale_type{0}};
+ auto expected_offsets = fixed_width_column_wrapper{0, 0, 3, 7, 10, 10, 10}.release();
+ auto expected_num_rows = expected_offsets->size() - 1;
+ auto null_mask_iter = cudf::detail::make_counting_transform_iterator(
+ size_type{0}, [expected_num_rows](auto i) { return i > 0 && i < 4; });
+
+ auto expected_result = make_lists_column(
+ expected_num_rows,
+ std::move(expected_offsets),
+ expected_result_child.release(),
+ 3,
+ cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(),
+ result_with_nulls_excluded->view());
+ }
+}
+
+TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2};
+ auto const input_column =
+ fixed_width_column_wrapper{10, 11, 11, 13, 13, 20, 21, 20, 23};
+
+ auto const preceding = 2;
+ auto const following = 1;
+ auto const min_periods = 1;
+ auto const result = grouped_rolling_window(table_view{std::vector{group_column}},
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto const expected_result =
+ lists_column_wrapper{
+ {10, 11}, {10, 11}, {11, 13}, {11, 13}, {13}, {20, 21}, {20, 21}, {20, 21, 23}, {20, 23}}
+ .release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded = grouped_rolling_window(
+ table_view{std::vector{group_column}},
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2};
+ auto const input_column = fixed_width_column_wrapper{
+ {10, 11, 12, 13, 13, 20, 21, 21, 23}, {1, 0, 0, 1, 1, 1, 0, 1, 1}};
+
+ auto const preceding = 2;
+ auto const following = 1;
+ auto const min_periods = 1;
+
+ {
+ // Nulls included.
+ auto const result =
+ grouped_rolling_window(table_view{std::vector{group_column}},
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+ // Null values are sorted to the tails of lists (sets)
+ auto expected_child = fixed_width_column_wrapper{
+ {10, 11, 10, 11, 13, 11, 13, 12, 13, 20, 21, 20, 21, 21, 21, 23, 21, 21, 23},
+ {1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1}};
+ auto expected_offsets = fixed_width_column_wrapper{0, 2, 4, 6, 8, 9, 11, 14, 17, 19};
+
+ auto expected_result = make_lists_column(static_cast(group_column).size(),
+ expected_offsets.release(),
+ expected_child.release(),
+ 0,
+ {});
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+ }
+
+ {
+ // Nulls excluded.
+ auto const result = grouped_rolling_window(
+ table_view{std::vector{group_column}},
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ auto expected_child =
+ fixed_width_column_wrapper{10, 10, 13, 13, 13, 20, 20, 21, 21, 23, 21, 23};
+
+ auto expected_offsets = fixed_width_column_wrapper{0, 1, 2, 3, 4, 5, 6, 8, 10, 12};
+
+ auto expected_result = make_lists_column(static_cast(group_column).size(),
+ expected_offsets.release(),
+ expected_child.release(),
+ 0,
+ {});
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+ }
+}
+
+TYPED_TEST(TypedCollectSetTest, BasicGroupedTimeRangeRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const time_column = fixed_width_column_wrapper{
+ 1, 1, 2, 2, 3, 1, 4, 5, 6};
+ auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2};
+ auto const input_column =
+ fixed_width_column_wrapper{10, 11, 12, 13, 14, 20, 21, 22, 23};
+ auto const preceding = 2;
+ auto const following = 1;
+ auto const min_periods = 1;
+ auto const result =
+ grouped_time_range_rolling_window(table_view{std::vector{group_column}},
+ time_column,
+ cudf::order::ASCENDING,
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_list_aggregation());
+
+ auto const expected_result = lists_column_wrapper{
+ {10, 11, 12, 13},
+ {10, 11, 12, 13},
+ {10, 11, 12, 13, 14},
+ {10, 11, 12, 13, 14},
+ {10, 11, 12, 13, 14},
+ {20},
+ {21, 22},
+ {21, 22, 23},
+ {21, 22, 23}}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded = grouped_time_range_rolling_window(
+ table_view{std::vector{group_column}},
+ time_column,
+ cudf::order::ASCENDING,
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_list_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TYPED_TEST(TypedCollectSetTest, GroupedTimeRangeRollingWindowWithNulls)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const time_column = fixed_width_column_wrapper{
+ 1, 1, 2, 2, 3, 1, 4, 5, 6};
+ auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2};
+ auto const input_column = fixed_width_column_wrapper{
+ {10, 10, 12, 13, 14, 20, 21, 22, 22}, {1, 0, 1, 1, 1, 1, 0, 1, 1}};
+ auto const preceding = 2;
+ auto const following = 1;
+ auto const min_periods = 1;
+ auto const result =
+ grouped_time_range_rolling_window(table_view{std::vector{group_column}},
+ time_column,
+ cudf::order::ASCENDING,
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto null_at_1 = iterator_with_null_at(1);
+ auto null_at_2 = iterator_with_null_at(2);
+ auto null_at_3 = iterator_with_null_at(3);
+ auto null_at_4 = iterator_with_null_at(4);
+
+ // In the results, `11` and `21` should be nulls.
+ auto const expected_result = lists_column_wrapper{
+ {{10, 12, 13, 10}, null_at_3},
+ {{10, 12, 13, 10}, null_at_3},
+ {{10, 12, 13, 14, 10}, null_at_4},
+ {{10, 12, 13, 14, 10}, null_at_4},
+ {{10, 12, 13, 14, 10}, null_at_4},
+ {{20}, null_at_1},
+ {{22, 21}, null_at_1},
+ {{22, 21}, null_at_1},
+ {{22, 21}, null_at_1}}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded = grouped_time_range_rolling_window(
+ table_view{std::vector{group_column}},
+ time_column,
+ cudf::order::ASCENDING,
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ // After null exclusion, `11`, `21`, and `null` should not appear.
+ auto const expected_result_with_nulls_excluded = lists_column_wrapper{
+ {10, 12, 13},
+ {10, 12, 13},
+ {10, 12, 13, 14},
+ {10, 12, 13, 14},
+ {10, 12, 13, 14},
+ {20},
+ {22},
+ {22},
+ {22}}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(),
+ result_with_nulls_excluded->view());
+}
+
+TYPED_TEST(TypedCollectSetTest, SlicedGroupedRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ using T = TypeParam;
+
+ auto const group_original = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2};
+ auto const input_original =
+ fixed_width_column_wrapper{10, 11, 11, 13, 13, 20, 21, 21, 23};
+ auto const group_col = cudf::slice(group_original, {2, 7})[0]; // { 1, 1, 1, 2, 2 }
+ auto const input_col = cudf::slice(input_original, {2, 7})[0]; // { 11, 13, 13, 20, 21 }
+
+ auto const preceding = 2;
+ auto const following = 1;
+ auto const min_periods = 1;
+ auto const result = grouped_rolling_window(table_view{std::vector{group_col}},
+ input_col,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto const expected_result =
+ lists_column_wrapper{{11, 13}, {11, 13}, {13}, {20, 21}, {20, 21}}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+}
+
+TEST_F(CollectSetTest, BoolRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto const input_column = fixed_width_column_wrapper{false, false, true, true, true};
+
+ auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2};
+ auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0};
+
+ EXPECT_EQ(static_cast(prev_column).size(),
+ static_cast(foll_column).size());
+
+ auto const result_column_based_window =
+ rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 1,
+ *make_collect_set_aggregation());
+
+ auto const expected_result =
+ lists_column_wrapper{
+ {false},
+ {false, true},
+ {false, true},
+ {true},
+ {true},
+ }
+ .release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view());
+
+ auto const result_fixed_window =
+ rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation());
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ 2,
+ 1,
+ 1,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TEST_F(CollectSetTest, BoolGroupedRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2};
+ auto const input_column =
+ fixed_width_column_wrapper{false, true, false, true, false, false, false, true, true};
+
+ auto const preceding = 2;
+ auto const following = 1;
+ auto const min_periods = 1;
+ auto const result = grouped_rolling_window(table_view{std::vector{group_column}},
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation());
+
+ auto const expected_result = lists_column_wrapper{
+ {false, true},
+ {false, true},
+ {false, true},
+ {false, true},
+ {false, true},
+ {false},
+ {false, true},
+ {false, true},
+ {true}}.release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
+
+ auto const result_with_nulls_excluded = grouped_rolling_window(
+ table_view{std::vector{group_column}},
+ input_column,
+ preceding,
+ following,
+ min_periods,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TEST_F(CollectSetTest, BasicRollingWindowWithNaNs)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto const input_column =
+ fixed_width_column_wrapper{1.23, 0.2341, std::nan("1"), std::nan("1"), -5.23e9};
+
+ auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2};
+ auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0};
+
+ EXPECT_EQ(static_cast(prev_column).size(),
+ static_cast(foll_column).size());
+
+ auto const result_column_based_window =
+ rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 1,
+ *make_collect_set_aggregation());
+
+ auto const expected_result =
+ lists_column_wrapper{
+ {0.2341, 1.23},
+ {0.2341, 1.23, std::nan("1")},
+ {0.2341, std::nan("1"), std::nan("1")},
+ {-5.23e9, std::nan("1"), std::nan("1")},
+ {-5.23e9, std::nan("1")},
+ }
+ .release();
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view());
+
+ auto const result_fixed_window =
+ rolling_window(input_column, 2, 1, 1, *make_collect_set_aggregation());
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view());
+
+ auto const result_with_nulls_excluded =
+ rolling_window(input_column,
+ 2,
+ 1,
+ 1,
+ *make_collect_set_aggregation(null_policy::EXCLUDE));
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
+}
+
+TEST_F(CollectSetTest, ListTypeRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto const input_column = lists_column_wrapper{{1, 2, 3}, {4, 5}, {6}, {7, 8, 9}, {10}};
+
+ auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2};
+ auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0};
+
+ EXPECT_THROW(rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 1,
+ *make_collect_set_aggregation()),
+ cudf::logic_error);
+}
+
+TEST_F(CollectSetTest, StructTypeRollingWindow)
+{
+ using namespace cudf;
+ using namespace cudf::test;
+
+ auto col1 = fixed_width_column_wrapper{1, 2, 3, 4, 5};
+ auto col2 = strings_column_wrapper{"a", "b", "c", "d", "e"};
+ auto const input_column = cudf::test::structs_column_wrapper{{col1, col2}};
+ auto const prev_column = fixed_width_column_wrapper{1, 2, 2, 2, 2};
+ auto const foll_column = fixed_width_column_wrapper{1, 1, 1, 1, 0};
+
+ EXPECT_THROW(rolling_window(input_column,
+ prev_column,
+ foll_column,
+ 1,
+ *make_collect_set_aggregation()),
+ cudf::logic_error);
+}