diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index b8db82b9afe2d..f5f3c14f6b956 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -18,11 +18,16 @@ import logging from typing import Any, cast, Dict, Optional -from flask import current_app +from flask import current_app, g from superset import app from superset.exceptions import SupersetVizException -from superset.extensions import async_query_manager, cache_manager, celery_app +from superset.extensions import ( + async_query_manager, + cache_manager, + celery_app, + security_manager, +) from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.views.utils import get_datasource_info, get_viz @@ -32,6 +37,12 @@ ] # TODO: new config key +def ensure_user_is_set(user_id: Optional[int]) -> None: + user_is_set = hasattr(g, "user") and g.user is not None + if not user_is_set and user_id is not None: + g.user = security_manager.get_user_by_id(user_id) + + @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], @@ -42,6 +53,7 @@ def load_chart_data_into_cache( with app.app_context(): # type: ignore try: + ensure_user_is_set(job_metadata.get("user_id")) command = ChartDataCommand() command.set_query_context(form_data) result = command.run(cache=True) @@ -72,6 +84,7 @@ def load_explore_json_into_cache( with app.app_context(): # type: ignore cache_key_prefix = "ejr-" # ejr: explore_json request try: + ensure_user_is_set(job_metadata.get("user_id")) datasource_id, datasource_type = get_datasource_info(None, None, form_data) viz_obj = get_viz( diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py index 5a7b86a4d6297..cd4f0c0ce7669 100644 --- a/tests/tasks/async_queries_tests.py +++ b/tests/tasks/async_queries_tests.py @@ -26,7 +26,8 @@ from superset.charts.commands.exceptions import ChartDataQueryFailedError from superset.connectors.sqla.models import SqlaTable from superset.exceptions import SupersetException -from superset.extensions import async_query_manager +from superset.extensions import async_query_manager, security_manager +from superset.tasks import async_queries from superset.tasks.async_queries import ( load_chart_data_into_cache, load_explore_json_into_cache, @@ -48,17 +49,24 @@ class TestAsyncQueries(SupersetTestCase): def test_load_chart_data_into_cache(self, mock_update_job): async_query_manager.init_app(app) query_context = get_query_context("birth_names") + user = security_manager.find_user("gamma") job_metadata = { "channel_id": str(uuid4()), "job_id": str(uuid4()), - "user_id": 1, + "user_id": user.id, "status": "pending", "errors": [], } - load_chart_data_into_cache(job_metadata, query_context) + with mock.patch.object( + async_queries, "ensure_user_is_set" + ) as ensure_user_is_set: + load_chart_data_into_cache(job_metadata, query_context) - mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY) + ensure_user_is_set.assert_called_once_with(user.id) + mock_update_job.assert_called_once_with( + job_metadata, "done", result_url=mock.ANY + ) @mock.patch.object( ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo") @@ -67,25 +75,31 @@ def test_load_chart_data_into_cache(self, mock_update_job): def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command): async_query_manager.init_app(app) query_context = get_query_context("birth_names") + user = security_manager.find_user("gamma") job_metadata = { "channel_id": str(uuid4()), "job_id": str(uuid4()), - "user_id": 1, + "user_id": user.id, "status": "pending", "errors": [], } with pytest.raises(ChartDataQueryFailedError): - load_chart_data_into_cache(job_metadata, query_context) + with mock.patch.object( + async_queries, "ensure_user_is_set" + ) as ensure_user_is_set: + load_chart_data_into_cache(job_metadata, query_context) + ensure_user_is_set.assert_called_once_with(user.id) - mock_run_command.assert_called_with(cache=True) + mock_run_command.assert_called_once_with(cache=True) errors = [{"message": "Error: foo"}] - mock_update_job.assert_called_with(job_metadata, "error", errors=errors) + mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache(self, mock_update_job): async_query_manager.init_app(app) table = get_table_by_name("birth_names") + user = security_manager.find_user("gamma") form_data = { "datasource": f"{table.id}__table", "viz_type": "dist_bar", @@ -100,29 +114,40 @@ def test_load_explore_json_into_cache(self, mock_update_job): job_metadata = { "channel_id": str(uuid4()), "job_id": str(uuid4()), - "user_id": 1, + "user_id": user.id, "status": "pending", "errors": [], } - load_explore_json_into_cache(job_metadata, form_data) + with mock.patch.object( + async_queries, "ensure_user_is_set" + ) as ensure_user_is_set: + load_explore_json_into_cache(job_metadata, form_data) - mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY) + ensure_user_is_set.assert_called_once_with(user.id) + mock_update_job.assert_called_once_with( + job_metadata, "done", result_url=mock.ANY + ) @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache_error(self, mock_update_job): async_query_manager.init_app(app) + user = security_manager.find_user("gamma") form_data = {} job_metadata = { "channel_id": str(uuid4()), "job_id": str(uuid4()), - "user_id": 1, + "user_id": user.id, "status": "pending", "errors": [], } with pytest.raises(SupersetException): - load_explore_json_into_cache(job_metadata, form_data) + with mock.patch.object( + async_queries, "ensure_user_is_set" + ) as ensure_user_is_set: + load_explore_json_into_cache(job_metadata, form_data) + ensure_user_is_set.assert_called_once_with(user.id) errors = ["The dataset associated with this chart no longer exists"] - mock_update_job.assert_called_with(job_metadata, "error", errors=errors) + mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)