diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d6a7d3e3a2645..aebfad0781bd5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -335,6 +335,7 @@ def get_timestamp_expression( :param time_grain: Optional time grain, e.g. P1Y :param label: alias/label that column is expected to have + :param template_processor: template processor :return: A TimeExpression object wrapped in a Label if supported by db """ label = label or utils.DTTM_ALIAS @@ -488,6 +489,27 @@ def data(self) -> Dict[str, Any]: ) +def _process_sql_expression( + expression: Optional[str], + database_id: int, + schema: str, + template_processor: Optional[BaseTemplateProcessor], +) -> Optional[str]: + if template_processor and expression: + expression = template_processor.process_template(expression) + if expression: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex + return expression + + class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" @@ -875,13 +897,17 @@ def get_rendered_sql( return sql def adhoc_metric_to_sqla( - self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn] + self, + metric: AdhocMetric, + columns_by_name: Dict[str, TableColumn], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict columns_by_name: Columns for the current table + :param template_processor: template_processor instance :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -898,17 +924,12 @@ def adhoc_metric_to_sqla( sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: - tp = self.get_template_processor() - expression = tp.process_template(cast(str, metric["sqlExpression"])) - expression = validate_adhoc_subquery( - expression, - self.database_id, - self.schema, + expression = _process_sql_expression( + expression=metric["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, ) - try: - expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -929,21 +950,14 @@ def adhoc_column_to_sqla( :rtype: sqlalchemy.sql.column """ label = utils.get_column_name(col) - expression = col["sqlExpression"] - if template_processor and expression: - expression = template_processor.process_template(expression) - if expression: - expression = validate_adhoc_subquery( - expression, - self.database_id, - self.schema, - ) - try: - expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex - sqla_metric = literal_column(expression) - return self.make_sqla_column_compatible(sqla_metric, label) + expression = _process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + sqla_column = literal_column(expression) + return self.make_sqla_column_compatible(sqla_column, label) def make_sqla_column_compatible( self, sqla_col: ColumnElement, label: Optional[str] = None @@ -1127,7 +1141,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma for metric in metrics: if utils.is_adhoc_metric(metric): assert isinstance(metric, dict) - metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name)) + metrics_exprs.append( + self.adhoc_metric_to_sqla( + metric=metric, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) + ) elif isinstance(metric, str) and metric in metrics_by_name: metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) else: @@ -1154,10 +1174,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if isinstance(col, dict): col = cast(AdhocMetric, col) if col.get("sqlExpression"): - col["sqlExpression"] = validate_adhoc_subquery( - cast(str, col["sqlExpression"]), - self.database_id, - self.schema, + col["sqlExpression"] = _process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists