From 7e29e210dbb3d90514522b7900e1bd6f672363a9 Mon Sep 17 00:00:00 2001 From: Mayur Date: Thu, 29 Sep 2022 10:48:58 +0530 Subject: [PATCH 1/4] check slice cache timeout --- superset/common/query_context.py | 6 ++++++ superset/common/query_context_factory.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 4a91c6ad6db17..b9414fddd1dca 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -27,6 +27,7 @@ QueryContextProcessor, ) from superset.common.query_object import QueryObject +from superset.models.slice import Slice if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -46,6 +47,7 @@ class QueryContext: enforce_numerical_metrics: ClassVar[bool] = True datasource: BaseDatasource + slice_id: Optional[int] = None queries: List[QueryObject] form_data: Optional[Dict[str, Any]] result_type: ChartDataResultType @@ -64,6 +66,7 @@ def __init__( *, datasource: BaseDatasource, queries: List[QueryObject], + slice: Optional[Slice], form_data: Optional[Dict[str, Any]], result_type: ChartDataResultType, result_format: ChartDataResultFormat, @@ -72,6 +75,7 @@ def __init__( cache_values: Dict[str, Any], ) -> None: self.datasource = datasource + self.slice = slice self.result_type = result_type self.result_format = result_format self.queries = queries @@ -98,6 +102,8 @@ def get_payload( def get_cache_timeout(self) -> Optional[int]: if self.custom_cache_timeout is not None: return self.custom_cache_timeout + if self.slice and self.slice.cache_timeout: + return self.slice.cache_timeout if self.datasource.cache_timeout is not None: return self.datasource.cache_timeout if hasattr(self.datasource, "database"): diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index dc43d28de9d58..84fdeb7109cb2 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -19,10 +19,12 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING from superset import app, db +from superset.charts.dao import ChartDAO from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.common.query_object_factory import QueryObjectFactory from superset.datasource.dao import DatasourceDAO +from superset.models.slice import Slice from superset.utils.core import DatasourceDict, DatasourceType if TYPE_CHECKING: @@ -55,6 +57,12 @@ def create( datasource_model_instance = None if datasource: datasource_model_instance = self._convert_to_model(datasource) + + slice = None + if form_data and form_data.get("slice_id") is not None: + slice_id = form_data.get("slice_id") + slice = self._get_slice(slice_id) + result_type = result_type or ChartDataResultType.FULL result_format = result_format or ChartDataResultFormat.JSON queries_ = [ @@ -72,6 +80,7 @@ def create( return QueryContext( datasource=datasource_model_instance, queries=queries_, + slice=slice, form_data=form_data, result_type=result_type, result_format=result_format, @@ -88,3 +97,6 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: datasource_type=DatasourceType(datasource["type"]), datasource_id=int(datasource["id"]), ) + + def _get_slice(self, slice_id: Any) -> Optional[Slice]: + return ChartDAO.find_by_id(slice_id) From 1e6e2cfa8ad8628cbbdf26623b8a2d4e34b4886a Mon Sep 17 00:00:00 2001 From: Mayur Date: Thu, 29 Sep 2022 10:49:15 +0530 Subject: [PATCH 2/4] check slice cache timeout --- superset/common/query_context_factory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 84fdeb7109cb2..ee6fdefe34319 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -60,8 +60,7 @@ def create( slice = None if form_data and form_data.get("slice_id") is not None: - slice_id = form_data.get("slice_id") - slice = self._get_slice(slice_id) + slice = self._get_slice(form_data.get("slice_id")) result_type = result_type or ChartDataResultType.FULL result_format = result_format or ChartDataResultFormat.JSON From 877ac2f1cf2b4c28ce3795712b6aa24c5e90e4cb Mon Sep 17 00:00:00 2001 From: Mayur Date: Thu, 29 Sep 2022 14:44:32 +0530 Subject: [PATCH 3/4] add tests --- superset/common/query_context.py | 4 +- .../charts/data/api_tests.py | 85 ++++++++++++++++++- .../fixtures/energy_dashboard.py | 14 +-- 3 files changed, 94 insertions(+), 9 deletions(-) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index b9414fddd1dca..3ff5f914d3c44 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -47,7 +47,7 @@ class QueryContext: enforce_numerical_metrics: ClassVar[bool] = True datasource: BaseDatasource - slice_id: Optional[int] = None + slice: Optional[Slice] = None queries: List[QueryObject] form_data: Optional[Dict[str, Any]] result_type: ChartDataResultType @@ -102,7 +102,7 @@ def get_payload( def get_cache_timeout(self) -> Optional[int]: if self.custom_cache_timeout is not None: return self.custom_cache_timeout - if self.slice and self.slice.cache_timeout: + if self.slice and self.slice.cache_timeout is not None: return self.slice.cache_timeout if self.datasource.cache_timeout is not None: return self.datasource.cache_timeout diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 73d33cd793b76..4a56d0df29480 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -21,7 +21,7 @@ import copy from datetime import datetime from io import BytesIO -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from unittest import mock from zipfile import ZipFile @@ -38,8 +38,12 @@ load_birth_names_data, ) from tests.integration_tests.test_app import app - +from tests.integration_tests.fixtures.energy_dashboard import ( + load_energy_table_with_slice, + load_energy_table_data, +) import pytest +from superset.models.slice import Slice from superset.charts.data.commands.get_data_command import ChartDataCommand from superset.connectors.sqla.models import TableColumn, SqlaTable @@ -976,3 +980,80 @@ def test_data_cache_default_timeout( ): rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 3456 + + +def test_chart_cache_timeout( + test_client, + login_as_admin, + physical_query_context, + load_energy_table_with_slice: List[Slice], +): + # should override datasource cache timeout + + slice_with_cache_timeout = load_energy_table_with_slice[0] + slice_with_cache_timeout.cache_timeout = 20 + db.session.merge(slice_with_cache_timeout) + + datasource: SqlaTable = ( + db.session.query(SqlaTable) + .filter(SqlaTable.id == physical_query_context["datasource"]["id"]) + .first() + ) + datasource.cache_timeout = 1254 + db.session.merge(datasource) + + db.session.commit() + + physical_query_context["form_data"] = {"slice_id": slice_with_cache_timeout.id} + + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 20 + + +@mock.patch( + "superset.common.query_context_processor.config", + { + **app.config, + "DATA_CACHE_CONFIG": { + **app.config["DATA_CACHE_CONFIG"], + "CACHE_DEFAULT_TIMEOUT": 1010, + }, + }, +) +def test_chart_cache_timeout_not_present( + test_client, login_as_admin, physical_query_context +): + # should use datasource cache, if it's present + + datasource: SqlaTable = ( + db.session.query(SqlaTable) + .filter(SqlaTable.id == physical_query_context["datasource"]["id"]) + .first() + ) + datasource.cache_timeout = 1980 + db.session.merge(datasource) + db.session.commit() + + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 1980 + + +@mock.patch( + "superset.common.query_context_processor.config", + { + **app.config, + "DATA_CACHE_CONFIG": { + **app.config["DATA_CACHE_CONFIG"], + "CACHE_DEFAULT_TIMEOUT": 1010, + }, + }, +) +def test_chart_cache_timeout_chart_not_found( + test_client, login_as_admin, physical_query_context +): + # should use default timeout + + physical_query_context["form_data"] = {"slice_id": 0} + + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 1010 diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index 0279fe8ff2f5c..436ba1ce55b6a 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import random -from typing import Dict, Set +from typing import Dict, List, Set import pandas as pd import pytest @@ -59,8 +59,8 @@ def load_energy_table_data(): @pytest.fixture() def load_energy_table_with_slice(load_energy_table_data): with app.app_context(): - _create_energy_table() - yield + slices = _create_energy_table() + yield slices _cleanup() @@ -69,7 +69,7 @@ def _get_dataframe(): return pd.DataFrame.from_dict(data) -def _create_energy_table(): +def _create_energy_table() -> List[Slice]: table = create_table_metadata( table_name=ENERGY_USAGE_TBL_NAME, database=get_example_database(), @@ -86,13 +86,17 @@ def _create_energy_table(): db.session.commit() table.fetch_metadata() + slices = [] for slice_data in _get_energy_slices(): - _create_and_commit_energy_slice( + + slice = _create_and_commit_energy_slice( table, slice_data["slice_title"], slice_data["viz_type"], slice_data["params"], ) + slices.append(slice) + return slices def _create_and_commit_energy_slice( From 65968cc3cb17a52514610b8a4963fd54a0f9a6e9 Mon Sep 17 00:00:00 2001 From: Mayur Date: Thu, 29 Sep 2022 17:21:54 +0530 Subject: [PATCH 4/4] pylint --- superset/common/query_context.py | 10 +++++----- superset/common/query_context_factory.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 3ff5f914d3c44..3f3667709938b 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -47,7 +47,7 @@ class QueryContext: enforce_numerical_metrics: ClassVar[bool] = True datasource: BaseDatasource - slice: Optional[Slice] = None + slice_: Optional[Slice] = None queries: List[QueryObject] form_data: Optional[Dict[str, Any]] result_type: ChartDataResultType @@ -66,7 +66,7 @@ def __init__( *, datasource: BaseDatasource, queries: List[QueryObject], - slice: Optional[Slice], + slice_: Optional[Slice], form_data: Optional[Dict[str, Any]], result_type: ChartDataResultType, result_format: ChartDataResultFormat, @@ -75,7 +75,7 @@ def __init__( cache_values: Dict[str, Any], ) -> None: self.datasource = datasource - self.slice = slice + self.slice_ = slice_ self.result_type = result_type self.result_format = result_format self.queries = queries @@ -102,8 +102,8 @@ def get_payload( def get_cache_timeout(self) -> Optional[int]: if self.custom_cache_timeout is not None: return self.custom_cache_timeout - if self.slice and self.slice.cache_timeout is not None: - return self.slice.cache_timeout + if self.slice_ and self.slice_.cache_timeout is not None: + return self.slice_.cache_timeout if self.datasource.cache_timeout is not None: return self.datasource.cache_timeout if hasattr(self.datasource, "database"): diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index ee6fdefe34319..360ee449a4135 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -58,9 +58,9 @@ def create( if datasource: datasource_model_instance = self._convert_to_model(datasource) - slice = None + slice_ = None if form_data and form_data.get("slice_id") is not None: - slice = self._get_slice(form_data.get("slice_id")) + slice_ = self._get_slice(form_data.get("slice_id")) result_type = result_type or ChartDataResultType.FULL result_format = result_format or ChartDataResultFormat.JSON @@ -79,7 +79,7 @@ def create( return QueryContext( datasource=datasource_model_instance, queries=queries_, - slice=slice, + slice_=slice_, form_data=form_data, result_type=result_type, result_format=result_format,