From d18c7d6128d7e99f7756ad4006f79df3fb0cc3d6 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Fri, 6 Jan 2023 13:52:05 -0500 Subject: [PATCH] fix(ssh-tunnel): fix dataset creation flow through modal for DB with tunnel (#22581) --- superset/models/core.py | 23 ++++++++++++----------- tests/integration_tests/core_tests.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 173bd5b590752..9e042eeab573a 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -744,13 +744,14 @@ def update_params_from_encrypted_extra(self, params: Dict[str, Any]) -> None: def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) - return Table( - table_name, - meta, - schema=schema or None, - autoload=True, - autoload_with=self._get_sqla_engine(), - ) + with self.get_sqla_engine_with_context() as engine: + return Table( + table_name, + meta, + schema=schema or None, + autoload=True, + autoload_with=engine, + ) def get_table_comment( self, table_name: str, schema: Optional[str] = None @@ -846,12 +847,12 @@ def get_perm(self) -> str: return self.perm # type: ignore def has_table(self, table: Table) -> bool: - engine = self._get_sqla_engine() - return engine.has_table(table.table_name, table.schema or None) + with self.get_sqla_engine_with_context() as engine: + return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: - engine = self._get_sqla_engine() - return engine.has_table(table_name, schema) + with self.get_sqla_engine_with_context() as engine: + return engine.has_table(table_name, schema) @classmethod def _has_view( diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 6e9f1a8d33c48..86246084fb81d 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -27,10 +27,12 @@ from urllib.parse import quote import superset.utils.database +from superset.utils.core import backend from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, ) +from sqlalchemy import Table import pytest import pytz @@ -79,6 +81,7 @@ load_world_bank_dashboard_with_slices, load_world_bank_data, ) +from tests.integration_tests.conftest import CTAS_SCHEMA_NAME logger = logging.getLogger(__name__) @@ -1673,6 +1676,16 @@ def test_explore_redirect(self, mock_command: mock.Mock): ) self.assertRedirects(rv, f"/explore/?form_data_key={random_key}") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_has_table_by_name(self): + if backend() in ("sqlite", "mysql"): + return + example_db = superset.utils.database.get_example_database() + assert ( + example_db.has_table_by_name(table_name="birth_names", schema="public") + is True + ) + if __name__ == "__main__": unittest.main()