From a2b21b55be8941e1756bd6c10f5b3dd063a20ee3 Mon Sep 17 00:00:00 2001 From: EugeneTorap Date: Thu, 11 Aug 2022 18:37:53 +0300 Subject: [PATCH] fix: Validate required fields in sql_json API (#21003) * fix: Validate required params for sql_json API * Test required params in sql_json API * Refactoring: use marshmallow Schema for validation sql_json API * Update SqlJsonPayloadSchema * Update SqlJsonPayloadSchema * Refactoring * Refactoring * Refactoring --- superset/initialization/__init__.py | 2 +- superset/views/core.py | 5 ++ superset/views/sql_lab/__init__.py | 16 +++++++ superset/views/sql_lab/schemas.py | 35 ++++++++++++++ .../views/{sql_lab.py => sql_lab/views.py} | 8 +++- tests/integration_tests/core_tests.py | 46 +++++++++++++++++++ 6 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 superset/views/sql_lab/__init__.py create mode 100644 superset/views/sql_lab/schemas.py rename superset/views/{sql_lab.py => sql_lab/views.py} (99%) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 8dfeff9942aba..2fe5591dac65b 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -176,7 +176,7 @@ def init_views(self) -> None: from superset.views.log.api import LogRestApi from superset.views.log.views import LogModelView from superset.views.redirects import R - from superset.views.sql_lab import ( + from superset.views.sql_lab.views import ( SavedQueryView, SavedQueryViewApi, SqlLab, diff --git a/superset/views/core.py b/superset/views/core.py index 24c56e61d5e56..4f392337902b8 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -152,6 +152,7 @@ json_success, validate_sqlatable, ) +from superset.views.sql_lab.schemas import SqlJsonPayloadSchema from superset.views.utils import ( _deserialize_results_payload, bootstrap_user_data, @@ -2433,6 +2434,10 @@ def validate_sql_json( @event_logger.log_this @expose("/sql_json/", methods=["POST"]) def sql_json(self) -> FlaskResponse: + errors = SqlJsonPayloadSchema().validate(request.json) + if errors: + return json_error_response(status=400, payload=errors) + try: log_params = { "user_agent": cast(Optional[str], request.headers.get("USER_AGENT")) diff --git a/superset/views/sql_lab/__init__.py b/superset/views/sql_lab/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/views/sql_lab/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/views/sql_lab/schemas.py b/superset/views/sql_lab/schemas.py new file mode 100644 index 0000000000000..399665afc1bf2 --- /dev/null +++ b/superset/views/sql_lab/schemas.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from marshmallow import fields, Schema + + +class SqlJsonPayloadSchema(Schema): + database_id = fields.Integer(required=True) + sql = fields.String(required=True) + client_id = fields.String(allow_none=True) + queryLimit = fields.Integer(allow_none=True) + sql_editor_id = fields.String(allow_none=True) + schema = fields.String(allow_none=True) + tab = fields.String(allow_none=True) + ctas_method = fields.String(allow_none=True) + templateParams = fields.String(allow_none=True) + tmp_table_name = fields.String(allow_none=True) + select_as_cta = fields.Boolean(allow_none=True) + json = fields.Boolean(allow_none=True) + runAsync = fields.Boolean(allow_none=True) + expand_data = fields.Boolean(allow_none=True) diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab/views.py similarity index 99% rename from superset/views/sql_lab.py rename to superset/views/sql_lab/views.py index 1042b8f920ac2..509ff4211aacc 100644 --- a/superset/views/sql_lab.py +++ b/superset/views/sql_lab/views.py @@ -30,8 +30,12 @@ from superset.superset_typing import FlaskResponse from superset.utils import core as utils from superset.utils.core import get_user_id - -from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView +from superset.views.base import ( + BaseSupersetView, + DeleteMixin, + json_success, + SupersetModelView, +) logger = logging.getLogger(__name__) diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 5c99a1e870c5a..471926d6e2582 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -763,6 +763,52 @@ def test_extra_table_metadata(self): f"/superset/extra_table_metadata/{example_db.id}/birth_names/{schema}/" ) + def test_required_params_in_sql_json(self): + self.login() + client_id = "{}".format(random.getrandbits(64))[:10] + + data = {"client_id": client_id} + rv = self.client.post( + "/superset/sql_json/", + json=data, + ) + failed_resp = { + "sql": ["Missing data for required field."], + "database_id": ["Missing data for required field."], + } + resp_data = json.loads(rv.data.decode("utf-8")) + self.assertDictEqual(resp_data, failed_resp) + self.assertEqual(rv.status_code, 400) + + data = {"sql": "SELECT 1", "client_id": client_id} + rv = self.client.post( + "/superset/sql_json/", + json=data, + ) + failed_resp = {"database_id": ["Missing data for required field."]} + resp_data = json.loads(rv.data.decode("utf-8")) + self.assertDictEqual(resp_data, failed_resp) + self.assertEqual(rv.status_code, 400) + + data = {"database_id": 1, "client_id": client_id} + rv = self.client.post( + "/superset/sql_json/", + json=data, + ) + failed_resp = {"sql": ["Missing data for required field."]} + resp_data = json.loads(rv.data.decode("utf-8")) + self.assertDictEqual(resp_data, failed_resp) + self.assertEqual(rv.status_code, 400) + + data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id} + rv = self.client.post( + "/superset/sql_json/", + json=data, + ) + resp_data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(resp_data.get("status"), "success") + self.assertEqual(rv.status_code, 200) + def test_templated_sql_json(self): if superset.utils.database.get_example_database().backend == "presto": # TODO: make it work for presto