Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve get_db_engine_spec_for_backend #21171

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
#
# e.g.:
#
# class AesGcmEncryptedAdapter( # pylint: disable=too-few-public-methods
# class AesGcmEncryptedAdapter(
# AbstractEncryptedFieldAdapter
# ):
# def create(
Expand Down
4 changes: 2 additions & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,8 @@ def available(self) -> Response:
"preferred": engine_spec.engine_name in preferred_databases,
}

if hasattr(engine_spec, "default_driver"):
payload["default_driver"] = engine_spec.default_driver # type: ignore
if engine_spec.default_driver:
payload["default_driver"] = engine_spec.default_driver

# show configuration parameters for DBs that support it
if (
Expand Down
27 changes: 3 additions & 24 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
)
from superset.databases.dao import DatabaseDAO
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs import get_engine_spec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
Expand All @@ -45,25 +44,13 @@ def __init__(self, parameters: Dict[str, Any]):

def run(self) -> None:
engine = self._properties["engine"]
engine_specs = get_engine_specs()
driver = self._properties.get("driver")

if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return

if engine not in engine_specs:
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={"allowed": list(engine_specs), "provided": engine},
),
)
engine_spec = engine_specs[engine]
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
Expand All @@ -73,14 +60,6 @@ def run(self) -> None:
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={
"allowed": [
name
for name, engine_spec in engine_specs.items()
if issubclass(engine_spec, BasicParametersMixin)
],
"provided": engine,
},
),
)

Expand Down
43 changes: 16 additions & 27 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import inspect
import json
from typing import Any, Dict, Optional, Type
from typing import Any, Dict

from flask import current_app
from flask_babel import lazy_gettext as _
Expand All @@ -28,7 +28,7 @@
from superset import db
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
Expand Down Expand Up @@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
[
_(
"Invalid connection string, a valid string usually follows: "
"driver://user:password@database-host/database-name"
"backend+driver://user:password@database-host/database-name"
)
]
) from ex
Expand Down Expand Up @@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
"""

engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(),
Expand Down Expand Up @@ -262,10 +263,20 @@ def build_sqlalchemy_uri(
or parameters.pop("engine", None)
or data.pop("backend", None)
)
driver = data.pop("driver", None)

configuration_method = data.get("configuration_method")
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
engine_spec = get_engine_spec(engine)
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_spec = get_engine_spec(engine, driver)

if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
engine_spec, "parameters_schema"
Expand Down Expand Up @@ -295,34 +306,12 @@ def build_sqlalchemy_uri(
return data


def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_specs = get_engine_specs()
if engine not in engine_specs:
raise ValidationError(
[
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
)
return engine_specs[engine]


class DatabaseValidateParametersSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE

engine = fields.String(required=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(allow_none=True),
Expand Down
48 changes: 33 additions & 15 deletions superset/db_engine_specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,34 @@
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Set, Type
from typing import Any, Dict, List, Optional, Set, Type

import sqlalchemy.databases
import sqlalchemy.dialects
from pkg_resources import iter_entry_points
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.url import URL

from superset.db_engine_specs.base import BaseEngineSpec

logger = logging.getLogger(__name__)


def is_engine_spec(attr: Any) -> bool:
def is_engine_spec(obj: Any) -> bool:
"""
Return true if a given object is a DB engine spec.
"""
return (
inspect.isclass(attr)
and issubclass(attr, BaseEngineSpec)
and attr != BaseEngineSpec
inspect.isclass(obj)
and issubclass(obj, BaseEngineSpec)
and obj != BaseEngineSpec
)


def load_engine_specs() -> List[Type[BaseEngineSpec]]:
"""
Load all engine specs, native and 3rd party.
"""
engine_specs: List[Type[BaseEngineSpec]] = []

# load standard engines
Expand All @@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs


def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-bodley does AirBnB use this function outside of the Superset code base?

def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
"""
Return the DB engine spec associated with a given SQLAlchemy URL.

