From 25da4714f4aaf731976daef0fd34cce1d52c2c8a Mon Sep 17 00:00:00 2001 From: ofekisr Date: Sun, 14 Nov 2021 23:48:36 +0200 Subject: [PATCH] refactor: ChartDataCommand into two separate commands --- superset/charts/data/api.py | 18 ++++++++----- .../{commands/data.py => data/commands.py} | 26 +++++++++---------- superset/tasks/async_queries.py | 2 +- .../charts/data/api_tests.py | 2 +- .../tasks/async_queries_tests.py | 4 +-- 5 files changed, 27 insertions(+), 25 deletions(-) rename superset/charts/{commands/data.py => data/commands.py} (88%) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 37703339e7f13..534101bae6be1 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -28,11 +28,14 @@ from superset import is_feature_enabled, security_manager from superset.charts.api import ChartRestApi -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) +from superset.charts.data.commands import ( + ChartDataCommand, + CreateAsyncChartDataJobCommand, +) from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.post_processing import apply_post_process from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -145,7 +148,7 @@ def get_data(self, pk: int) -> Response: and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(command) + return self._run_async(json_body, command) try: form_data = json.loads(chart.params) @@ -231,7 +234,7 @@ def data(self) -> Response: and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(command) + return self._run_async(json_body, command) return self._get_data_response(command) @@ -289,7 +292,9 @@ def data_from_cache(self, cache_key: str) -> Response: return self._get_data_response(command, True) - def _run_async(self, command: ChartDataCommand) -> Response: + def _run_async( + self, form_data: Dict[str, Any], command: ChartDataCommand + ) -> Response: """ Execute command as an async query. """ @@ -309,12 +314,13 @@ def _run_async(self, command: ChartDataCommand) -> Response: # Clients will either poll or be notified of query completion, # at which point they will call the /data/ endpoint # to retrieve the results. + async_command = CreateAsyncChartDataJobCommand() try: - command.validate_async_request(request) + async_command.validate(request) except AsyncQueryTokenException: return self.response_401() - result = command.run_async(g.user.get_id()) + result = async_command.run(form_data, g.user.get_id()) return self.response(202, **result) def _send_chart_response( diff --git a/superset/charts/commands/data.py b/superset/charts/data/commands.py similarity index 88% rename from superset/charts/commands/data.py rename to superset/charts/data/commands.py index ec63362a5c3d0..d434f79a17101 100644 --- a/superset/charts/commands/data.py +++ b/superset/charts/data/commands.py @@ -35,10 +35,7 @@ class ChartDataCommand(BaseCommand): - def __init__(self) -> None: - self._form_data: Dict[str, Any] - self._query_context: QueryContext - self._async_channel_id: str + _query_context: QueryContext def run(self, **kwargs: Any) -> Dict[str, Any]: # caching is handled in query_context.get_df_payload @@ -66,26 +63,27 @@ def run(self, **kwargs: Any) -> Dict[str, Any]: return return_value - def run_async(self, user_id: Optional[str]) -> Dict[str, Any]: - job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) - load_chart_data_into_cache.delay(job_metadata, self._form_data) - - return job_metadata - def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext: - self._form_data = form_data try: - self._query_context = ChartDataQueryContextSchema().load(self._form_data) + self._query_context = ChartDataQueryContextSchema().load(form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex except ValidationError as error: raise error - return self._query_context def validate(self) -> None: self._query_context.raise_for_access() - def validate_async_request(self, request: Request) -> None: + +class CreateAsyncChartDataJobCommand: + _async_channel_id: str + + def validate(self, request: Request) -> None: jwt_data = async_query_manager.parse_jwt_from_request(request) self._async_channel_id = jwt_data["channel"] + + def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]: + job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) + load_chart_data_into_cache.delay(job_metadata, form_data) + return job_metadata diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 18094323ec1ec..c50dbb9a94436 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -55,7 +55,7 @@ def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel - from superset.charts.commands.data import ChartDataCommand + from superset.charts.data.commands import ChartDataCommand try: ensure_user_is_set(job_metadata.get("user_id")) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 45d300b7381d6..1b2ade28f4360 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -37,7 +37,7 @@ import pytest -from superset.charts.commands.data import ChartDataCommand +from superset.charts.data.commands import ChartDataCommand from superset.connectors.sqla.models import TableColumn, SqlaTable from superset.errors import SupersetErrorType from superset.extensions import async_query_manager, db diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 3ea1c6f0ce6de..e2cf21c552624 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -22,10 +22,8 @@ from celery.exceptions import SoftTimeLimitExceeded from flask import g -from superset import db -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.exceptions import ChartDataQueryFailedError -from superset.connectors.sqla.models import SqlaTable +from superset.charts.data.commands import ChartDataCommand from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager from superset.tasks import async_queries