From 811bb5d83b24e90d2ddd905e618679e7e59624b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Sat, 15 Jan 2022 12:02:47 +0700 Subject: [PATCH] feat: Trino Authentications (#17593) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support Trino Authentications Signed-off-by: Đặng Minh Dũng * docs: Trino Authentications Signed-off-by: Đặng Minh Dũng --- .../docs/Connecting to Databases/trino.mdx | 91 ++++++++++++++++-- requirements/base.txt | 2 + requirements/testing.in | 1 + requirements/testing.txt | 10 +- superset/config.py | 24 ++++- superset/db_engine_specs/base.py | 20 ++++ superset/db_engine_specs/trino.py | 43 +++++++++ superset/models/core.py | 25 +++-- .../db_engine_specs/trino_tests.py | 93 +++++++++++++++++++ 9 files changed, 286 insertions(+), 23 deletions(-) diff --git a/docs/src/pages/docs/Connecting to Databases/trino.mdx b/docs/src/pages/docs/Connecting to Databases/trino.mdx index 418956c2bacd3..6f0e132b8712d 100644 --- a/docs/src/pages/docs/Connecting to Databases/trino.mdx +++ b/docs/src/pages/docs/Connecting to Databases/trino.mdx @@ -8,20 +8,95 @@ version: 1 ## Trino -Supported trino version 352 and higher - -The [sqlalchemy-trino](https://pypi.org/project/sqlalchemy-trino/) library is the recommended way to connect to Trino through SQLAlchemy. - -The expected connection string is formatted as follows: +Superset supports Trino >=352 via SQLAlchemy by using the [sqlalchemy-trino](https://pypi.org/project/sqlalchemy-trino/) library. +### Connection String +The connection string format is as follows: ``` trino://{username}:{password}@{hostname}:{port}/{catalog} ``` -If you are running trino with docker on local machine please use the following connection URL +If you are running Trino with docker on local machine, please use the following connection URL ``` trino://trino@host.docker.internal:8080 ``` -Reference: -[Trino-Superset-Podcast](https://trino.io/episodes/12.html) +### Authentications +#### 1. Basic Authentication +You can provide `username`/`password` in the connection string or in the `Secure Extra` field at `Advanced / Security` +* In Connection String + ``` + trino://{username}:{password}@{hostname}:{port}/{catalog} + ``` + +* In `Secure Extra` field + ```json + { + "auth_method": "basic", + "auth_params": { + "username": "", + "password": "" + } + } + ``` + +NOTE: if both are provided, `Secure Extra` always takes higher priority. + +#### 2. Kerberos Authentication +In `Secure Extra` field, config as following example: +```json +{ + "auth_method": "kerberos", + "auth_params": { + "service_name": "superset", + "config": "/path/to/krb5.config", + ... + } +} +``` + +All fields in `auth_params` are passed directly to the [`KerberosAuthentication`](https://github.com/trinodb/trino-python-client/blob/0.306.0/trino/auth.py#L40) class. + +#### 3. JWT Authentication +Config `auth_method` and provide token in `Secure Extra` field +```json +{ + "auth_method": "jwt", + "auth_params": { + "token": "" + } +} +``` + +#### 4. Custom Authentication +To use custom authentication, first you need to add it into +`ALLOWED_EXTRA_AUTHENTICATIONS` allow list in Superset config file: +```python +from your.module import AuthClass +from another.extra import auth_method + +ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = { + "trino": { + "custom_auth": AuthClass, + "another_auth_method": auth_method, + }, +} +``` + +Then in `Secure Extra` field: +```json +{ + "auth_method": "custom_auth", + "auth_params": { + ... + } +} +``` + +You can also use custom authentication by providing reference to your `trino.auth.Authentication` class +or factory function (which returns an `Authentication` instance) to `auth_method`. + +All fields in `auth_params` are passed directly to your class/function. + +**Reference**: +* [Trino-Superset-Podcast](https://trino.io/episodes/12.html) diff --git a/requirements/base.txt b/requirements/base.txt index e5804425675eb..7fede6960464a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -283,6 +283,8 @@ werkzeug==1.0.1 # via # flask # flask-jwt-extended +wrapt==1.12.1 + # via -r requirements/base.in wtforms==2.3.3 # via # flask-appbuilder diff --git a/requirements/testing.in b/requirements/testing.in index 575016ff23925..c33f245280bb0 100644 --- a/requirements/testing.in +++ b/requirements/testing.in @@ -38,3 +38,4 @@ statsd pytest-mock # DB dependencies -e file:.[bigquery] +-e file:.[trino] diff --git a/requirements/testing.txt b/requirements/testing.txt index 806b186ea2bd3..8eb0d4f3ec3a4 100644 --- a/requirements/testing.txt +++ b/requirements/testing.txt @@ -1,4 +1,4 @@ -# SHA1:9658361c2ab00a6b27c5875b7b3557c2999854ba +# SHA1:7a8e256097b4758bdeda2529d3d4d31e421e1a3c # # This file is autogenerated by pip-compile-multi # To update, run: @@ -11,8 +11,6 @@ # via # -r requirements/base.in # -r requirements/testing.in -appnope==0.1.2 - # via ipython astroid==2.6.6 # via pylint backcall==0.2.0 @@ -166,20 +164,22 @@ requests-oauthlib==1.3.0 # via google-auth-oauthlib rsa==4.7.2 # via google-auth +sqlalchemy-trino==0.4.1 + # via apache-superset statsd==3.3.0 # via -r requirements/testing.in traitlets==5.0.5 # via # ipython # matplotlib-inline +trino==0.306 + # via sqlalchemy-trino typing-inspect==0.7.1 # via libcst wcwidth==0.2.5 # via prompt-toolkit websocket-client==1.2.0 # via docker -wrapt==1.12.1 - # via astroid # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/superset/config.py b/superset/config.py index b5d39abb93ee4..7db2978514b16 100644 --- a/superset/config.py +++ b/superset/config.py @@ -723,6 +723,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # Force refresh while auto-refresh in dashboard DASHBOARD_AUTO_REFRESH_MODE: Literal["fetch", "force"] = "force" + # Default celery config is to use SQLA as a broker, in a production setting # you'll want to use a proper broker as specified here: # http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html @@ -872,6 +873,8 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # The directory within the bucket specified above that will # contain all the external tables CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/" + + # Function that creates upload directory dynamically based on the # database used, user and schema provided. def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name @@ -986,6 +989,19 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds()) +# Allow list of custom authentications for each DB engine. +# Example: +# from your.module import AuthClass +# from another.extra import auth_method +# +# ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = { +# "trino": { +# "custom_auth": AuthClass, +# "another_auth_method": auth_method, +# }, +# } +ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {} + # Allow for javascript controls components # this enables programmers to customize certain charts (like the # geospatial ones) by inputing javascript in controls. This exposes @@ -1012,6 +1028,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # as such `create_engine(url, **params)` DB_CONNECTION_MUTATOR = None + # A function that intercepts the SQL to be executed and can alter it. # The use case is can be around adding some sort of comment header # with information such as the username and worker node information @@ -1323,8 +1340,8 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument if CONFIG_PATH_ENV_VAR in os.environ: # Explicitly import config module that is not necessarily in pythonpath; useful # for case where app is being executed via pex. + cfg_path = os.environ[CONFIG_PATH_ENV_VAR] try: - cfg_path = os.environ[CONFIG_PATH_ENV_VAR] module = sys.modules[__name__] override_conf = imp.load_source("superset_config", cfg_path) for key in dir(override_conf): @@ -1339,8 +1356,9 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument raise elif importlib.util.find_spec("superset_config") and not is_test(): try: - import superset_config # pylint: disable=import-error - from superset_config import * # type: ignore # pylint: disable=import-error,wildcard-import,unused-wildcard-import + # pylint: disable=import-error,wildcard-import,unused-wildcard-import + import superset_config + from superset_config import * # type:ignore print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]") except Exception: diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index f579fc2502768..bdd1922d2ca56 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1273,6 +1273,26 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: raise ex return extra + @staticmethod + def update_encrypted_extra_params( + database: "Database", params: Dict[str, Any] + ) -> None: + """ + Some databases require some sensitive information which do not conform to + the username:password syntax normally used by SQLAlchemy. + + :param database: database instance from which to extract extras + :param params: params to be updated + """ + if not database.encrypted_extra: + return + try: + encrypted_extra = json.loads(database.encrypted_extra) + params.update(encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise ex + @classmethod def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: """Pessimistic readonly, 100% sure statement won't mutate anything""" diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 8d294bfa89e4d..778845d132e95 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging from datetime import datetime from typing import Any, Dict, List, Optional, TYPE_CHECKING from urllib import parse import simplejson as json +from flask import current_app from sqlalchemy.engine.url import make_url, URL from superset.db_engine_specs.base import BaseEngineSpec @@ -27,6 +29,8 @@ if TYPE_CHECKING: from superset.models.core import Database +logger = logging.getLogger(__name__) + class TrinoEngineSpec(BaseEngineSpec): engine = "trino" @@ -202,3 +206,42 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert) return extra + + @staticmethod + def update_encrypted_extra_params( + database: "Database", params: Dict[str, Any] + ) -> None: + if not database.encrypted_extra: + return + try: + encrypted_extra = json.loads(database.encrypted_extra) + auth_method = encrypted_extra.pop("auth_method", None) + auth_params = encrypted_extra.pop("auth_params", {}) + if not auth_method: + return + + connect_args = params.setdefault("connect_args", {}) + connect_args["http_scheme"] = "https" + # pylint: disable=import-outside-toplevel + if auth_method == "basic": + from trino.auth import BasicAuthentication as trino_auth # noqa + elif auth_method == "kerberos": + from trino.auth import KerberosAuthentication as trino_auth # noqa + elif auth_method == "jwt": + from trino.auth import JWTAuthentication as trino_auth # noqa + else: + allowed_extra_auths = current_app.config[ + "ALLOWED_EXTRA_AUTHENTICATIONS" + ].get("trino", {}) + if auth_method in allowed_extra_auths: + trino_auth = allowed_extra_auths.get(auth_method) + else: + raise ValueError( + f"For security reason, custom authentication '{auth_method}' " + f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" + ) + + connect_args["auth"] = trino_auth(**auth_params) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise ex diff --git a/superset/models/core.py b/superset/models/core.py index d9b7909261f50..4b7adad2d2b29 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -91,7 +91,6 @@ class KeyValue(Model): # pylint: disable=too-few-public-methods class CssTemplate(Model, AuditMixinNullable): - """CSS templates for dashboards""" __tablename__ = "css_templates" @@ -244,7 +243,10 @@ def parameters(self) -> Dict[str, Any]: uri = make_url(self.sqlalchemy_uri_decrypted) encrypted_extra = self.get_encrypted_extra() try: - parameters = self.db_engine_spec.get_parameters_from_uri(uri, encrypted_extra=encrypted_extra) # type: ignore # pylint: disable=line-too-long,useless-suppression + # pylint: disable=useless-suppression + parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore + uri, encrypted_extra=encrypted_extra + ) except Exception: # pylint: disable=broad-except parameters = {} @@ -330,7 +332,14 @@ def get_effective_user( effective_username = g.user.username return effective_username - @memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) + @memoized( + watch=( + "impersonate_user", + "sqlalchemy_uri_decrypted", + "extra", + "encrypted_extra", + ) + ) def get_sqla_engine( self, schema: Optional[str] = None, @@ -365,7 +374,7 @@ def get_sqla_engine( if connect_args: params["connect_args"] = connect_args - params.update(self.get_encrypted_extra()) + self.update_encrypted_extra_params(params) if DB_CONNECTION_MUTATOR: if not source and request and request.referrer: @@ -443,9 +452,8 @@ def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) - if ( - engine.dialect.identifier_preparer._double_percents # pylint: disable=protected-access - ): + # pylint: disable=protected-access + if engine.dialect.identifier_preparer._double_percents: # noqa sql = sql.replace("%%", "%") return sql @@ -639,6 +647,9 @@ def get_encrypted_extra(self) -> Dict[str, Any]: raise ex return encrypted_extra + def update_encrypted_extra_params(self, params: Dict[str, Any]) -> None: + self.db_engine_spec.update_encrypted_extra_params(self, params) + def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py index e77e91603540f..973eb17159ae7 100644 --- a/tests/integration_tests/db_engine_specs/trino_tests.py +++ b/tests/integration_tests/db_engine_specs/trino_tests.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. import json +from typing import Any, Dict from unittest.mock import Mock, patch +import pytest from sqlalchemy.engine.url import URL +import superset.config from superset.db_engine_specs.trino import TrinoEngineSpec from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -87,3 +90,93 @@ def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock self.assertEqual(connect_args.get("http_scheme"), "https") self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt") create_ssl_cert_file_func.assert_called_once_with(database.server_cert) + + @patch("trino.auth.BasicAuthentication") + def test_auth_basic(self, auth: Mock): + database = Mock() + + auth_params = {"username": "username", "password": "password"} + database.encrypted_extra = json.dumps( + {"auth_method": "basic", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_encrypted_extra_params(database, params) + connect_args = params.setdefault("connect_args", {}) + self.assertEqual(connect_args.get("http_scheme"), "https") + auth.assert_called_once_with(**auth_params) + + @patch("trino.auth.KerberosAuthentication") + def test_auth_kerberos(self, auth: Mock): + database = Mock() + + auth_params = { + "service_name": "superset", + "mutual_authentication": False, + "delegate": True, + } + database.encrypted_extra = json.dumps( + {"auth_method": "kerberos", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_encrypted_extra_params(database, params) + connect_args = params.setdefault("connect_args", {}) + self.assertEqual(connect_args.get("http_scheme"), "https") + auth.assert_called_once_with(**auth_params) + + @patch("trino.auth.JWTAuthentication") + def test_auth_jwt(self, auth: Mock): + database = Mock() + + auth_params = {"token": "jwt-token-string"} + database.encrypted_extra = json.dumps( + {"auth_method": "jwt", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_encrypted_extra_params(database, params) + connect_args = params.setdefault("connect_args", {}) + self.assertEqual(connect_args.get("http_scheme"), "https") + auth.assert_called_once_with(**auth_params) + + def test_auth_custom_auth(self): + database = Mock() + auth_class = Mock() + + auth_method = "custom_auth" + auth_params = {"params1": "params1", "params2": "params2"} + database.encrypted_extra = json.dumps( + {"auth_method": auth_method, "auth_params": auth_params} + ) + + with patch.dict( + "superset.config.ALLOWED_EXTRA_AUTHENTICATIONS", + {"trino": {"custom_auth": auth_class}}, + clear=True, + ): + params: Dict[str, Any] = {} + TrinoEngineSpec.update_encrypted_extra_params(database, params) + + connect_args = params.setdefault("connect_args", {}) + self.assertEqual(connect_args.get("http_scheme"), "https") + + auth_class.assert_called_once_with(**auth_params) + + def test_auth_custom_auth_denied(self): + database = Mock() + auth_method = "my.module:TrinoAuthClass" + auth_params = {"params1": "params1", "params2": "params2"} + database.encrypted_extra = json.dumps( + {"auth_method": auth_method, "auth_params": auth_params} + ) + + superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {} + + with pytest.raises(ValueError) as excinfo: + TrinoEngineSpec.update_encrypted_extra_params(database, {}) + + assert str(excinfo.value) == ( + f"For security reason, custom authentication '{auth_method}' " + f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" + )