Note that if a driver is not specified the function returns the first DB engine spec
that supports the backend. Also, if a driver is specified but no DB engine explicitly
supporting that driver exists then a backend-only match is done, in order to allow new
drivers to work with Superset even if they are not listed in the DB engine spec
drivers.
"""
engine_specs = load_engine_specs()

# build map from name/alias -> spec
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
for engine_spec in engine_specs:
names = [engine_spec.engine]
if engine_spec.engine_aliases:
names.extend(engine_spec.engine_aliases)
if driver is not None:
for engine_spec in engine_specs:
if engine_spec.supports_backend(backend, driver):
return engine_spec

for name in names:
engine_specs_map[name] = engine_spec
# check ignoring the driver, in order to support new drivers; this will return a
# random DB engine spec that supports the engine
for engine_spec in engine_specs:
if engine_spec.supports_backend(backend):
return engine_spec

return engine_specs_map
# default to the generic DB engine spec
return BaseEngineSpec


# there's a mismatch between the dialect name reported by the driver in these
Expand Down
62 changes: 60 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""

engine_name: Optional[str] = None # for user messages, overridden in child classes

# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
# see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
engine_name: Optional[str] = None # for user messages, overridden in child classes
drivers: Dict[str, str] = {}
default_driver: Optional[str] = None

_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
Expand Down Expand Up @@ -355,6 +361,58 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
] = {}

@classmethod
def supports_url(cls, url: URL) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy URL.

As an example, if a given DB engine spec has:

class PostgresDBEngineSpec:
engine = "postgresql"
engine_aliases = "postgres"
drivers = {
"psycopg2": "The default Postgres driver",
"asyncpg": "An asynchronous Postgres driver",
}

It would be used for all the following SQLAlchemy URIs:

- postgres://user:password@host/db
- postgresql://user:password@host/db
- postgres+asyncpg://user:password@host/db
- postgres+psycopg2://user:password@host/db
- postgresql+asyncpg://user:password@host/db
- postgresql+psycopg2://user:password@host/db

Note that SQLAlchemy has a default driver even if one is not specified:

>>> from sqlalchemy.engine.url import make_url
>>> make_url('postgres://').get_driver_name()
'psycopg2'

"""
backend = url.get_backend_name()
driver = url.get_driver_name()
return cls.supports_backend(backend, driver)

@classmethod
def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
# check the backend first
if backend != cls.engine and backend not in cls.engine_aliases:
return False

# originally DB engine specs didn't declare any drivers and the check was made
# only on the engine; if that's the case, ignore the driver for backwards
# compatibility
if not cls.drivers or driver is None:
return True

return driver in cls.drivers

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
Expand Down Expand Up @@ -394,7 +452,7 @@ def get_allow_cost_estimate( # pylint: disable=unused-argument
@classmethod
def get_text_clause(cls, clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
SQLAlchemy wrapper to ensure text clauses are escaped properly

:param clause: string clause with potentially unescaped characters
:return: text clause with escaped characters
Expand Down
19 changes: 13 additions & 6 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,23 @@


class DatabricksHiveEngineSpec(HiveEngineSpec):
engine = "databricks"
engine_name = "Databricks Interactive Cluster"
driver = "pyhive"

engine = "databricks"
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
default_driver = "pyhive"

_show_functions_column = "function"

_time_grain_expressions = time_grain_expressions


class DatabricksODBCEngineSpec(BaseEngineSpec):
engine = "databricks"
engine_name = "Databricks SQL Endpoint"
driver = "pyodbc"

engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"

_time_grain_expressions = time_grain_expressions

Expand All @@ -74,9 +79,11 @@ def epoch_to_dttm(cls) -> str:


class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
engine = "databricks"
engine_name = "Databricks Native Connector"
driver = "connector"

engine = "databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
Expand Down
6 changes: 5 additions & 1 deletion superset/db_engine_specs/shillelagh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
class ShillelaghEngineSpec(SqliteEngineSpec):
"""Engine for shillelagh"""

engine = "shillelagh"
engine_name = "Shillelagh"
engine = "shillelagh"
drivers = {"apsw": "SQLite driver"}
default_driver = "apsw"
sqlalchemy_uri_placeholder = "shillelagh://"

allows_joins = True
allows_subqueries = True
2 changes: 1 addition & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from superset.models.core import Database

try:
from trino.dbapi import Cursor # pylint: disable=unused-import
from trino.dbapi import Cursor
except ImportError:
pass

Expand Down
Loading