diff --git a/superset/databases/api.py b/superset/databases/api.py index edf63392ee6f6..6d0e335cfec08 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -120,6 +120,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "allow_cvas", "allow_dml", "backend", + "driver", "force_ctas_schema", "impersonate_user", "masked_encrypted_extra", @@ -269,6 +270,9 @@ def post(self) -> Response: if new_model.parameters: item["parameters"] = new_model.parameters + if new_model.driver: + item["driver"] = new_model.driver + return self.response(201, id=new_model.id, result=item) except DatabaseInvalidError as ex: return self.response_422(message=ex.normalized_messages()) diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index ddb45c4465a89..a8956257fa28a 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -38,8 +38,8 @@ class ValidateDatabaseParametersCommand(BaseCommand): - def __init__(self, parameters: Dict[str, Any]): - self._properties = parameters.copy() + def __init__(self, properties: Dict[str, Any]): + self._properties = properties.copy() self._model: Optional[Database] = None def run(self) -> None: @@ -66,9 +66,7 @@ def run(self) -> None: ) # perform initial validation - errors = engine_spec.validate_parameters( # type: ignore - self._properties.get("parameters", {}) - ) + errors = engine_spec.validate_parameters(self._properties) # type: ignore if errors: event_logger.log_with_context(action="validation_error", engine=engine) raise InvalidParametersError(errors) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index d80625dc0c385..dabed0c7aeeec 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1685,6 +1685,10 @@ class BasicParametersType(TypedDict, total=False): encryption: bool +class BasicPropertiesType(TypedDict): + parameters: BasicParametersType + + class BasicParametersMixin: """ Mixin for configuring DB engine specs via a dictionary. @@ -1762,7 +1766,7 @@ def get_parameters_from_uri( # pylint: disable=unused-argument @classmethod def validate_parameters( - cls, parameters: BasicParametersType + cls, properties: BasicPropertiesType ) -> List[SupersetError]: """ Validates any number of parameters, for progressive validation. @@ -1773,6 +1777,7 @@ def validate_parameters( errors: List[SupersetError] = [] required = {"host", "port", "username", "database"} + parameters = properties.get("parameters", {}) present = {key for key in parameters if parameters.get(key, ())} missing = sorted(required - present) diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 38b1f92023a10..8f59af945d28c 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -34,7 +34,7 @@ from superset.constants import PASSWORD_MASK from superset.databases.schemas import encrypted_field_properties, EncryptedString from superset.databases.utils import make_url_safe -from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.exceptions import SupersetDBAPIDisconnectionError from superset.errors import SupersetError, SupersetErrorType from superset.sql_parse import Table @@ -450,7 +450,8 @@ def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: @classmethod def validate_parameters( - cls, parameters: BigQueryParametersType # pylint: disable=unused-argument + cls, + properties: BasicPropertiesType, # pylint: disable=unused-argument ) -> List[SupersetError]: return [] diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 3562e0d0b1384..78b42d2b3a999 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -58,6 +58,10 @@ class GSheetsParametersType(TypedDict): catalog: Dict[str, str] +class GSheetsPropertiesType(TypedDict): + parameters: GSheetsParametersType + + class GSheetsEngineSpec(SqliteEngineSpec): """Engine for Google spreadsheets""" @@ -208,9 +212,10 @@ def parameters_json_schema(cls) -> Any: @classmethod def validate_parameters( cls, - parameters: GSheetsParametersType, + properties: GSheetsPropertiesType, ) -> List[SupersetError]: errors: List[SupersetError] = [] + parameters = properties.get("parameters", {}) encrypted_credentials = parameters.get("service_account_info") or "{}" # On create the encrypted credentials are a string, diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 6ead37ded373f..0704712d6515b 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -32,6 +32,7 @@ from typing_extensions import TypedDict from superset.databases.utils import make_url_safe +from superset.db_engine_specs.base import BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query @@ -242,7 +243,7 @@ def get_parameters_from_uri( @classmethod def validate_parameters( - cls, parameters: SnowflakeParametersType + cls, properties: BasicPropertiesType ) -> List[SupersetError]: errors: List[SupersetError] = [] required = { @@ -253,6 +254,7 @@ def validate_parameters( "role", "password", } + parameters = properties.get("parameters", {}) present = {key for key in parameters if parameters.get(key, ())} missing = sorted(required - present) diff --git a/superset/models/core.py b/superset/models/core.py index ca15137935ef9..008230ef4874f 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -244,8 +244,11 @@ def url_object(self) -> URL: @property def backend(self) -> str: - sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted) - return sqlalchemy_url.get_backend_name() + return self.url_object.get_backend_name() + + @property + def driver(self) -> str: + return self.url_object.get_driver_name() @property def masked_encrypted_extra(self) -> Optional[str]: @@ -253,14 +256,12 @@ def masked_encrypted_extra(self) -> Optional[str]: @property def parameters(self) -> Dict[str, Any]: - db_engine_spec = self.db_engine_spec - + # Database parameters are a dictionary of values that are used to make up + # the sqlalchemy_uri # When returning the parameters we should use the masked SQLAlchemy URI and the # masked ``encrypted_extra`` to prevent exposing sensitive credentials. masked_uri = make_url_safe(self.sqlalchemy_uri) - masked_encrypted_extra = db_engine_spec.mask_encrypted_extra( - self.encrypted_extra - ) + masked_encrypted_extra = self.masked_encrypted_extra encrypted_config = {} if masked_encrypted_extra is not None: try: @@ -270,7 +271,7 @@ def parameters(self) -> Dict[str, Any]: try: # pylint: disable=useless-suppression - parameters = db_engine_spec.get_parameters_from_uri( # type: ignore + parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore masked_uri, encrypted_extra=encrypted_config, ) diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index f998444f31895..c31a501487dda 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -421,28 +421,32 @@ def test_validate(is_port_open, is_hostname_valid): is_hostname_valid.return_value = True is_port_open.return_value = True - parameters = { - "host": "localhost", - "port": 5432, - "username": "username", - "password": "password", - "database": "dbname", - "query": {"sslmode": "verify-full"}, + properties = { + "parameters": { + "host": "localhost", + "port": 5432, + "username": "username", + "password": "password", + "database": "dbname", + "query": {"sslmode": "verify-full"}, + } } - errors = BasicParametersMixin.validate_parameters(parameters) + errors = BasicParametersMixin.validate_parameters(properties) assert errors == [] def test_validate_parameters_missing(): - parameters = { - "host": "", - "port": None, - "username": "", - "password": "", - "database": "", - "query": {}, + properties = { + "parameters": { + "host": "", + "port": None, + "username": "", + "password": "", + "database": "", + "query": {}, + } } - errors = BasicParametersMixin.validate_parameters(parameters) + errors = BasicParametersMixin.validate_parameters(properties) assert errors == [ SupersetError( message=( @@ -459,15 +463,17 @@ def test_validate_parameters_missing(): def test_validate_parameters_invalid_host(is_hostname_valid): is_hostname_valid.return_value = False - parameters = { - "host": "localhost", - "port": None, - "username": "username", - "password": "password", - "database": "dbname", - "query": {"sslmode": "verify-full"}, + properties = { + "parameters": { + "host": "localhost", + "port": None, + "username": "username", + "password": "password", + "database": "dbname", + "query": {"sslmode": "verify-full"}, + } } - errors = BasicParametersMixin.validate_parameters(parameters) + errors = BasicParametersMixin.validate_parameters(properties) assert errors == [ SupersetError( message="One or more parameters are missing: port", @@ -490,15 +496,17 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid): is_hostname_valid.return_value = True is_port_open.return_value = False - parameters = { - "host": "localhost", - "port": 5432, - "username": "username", - "password": "password", - "database": "dbname", - "query": {"sslmode": "verify-full"}, + properties = { + "parameters": { + "host": "localhost", + "port": 5432, + "username": "username", + "password": "password", + "database": "dbname", + "query": {"sslmode": "verify-full"}, + } } - errors = BasicParametersMixin.validate_parameters(parameters) + errors = BasicParametersMixin.validate_parameters(properties) assert errors == [ SupersetError( message="The port is closed.", diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 006c57e01d23d..c0df8bb4d1bb7 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -22,6 +22,7 @@ from uuid import UUID import pytest +from pytest_mock import MockFixture from sqlalchemy.orm.session import Session @@ -53,6 +54,7 @@ def test_post_with_uuid( def test_password_mask( + mocker: MockFixture, app: Any, session: Session, client: Any, @@ -92,6 +94,10 @@ def test_password_mask( session.add(database) session.commit() + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + response = client.get("/api/v1/database/1") assert ( response.json["result"]["parameters"]["service_account_info"]["private_key"] diff --git a/tests/unit_tests/databases/schema_tests.py b/tests/unit_tests/databases/schema_tests.py index 58a1f6389d4c1..a8e96c614416c 100644 --- a/tests/unit_tests/databases/schema_tests.py +++ b/tests/unit_tests/databases/schema_tests.py @@ -134,7 +134,6 @@ def test_database_parameters_schema_mixin_invalid_engine( try: dummy_schema.load(payload) except ValidationError as err: - print(err.messages) assert err.messages == { "_schema": ['Engine "dummy_engine" is not a valid engine.'] } diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index a226653ed57e0..487e1eff695f8 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -33,14 +33,16 @@ class ProgrammingError(Exception): def test_validate_parameters_simple() -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, - GSheetsParametersType, + GSheetsPropertiesType, ) - parameters: GSheetsParametersType = { - "service_account_info": "", - "catalog": {}, + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": {}, + } } - errors = GSheetsEngineSpec.validate_parameters(parameters) + errors = GSheetsEngineSpec.validate_parameters(properties) assert errors == [ SupersetError( message="Sheet name is required", @@ -56,7 +58,7 @@ def test_validate_parameters_catalog( ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, - GSheetsParametersType, + GSheetsPropertiesType, ) g = mocker.patch("superset.db_engine_specs.gsheets.g") @@ -71,15 +73,17 @@ def test_validate_parameters_catalog( ProgrammingError("Unsupported table: https://www.google.com/"), ] - parameters: GSheetsParametersType = { - "service_account_info": "", - "catalog": { - "private_sheet": "https://docs.google.com/spreadsheets/d/1/edit", - "public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1", - "not_a_sheet": "https://www.google.com/", - }, + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": { + "private_sheet": "https://docs.google.com/spreadsheets/d/1/edit", + "public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1", + "not_a_sheet": "https://www.google.com/", + }, + } } - errors = GSheetsEngineSpec.validate_parameters(parameters) # ignore: type + errors = GSheetsEngineSpec.validate_parameters(properties) # ignore: type assert errors == [ SupersetError( @@ -146,7 +150,7 @@ def test_validate_parameters_catalog_and_credentials( ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, - GSheetsParametersType, + GSheetsPropertiesType, ) g = mocker.patch("superset.db_engine_specs.gsheets.g") @@ -161,15 +165,17 @@ def test_validate_parameters_catalog_and_credentials( ProgrammingError("Unsupported table: https://www.google.com/"), ] - parameters: GSheetsParametersType = { - "service_account_info": "", - "catalog": { - "private_sheet": "https://docs.google.com/spreadsheets/d/1/edit", - "public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1", - "not_a_sheet": "https://www.google.com/", - }, + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": { + "private_sheet": "https://docs.google.com/spreadsheets/d/1/edit", + "public_sheet": "https://docs.google.com/spreadsheets/d/1/edit#gid=1", + "not_a_sheet": "https://www.google.com/", + }, + } } - errors = GSheetsEngineSpec.validate_parameters(parameters) # ignore: type + errors = GSheetsEngineSpec.validate_parameters(properties) # ignore: type assert errors == [ SupersetError( message=(