Skip to content

Commit

Permalink
fix: Validate required fields in sql_json API (#21003)
Browse files Browse the repository at this point in the history
* fix: Validate required params for sql_json API

* Test required params in sql_json API

* Refactoring: use marshmallow Schema for validation sql_json API

* Update SqlJsonPayloadSchema

* Update SqlJsonPayloadSchema

* Refactoring

* Refactoring

* Refactoring
  • Loading branch information
EugeneTorap authored Aug 11, 2022
1 parent 4f1996d commit a2b21b5
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 3 deletions.
2 changes: 1 addition & 1 deletion superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def init_views(self) -> None:
from superset.views.log.api import LogRestApi
from superset.views.log.views import LogModelView
from superset.views.redirects import R
from superset.views.sql_lab import (
from superset.views.sql_lab.views import (
SavedQueryView,
SavedQueryViewApi,
SqlLab,
Expand Down
5 changes: 5 additions & 0 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
json_success,
validate_sqlatable,
)
from superset.views.sql_lab.schemas import SqlJsonPayloadSchema
from superset.views.utils import (
_deserialize_results_payload,
bootstrap_user_data,
Expand Down Expand Up @@ -2433,6 +2434,10 @@ def validate_sql_json(
@event_logger.log_this
@expose("/sql_json/", methods=["POST"])
def sql_json(self) -> FlaskResponse:
errors = SqlJsonPayloadSchema().validate(request.json)
if errors:
return json_error_response(status=400, payload=errors)

try:
log_params = {
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
Expand Down
16 changes: 16 additions & 0 deletions superset/views/sql_lab/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
35 changes: 35 additions & 0 deletions superset/views/sql_lab/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from marshmallow import fields, Schema


class SqlJsonPayloadSchema(Schema):
database_id = fields.Integer(required=True)
sql = fields.String(required=True)
client_id = fields.String(allow_none=True)
queryLimit = fields.Integer(allow_none=True)
sql_editor_id = fields.String(allow_none=True)
schema = fields.String(allow_none=True)
tab = fields.String(allow_none=True)
ctas_method = fields.String(allow_none=True)
templateParams = fields.String(allow_none=True)
tmp_table_name = fields.String(allow_none=True)
select_as_cta = fields.Boolean(allow_none=True)
json = fields.Boolean(allow_none=True)
runAsync = fields.Boolean(allow_none=True)
expand_data = fields.Boolean(allow_none=True)
8 changes: 6 additions & 2 deletions superset/views/sql_lab.py → superset/views/sql_lab/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@
from superset.superset_typing import FlaskResponse
from superset.utils import core as utils
from superset.utils.core import get_user_id

from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView
from superset.views.base import (
BaseSupersetView,
DeleteMixin,
json_success,
SupersetModelView,
)

logger = logging.getLogger(__name__)

Expand Down
46 changes: 46 additions & 0 deletions tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,52 @@ def test_extra_table_metadata(self):
f"/superset/extra_table_metadata/{example_db.id}/birth_names/{schema}/"
)

def test_required_params_in_sql_json(self):
self.login()
client_id = "{}".format(random.getrandbits(64))[:10]

data = {"client_id": client_id}
rv = self.client.post(
"/superset/sql_json/",
json=data,
)
failed_resp = {
"sql": ["Missing data for required field."],
"database_id": ["Missing data for required field."],
}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)

data = {"sql": "SELECT 1", "client_id": client_id}
rv = self.client.post(
"/superset/sql_json/",
json=data,
)
failed_resp = {"database_id": ["Missing data for required field."]}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)

data = {"database_id": 1, "client_id": client_id}
rv = self.client.post(
"/superset/sql_json/",
json=data,
)
failed_resp = {"sql": ["Missing data for required field."]}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, failed_resp)
self.assertEqual(rv.status_code, 400)

data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id}
rv = self.client.post(
"/superset/sql_json/",
json=data,
)
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(resp_data.get("status"), "success")
self.assertEqual(rv.status_code, 200)

def test_templated_sql_json(self):
if superset.utils.database.get_example_database().backend == "presto":
# TODO: make it work for presto
Expand Down

0 comments on commit a2b21b5

Please sign in to comment.