Skip to content

Commit

Permalink
fix(apache#13378): Ensure g.user is set for impersonation (apache#13878)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjreinhart authored and Allan Caetano de Oliveira committed May 21, 2021
1 parent f0d8d97 commit f2c91e6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
17 changes: 15 additions & 2 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
import logging
from typing import Any, cast, Dict, Optional

from flask import current_app
from flask import current_app, g

from superset import app
from superset.exceptions import SupersetVizException
from superset.extensions import async_query_manager, cache_manager, celery_app
from superset.extensions import (
async_query_manager,
cache_manager,
celery_app,
security_manager,
)
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.views.utils import get_datasource_info, get_viz

Expand All @@ -32,6 +37,12 @@
] # TODO: new config key


def ensure_user_is_set(user_id: Optional[int]) -> None:
user_is_set = hasattr(g, "user") and g.user is not None
if not user_is_set and user_id is not None:
g.user = security_manager.get_user_by_id(user_id)


@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
Expand All @@ -42,6 +53,7 @@ def load_chart_data_into_cache(

with app.app_context(): # type: ignore
try:
ensure_user_is_set(job_metadata.get("user_id"))
command = ChartDataCommand()
command.set_query_context(form_data)
result = command.run(cache=True)
Expand Down Expand Up @@ -72,6 +84,7 @@ def load_explore_json_into_cache(
with app.app_context(): # type: ignore
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
ensure_user_is_set(job_metadata.get("user_id"))
datasource_id, datasource_type = get_datasource_info(None, None, form_data)

viz_obj = get_viz(
Expand Down
53 changes: 39 additions & 14 deletions tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager
from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
Expand All @@ -48,17 +49,24 @@ class TestAsyncQueries(SupersetTestCase):
def test_load_chart_data_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}

load_chart_data_into_cache(job_metadata, query_context)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)

mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
ensure_user_is_set.assert_called_once_with(user.id)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)

@mock.patch.object(
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
Expand All @@ -67,25 +75,31 @@ def test_load_chart_data_into_cache(self, mock_update_job):
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}
with pytest.raises(ChartDataQueryFailedError):
load_chart_data_into_cache(job_metadata, query_context)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)
ensure_user_is_set.assert_called_once_with(user.id)

mock_run_command.assert_called_with(cache=True)
mock_run_command.assert_called_once_with(cache=True)
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
user = security_manager.find_user("gamma")
form_data = {
"datasource": f"{table.id}__table",
"viz_type": "dist_bar",
Expand All @@ -100,29 +114,40 @@ def test_load_explore_json_into_cache(self, mock_update_job):
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}

load_explore_json_into_cache(job_metadata, form_data)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)

mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
ensure_user_is_set.assert_called_once_with(user.id)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)

@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache_error(self, mock_update_job):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}

with pytest.raises(SupersetException):
load_explore_json_into_cache(job_metadata, form_data)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)

errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

0 comments on commit f2c91e6

Please sign in to comment.