Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve adhoc SQL validation #19454

Merged
merged 3 commits into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,11 @@ def adhoc_metric_to_sqla(
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
Expand Down Expand Up @@ -928,7 +932,11 @@ def adhoc_column_to_sqla(
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
Expand Down Expand Up @@ -984,15 +992,14 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool:

def _get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
) -> List[TextClause]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why mypy didn't catch this.

"""
Return the appropriate row level security filters for
this table and the current user.

:param BaseTemplateProcessor template_processor: The template
processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
Expand Down Expand Up @@ -1145,6 +1152,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = validate_adhoc_subquery(
col["sqlExpression"],
self.database_id,
self.schema,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
Expand Down Expand Up @@ -1194,7 +1207,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
validate_adhoc_subquery(selected)
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
Expand All @@ -1207,7 +1224,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
select_exprs.append(outer)
elif columns:
for selected in columns:
validate_adhoc_subquery(selected)
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
Expand Down Expand Up @@ -1420,7 +1441,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
validate_adhoc_subquery(str(col.expression))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this up, where it's still a string.

col = literal_column(col.name)
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
Expand Down
40 changes: 25 additions & 15 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, ParsedQuery, Table
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.tables.models import Table as NewTable

if TYPE_CHECKING:
Expand Down Expand Up @@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
return cols


def validate_adhoc_subquery(raw_sql: str) -> None:
def validate_adhoc_subquery(
sql: str,
database_id: int,
default_schema: str,
) -> str:
"""
Check if adhoc SQL contains sub-queries or nested sub-queries with table
:param raw_sql: adhoc sql expression
Check if adhoc SQL contains sub-queries or nested sub-queries with table.

If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
predicates to it.

:param sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled

if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
return

for statement in sqlparse.parse(raw_sql):
statements = []
for statement in sqlparse.parse(sql):
if has_table_query(statement):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
)
)
)
return
statement = insert_rls(statement, database_id, default_schema)
statements.append(statement)

return ";\n".join(str(statement) for statement in statements)


def load_or_create_tables( # pylint: disable=too-many-arguments
Expand Down
77 changes: 51 additions & 26 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import cast, List, Optional, Set, Tuple
from urllib import parse

import sqlparse
from sqlalchemy import and_
from sqlparse.sql import (
Identifier,
IdentifierList,
Expand Down Expand Up @@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# # Recurse into child token list
# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

Expand All @@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:

def add_table_name(rls: TokenList, table: str) -> None:
"""
Modify a RLS expression ensuring columns are fully qualified.
Modify a RLS expression inplace ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
Expand All @@ -539,45 +540,67 @@ def add_table_name(rls: TokenList, table: str) -> None:
tokens.extend(token.tokens)


def matches_table_name(candidate: Token, table: str) -> bool:
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
) -> Optional[TokenList]:
"""
Returns if the token represents a reference to the table.

Tables can be fully qualified with periods.

Note that in theory a table should be represented as an identifier, but due to
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
Given a table name, return any associated RLS predicates.
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable

if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])

target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])
table = ParsedQuery._get_table(candidate) # pylint: disable=protected-access
if not table:
return None

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False
dataset = (
db.session.query(SqlaTable)
.filter(
and_(
SqlaTable.database_id == database_id,
SqlaTable.schema == (table.schema or default_schema),
SqlaTable.table_name == table.table,
)
)
.one_or_none()
)
if not dataset:
return None

return True
template_processor = dataset.get_template_processor()
# pylint: disable=protected-access
predicate = " AND ".join(
str(filter_)
for filter_ in dataset._get_sqla_row_level_filters(template_processor)
)
rls = sqlparse.parse(predicate)[0]
add_table_name(rls, str(dataset))

return rls
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved

def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:

def insert_rls(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
) -> TokenList:
"""
Update a statement inplace applying an RLS associated with a given table.
Update a statement inplace applying any associated RLS predicates.
"""
# make sure the identifier has the table name
add_table_name(rls, table)

rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)
token_list.tokens[i] = insert_rls(token, database_id, default_schema)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
Expand All @@ -587,12 +610,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
rls = get_rls_for_table(token, database_id, default_schema)
if rls:
state = InsertRLSState.FOUND_TABLE

# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
rls = cast(TokenList, rls)
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[
Expand Down
48 changes: 29 additions & 19 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@

import pytest
import sqlparse
from pytest_mock import MockerFixture
from sqlparse.sql import Token, TokenList

from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
has_table_query,
insert_rls,
matches_table_name,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
Expand Down Expand Up @@ -1391,13 +1392,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
def test_insert_rls(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
add_table_name(condition, table)

# pylint: disable=unused-argument
def get_rls_for_table(
candidate: Token, database_id: int, default_schema: str
) -> TokenList:
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
Return the RLS ``condition`` if ``candidate`` matches ``table``.
"""
# compare ignoring schema
for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
if left != right:
return None
return condition

mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)

statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
).strip()
== expected.strip()
)


@pytest.mark.parametrize(
Expand All @@ -1413,18 +1438,3 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)
assert str(condition) == expected


@pytest.mark.parametrize(
"candidate,table,expected",
[
("table", "table", True),
("schema.table", "table", True),
("table", "schema.table", True),
('schema."my table"', '"my table"', True),
('schema."my.table"', '"my.table"', True),
],
)
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
token = sqlparse.parse(candidate)[0].tokens[0]
assert matches_table_name(token, table) == expected