Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support collect_set on rolling window #7881

Merged
merged 14 commits into from
May 26, 2021
2 changes: 1 addition & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class collect_list_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
class collect_set_aggregation final : public aggregation {
class collect_set_aggregation final : public rolling_aggregation {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we inherit collect set from rolling_aggregation instead of aggregation? We need it for both rolling window and groupby, don't we?

Copy link
Contributor Author

@sperlingxx sperlingxx May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because rolling_aggregation is virtually inherited from aggregation. I just followed corresponding codes for collect_list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe both are wrong, as collect_set_aggregation and collect_set_aggregation are used not only in rolling window but in groupby. Can you change to:

class collect_list_aggregation final : public aggregation 
...
class collect_set_aggregation final : public aggregation 

and test if they can compile and unit tests all pass please?

Copy link
Contributor Author

@sperlingxx sperlingxx May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ttnghia, I tried on replacing rolling_aggregation with aggregation. And I got compiling error on src/aggregation/aggregation.cpp:454:

/home/alfredxu/workspace/codes/cudf/cpp/src/aggregation/aggregation.cpp:454:74: error: could not convert ‘std::make_unique(_Args&& ...) [with _Tp = cudf::detail::collect_list_aggregation; _Args = {cudf::null_policy&}; typename std::_MakeUniq<_Tp>::__single_object = std::unique_ptr<cudf::detail::collect_list_aggregation, std::default_delete<cudf::detail::collect_list_aggregation> >]()’ from ‘unique_ptr<cudf::detail::collect_list_aggregation,default_delete<cudf::detail::collect_list_aggregation>>’ to ‘unique_ptr<cudf::rolling_aggregation,default_delete<cudf::rolling_aggregation>>’

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// Factory to create a COLLECT_LIST aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_list_aggregation(null_policy null_handling)
{
  return std::make_unique<detail::collect_list_aggregation>(null_handling);
}
template std::unique_ptr<aggregation> make_collect_list_aggregation<aggregation>(
  null_policy null_handling);
template std::unique_ptr<rolling_aggregation> make_collect_list_aggregation<rolling_aggregation>(
  null_policy null_handling);

I think it is because we can not return std::unique_ptr<rolling_aggregation> after we made the replacement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks 😄

public:
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ std::unique_ptr<Base> make_collect_set_aggregation(null_policy null_handling,
}
template std::unique_ptr<aggregation> make_collect_set_aggregation<aggregation>(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);
template std::unique_ptr<rolling_aggregation> make_collect_set_aggregation<rolling_aggregation>(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);

/// Factory to create a LAG aggregation
template <typename Base = aggregation>
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/rolling/rolling_collect_list.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ std::unique_ptr<column> rolling_collect_list(column_view const& input,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
rolling_aggregation const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
Expand Down Expand Up @@ -321,7 +321,6 @@ std::unique_ptr<column> rolling_collect_list(column_view const& input,

// If gather_map collects null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
auto null_handling = dynamic_cast<collect_list_aggregation const&>(agg)._null_handling;
if (null_handling == null_policy::EXCLUDE && input.has_nulls()) {
auto num_child_nulls = count_child_nulls(input, gather_map, stream);
if (num_child_nulls != 0) {
Expand Down
30 changes: 29 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <cudf/detail/valid_if.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/rolling.hpp>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -581,6 +582,14 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
return {};
}

// COLLECT_SET aggregations do not peform a rolling operation at all. They get processed
// entirely in the finalize() step.
std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, cudf::detail::collect_set_aggregation const& agg) override
{
return {};
}

// LEAD and LAG have custom behaviors for non fixed-width types.
std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, cudf::detail::lead_lag_aggregation const& agg) override
Expand Down Expand Up @@ -678,11 +687,30 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
preceding_window_begin,
following_window_begin,
min_periods,
agg,
agg._null_handling,
stream,
mr);
}

// perform the actual COLLECT_SET operation entirely.
void visit(cudf::detail::collect_set_aggregation const& agg) override
{
auto const collected_list = rolling_collect_list(input,
default_outputs,
preceding_window_begin,
following_window_begin,
min_periods,
agg._null_handling,
stream,
rmm::mr::get_current_device_resource());

result = lists::detail::drop_list_duplicates(lists_column_view(collected_list->view()),
null_equality::EQUAL,
nan_equality::UNEQUAL,
stream,
mr);
}

std::unique_ptr<column> get_result()
{
CUDF_EXPECTS(result != nullptr,
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ ConfigureTest(ROLLING_TEST
rolling/lead_lag_test.cpp
rolling/range_window_bounds_test.cpp
rolling/range_rolling_window_test.cpp
rolling/collect_list_test.cpp)
rolling/collect_ops_test.cpp
)

###################################################################################################
# - filling test ----------------------------------------------------------------------------------
Expand Down
Loading