Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ChartDataCommand): into two separate commands #17425

Merged
merged 1 commit into from
Nov 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@

from superset import is_feature_enabled, security_manager
from superset.charts.api import ChartRestApi
from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.charts.data.commands import (
ChartDataCommand,
CreateAsyncChartDataJobCommand,
)
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.post_processing import apply_post_process
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
Expand Down Expand Up @@ -145,7 +148,7 @@ def get_data(self, pk: int) -> Response:
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(command)
return self._run_async(json_body, command)

try:
form_data = json.loads(chart.params)
Expand Down Expand Up @@ -231,7 +234,7 @@ def data(self) -> Response:
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(command)
return self._run_async(json_body, command)

return self._get_data_response(command)

Expand Down Expand Up @@ -289,7 +292,9 @@ def data_from_cache(self, cache_key: str) -> Response:

return self._get_data_response(command, True)

def _run_async(self, command: ChartDataCommand) -> Response:
def _run_async(
self, form_data: Dict[str, Any], command: ChartDataCommand
) -> Response:
"""
Execute command as an async query.
"""
Expand All @@ -309,12 +314,13 @@ def _run_async(self, command: ChartDataCommand) -> Response:
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint
# to retrieve the results.
async_command = CreateAsyncChartDataJobCommand()
try:
command.validate_async_request(request)
async_command.validate(request)
except AsyncQueryTokenException:
return self.response_401()

result = command.run_async(g.user.get_id())
result = async_command.run(form_data, g.user.get_id())
return self.response(202, **result)

def _send_chart_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@


class ChartDataCommand(BaseCommand):
def __init__(self) -> None:
self._form_data: Dict[str, Any]
self._query_context: QueryContext
self._async_channel_id: str
_query_context: QueryContext

def run(self, **kwargs: Any) -> Dict[str, Any]:
# caching is handled in query_context.get_df_payload
Expand Down Expand Up @@ -66,26 +63,27 @@ def run(self, **kwargs: Any) -> Dict[str, Any]:

return return_value

def run_async(self, user_id: Optional[str]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, self._form_data)

return job_metadata

def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
self._form_data = form_data
try:
self._query_context = ChartDataQueryContextSchema().load(self._form_data)
self._query_context = ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
except ValidationError as error:
raise error

return self._query_context

def validate(self) -> None:
self._query_context.raise_for_access()

def validate_async_request(self, request: Request) -> None:

class CreateAsyncChartDataJobCommand:
_async_channel_id: str

def validate(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]

def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
2 changes: 1 addition & 1 deletion superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
) -> None:
# pylint: disable=import-outside-toplevel
from superset.charts.commands.data import ChartDataCommand
from superset.charts.data.commands import ChartDataCommand

try:
ensure_user_is_set(job_metadata.get("user_id"))
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import pytest

from superset.charts.commands.data import ChartDataCommand
from superset.charts.data.commands import ChartDataCommand
from superset.connectors.sqla.models import TableColumn, SqlaTable
from superset.errors import SupersetErrorType
from superset.extensions import async_query_manager, db
Expand Down
4 changes: 1 addition & 3 deletions tests/integration_tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from celery.exceptions import SoftTimeLimitExceeded
from flask import g

from superset import db
from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.connectors.sqla.models import SqlaTable
from superset.charts.data.commands import ChartDataCommand
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
Expand Down