diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 5ed0f56a59c70..291ddb3764859 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -990,7 +990,7 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool: if is_alias_used_in_orderby(col): col.name = f"{col.name}__" - def _get_sqla_row_level_filters( + def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor ) -> List[TextClause]: """ @@ -1394,7 +1394,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma _("Invalid filter operation type: %(op)s", op=op) ) if is_feature_enabled("ROW_LEVEL_SECURITY"): - where_clause_and += self._get_sqla_row_level_filters(template_processor) + where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index eb555dcbcc4de..6bfb63c425c48 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -284,7 +284,7 @@ def get_statements(self) -> List[str]: return statements @staticmethod - def _get_table(tlist: TokenList) -> Optional[Table]: + def get_table(tlist: TokenList) -> Optional[Table]: """ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table construct. @@ -325,7 +325,7 @@ def _process_tokenlist(self, token_list: TokenList) -> None: """ # exclude subselects if "(" not in str(token_list): - table = self._get_table(token_list) + table = self.get_table(token_list) if table and not table.table.startswith(CTE_PREFIX): self._tables.add(table) return @@ -555,7 +555,7 @@ def get_rls_for_table( if not isinstance(candidate, Identifier): candidate = Identifier([Token(Name, candidate.value)]) - table = ParsedQuery._get_table(candidate) # pylint: disable=protected-access + table = ParsedQuery.get_table(candidate) if not table: return None @@ -577,7 +577,7 @@ def get_rls_for_table( # pylint: disable=protected-access predicate = " AND ".join( str(filter_) - for filter_ in dataset._get_sqla_row_level_filters(template_processor) + for filter_ in dataset.get_sqla_row_level_filters(template_processor) ) if not predicate: return None diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index ab81ddae7cf55..4a1ff89d74cc6 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=invalid-name, too-many-lines +# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines import unittest from typing import Optional, Set @@ -23,11 +22,14 @@ import pytest import sqlparse from pytest_mock import MockerFixture -from sqlparse.sql import Token, TokenList +from sqlalchemy import text +from sqlparse.sql import Identifier, Token, TokenList +from sqlparse.tokens import Name from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( add_table_name, + get_rls_for_table, has_table_query, insert_rls, ParsedQuery, @@ -1438,3 +1440,31 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None: condition = sqlparse.parse(rls)[0] add_table_name(condition, table) assert str(condition) == expected + + +def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None: + """ + Tests for ``get_rls_for_table``. + """ + candidate = Identifier([Token(Name, "some_table")]) + db = mocker.patch("superset.db") + dataset = db.session.query().filter().one_or_none() + dataset.__str__.return_value = "some_table" + + dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")] + assert ( + str(get_rls_for_table(candidate, 1, "public")) + == "some_table.organization_id = 1" + ) + + dataset.get_sqla_row_level_filters.return_value = [ + text("organization_id = 1"), + text("foo = 'bar'"), + ] + assert ( + str(get_rls_for_table(candidate, 1, "public")) + == "some_table.organization_id = 1 AND some_table.foo = 'bar'" + ) + + dataset.get_sqla_row_level_filters.return_value = [] + assert get_rls_for_table(candidate, 1, "public") is None