Skip to content

Commit

Permalink
fix: improve get_db_engine_spec_for_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Aug 25, 2022
1 parent 6a0b7e5 commit b2e8c66
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 78 deletions.
26 changes: 3 additions & 23 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from superset.databases.dao import DatabaseDAO
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import get_engine_spec
from superset.db_engine_specs.base import BasicParametersMixin
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
Expand All @@ -45,25 +45,13 @@ def __init__(self, parameters: Dict[str, Any]):

def run(self) -> None:
engine = self._properties["engine"]
engine_specs = get_engine_specs()
driver = self._properties["driver"]

if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return

if engine not in engine_specs:
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={"allowed": list(engine_specs), "provided": engine},
),
)
engine_spec = engine_specs[engine]
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
Expand All @@ -73,14 +61,6 @@ def run(self) -> None:
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={
"allowed": [
name
for name, engine_spec in engine_specs.items()
if issubclass(engine_spec, BasicParametersMixin)
],
"provided": engine,
},
),
)

Expand Down
39 changes: 14 additions & 25 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from superset import db
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
from superset.db_engine_specs import BaseEngineSpec, get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
Expand Down Expand Up @@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
"""

engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(),
Expand Down Expand Up @@ -262,10 +263,20 @@ def build_sqlalchemy_uri(
or parameters.pop("engine", None)
or data.pop("backend", None)
)
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
driver = data.pop("driver", None)

configuration_method = data.get("configuration_method")
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
engine_spec = get_engine_spec(engine)
engine_spec = get_engine_spec(engine, driver)

if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
engine_spec, "parameters_schema"
Expand Down Expand Up @@ -295,34 +306,12 @@ def build_sqlalchemy_uri(
return data


def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_specs = get_engine_specs()
if engine not in engine_specs:
raise ValidationError(
[
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
)
return engine_specs[engine]


class DatabaseValidateParametersSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE

engine = fields.String(required=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(allow_none=True),
Expand Down
35 changes: 20 additions & 15 deletions superset/db_engine_specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,34 @@
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Set, Type
from typing import Any, Dict, List, Optional, Set, Type

import sqlalchemy.databases
import sqlalchemy.dialects
from pkg_resources import iter_entry_points
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.url import URL

from superset.db_engine_specs.base import BaseEngineSpec

logger = logging.getLogger(__name__)


def is_engine_spec(attr: Any) -> bool:
def is_engine_spec(obj: Any) -> bool:
"""
Return true if a given object is a DB engine spec.
"""
return (
inspect.isclass(attr)
and issubclass(attr, BaseEngineSpec)
and attr != BaseEngineSpec
inspect.isclass(obj)
and issubclass(obj, BaseEngineSpec)
and obj != BaseEngineSpec
)


def load_engine_specs() -> List[Type[BaseEngineSpec]]:
"""
Load all engine specs, native and 3rd party.
"""
engine_specs: List[Type[BaseEngineSpec]] = []

# load standard engines
Expand All @@ -78,20 +85,18 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs


def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
"""
Return the DB engine spec associated with a given SQLAlchemy URL.
"""
engine_specs = load_engine_specs()

# build map from name/alias -> spec
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
for engine_spec in engine_specs:
names = [engine_spec.engine]
if engine_spec.engine_aliases:
names.extend(engine_spec.engine_aliases)

for name in names:
engine_specs_map[name] = engine_spec
if engine_spec.supports_backend(backend, driver):
return engine_spec

return engine_specs_map
# default to the generic DB engine spec
return BaseEngineSpec


# there's a mismatch between the dialect name reported by the driver in these
Expand Down
58 changes: 57 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,39 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""

engine_name: Optional[str] = None # for user messages, overridden in child classes

# Associate the DB engine spec to one or more SQLAlchemy dialects/drivers. For
# example, if a given DB engine spec has:
#
# class PostgresDBEngineSpec:
# engine = 'postgresql'
# engine_aliases = 'postgres'
# drivers = {'psycopg2', 'asyncpg'}
#
# It would be used for all the following SQLALchemy URIs:
#
# - postgres://user:password@host/db
# - postgresql://user:password@host/db
# - postgres+asyncpg://user:password@host/db
# - postgres+psycopg2://user:password@host/db
# - postgresql+asyncpg://user:password@host/db
# - postgresql+psycopg2://user:password@host/db
#
# Note that SQLAlchemy has a default driver when one is not specified:
#
# >>> from sqlalchemy.engine.url import make_url
# >>> make_url('postgres://').get_driver_name()
# 'psycopg2'
#
# The ``default_driver`` should point to the recomended driver, and is used by
# database creation modals where the user provides parameters to connect to the
# database, instead of providing the SQLAlchemy URI.
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
engine_name: Optional[str] = None # for user messages, overridden in child classes
drivers: Dict[str, str] = {}
default_driver: Optional[str] = None

_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
Expand Down Expand Up @@ -355,6 +385,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
] = {}

@classmethod
def supports_url(cls, url: URL) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy URL.
"""
backend = url.get_backend_name()
driver = url.get_driver_name()
return cls.supports_backend(backend, driver)

@classmethod
def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
# check the backend first
if backend != cls.engine and backend not in cls.engine_aliases:
return False

# originally DB engine specs didn't declare any drivers and the check was made
# only on the engine; if that's the case, ignore the driver for backwards
# compatibility
if not cls.drivers or driver is None:
return True

return driver in cls.drivers

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
Expand Down
19 changes: 13 additions & 6 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,23 @@


class DatabricksHiveEngineSpec(HiveEngineSpec):
engine = "databricks"
engine_name = "Databricks Interactive Cluster"
driver = "pyhive"

engine = "databricks"
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
default_driver = "pyhive"

_show_functions_column = "function"

_time_grain_expressions = time_grain_expressions


class DatabricksODBCEngineSpec(BaseEngineSpec):
engine = "databricks"
engine_name = "Databricks SQL Endpoint"
driver = "pyodbc"

engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"

_time_grain_expressions = time_grain_expressions

Expand All @@ -74,9 +79,11 @@ def epoch_to_dttm(cls) -> str:


class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
engine = "databricks"
engine_name = "Databricks Native Connector"
driver = "connector"

engine = "databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
Expand Down
6 changes: 5 additions & 1 deletion superset/db_engine_specs/shillelagh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
class ShillelaghEngineSpec(SqliteEngineSpec):
"""Engine for shillelagh"""

engine = "shillelagh"
engine_name = "Shillelagh"
engine = "shillelagh"
drivers = {"apsw": "SQLite driver"}
default_driver = "apsw"
sqlalchemy_uri_placeholder = "shillelagh://"

allows_joins = True
allows_subqueries = True
13 changes: 7 additions & 6 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,15 +635,16 @@ def get_all_schema_names( # pylint: disable=unused-argument

@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
return self.get_db_engine_spec_for_backend(self.backend)
url = make_url_safe(self.sqlalchemy_uri_decrypted)
return self.get_db_engine_spec(url)

@classmethod
@memoized
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
engines = db_engine_specs.get_engine_specs()
return engines.get(backend, db_engine_specs.BaseEngineSpec)
def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
backend = url.get_backend_name()
driver = url.get_driver_name()

return db_engine_specs.get_engine_spec(backend, driver)

def grains(self) -> Tuple[TimeGrain, ...]:
"""Defines time granularity database-specific expressions.
Expand Down
Loading

0 comments on commit b2e8c66

Please sign in to comment.