Skip to content

Commit

Permalink
fix(ssh-tunnel): fix dataset creation flow through modal for DB with …
Browse files Browse the repository at this point in the history
…tunnel (apache#22581)
  • Loading branch information
hughhhh authored Jan 6, 2023
1 parent af34e45 commit d18c7d6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
23 changes: 12 additions & 11 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,13 +744,14 @@ def update_params_from_encrypted_extra(self, params: Dict[str, Any]) -> None:
def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
return Table(
table_name,
meta,
schema=schema or None,
autoload=True,
autoload_with=self._get_sqla_engine(),
)
with self.get_sqla_engine_with_context() as engine:
return Table(
table_name,
meta,
schema=schema or None,
autoload=True,
autoload_with=engine,
)

def get_table_comment(
self, table_name: str, schema: Optional[str] = None
Expand Down Expand Up @@ -846,12 +847,12 @@ def get_perm(self) -> str:
return self.perm # type: ignore

def has_table(self, table: Table) -> bool:
engine = self._get_sqla_engine()
return engine.has_table(table.table_name, table.schema or None)
with self.get_sqla_engine_with_context() as engine:
return engine.has_table(table.table_name, table.schema or None)

def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
engine = self._get_sqla_engine()
return engine.has_table(table_name, schema)
with self.get_sqla_engine_with_context() as engine:
return engine.has_table(table_name, schema)

@classmethod
def _has_view(
Expand Down
13 changes: 13 additions & 0 deletions tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
from urllib.parse import quote

import superset.utils.database
from superset.utils.core import backend
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from sqlalchemy import Table

import pytest
import pytz
Expand Down Expand Up @@ -79,6 +81,7 @@
load_world_bank_dashboard_with_slices,
load_world_bank_data,
)
from tests.integration_tests.conftest import CTAS_SCHEMA_NAME

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1673,6 +1676,16 @@ def test_explore_redirect(self, mock_command: mock.Mock):
)
self.assertRedirects(rv, f"/explore/?form_data_key={random_key}")

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_has_table_by_name(self):
if backend() in ("sqlite", "mysql"):
return
example_db = superset.utils.database.get_example_database()
assert (
example_db.has_table_by_name(table_name="birth_names", schema="public")
is True
)


if __name__ == "__main__":
unittest.main()

0 comments on commit d18c7d6

Please sign in to comment.