diff --git a/docs/docs/installation/installing-superset-from-scratch.mdx b/docs/docs/installation/installing-superset-from-scratch.mdx index 3a12c9db3ac13..5efdb3e8f1f28 100644 --- a/docs/docs/installation/installing-superset-from-scratch.mdx +++ b/docs/docs/installation/installing-superset-from-scratch.mdx @@ -64,7 +64,7 @@ We don't recommend using the system installed Python. Instead, first install the brew install readline pkg-config libffi openssl mysql postgres ``` -You should install a recent version of Python (the official docker image uses 3.8.12). We'd recommend using a Python version manager like [pyenv](https://github.com/pyenv/pyenv) (and also [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv)). +You should install a recent version of Python (the official docker image uses 3.8.13). We'd recommend using a Python version manager like [pyenv](https://github.com/pyenv/pyenv) (and also [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv)). Let's also make sure we have the latest version of `pip` and `setuptools`: diff --git a/setup.py b/setup.py index 5f1b7b7fd2692..ba16a2e58f67a 100644 --- a/setup.py +++ b/setup.py @@ -167,9 +167,7 @@ def get_git_sha() -> str: "shillelagh": [ "shillelagh[datasetteapi,gsheetsapi,socrata,weatherapi]>=1.0.3, <2" ], - "snowflake": [ - "snowflake-sqlalchemy==1.2.4" - ], # PINNED! 1.2.5 introduced breaking changes requiring sqlalchemy>=1.4.0 + "snowflake": ["snowflake-sqlalchemy>=1.2.4, <2"], "spark": ["pyhive[hive]>=0.6.5", "tableschema", "thrift>=0.11.0, <1.0.0"], "teradata": ["teradatasql>=16.20.0.23"], "thumbnails": ["Pillow>=9.1.1, <10.0.0"], diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx index 5d30ee37e0a71..d5e8fcdee90af 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx @@ -328,6 +328,35 @@ const config: ControlPanelConfig = { row_limit: { default: rowLimit, }, + limit: { + rerender: ['timeseries_limit_metric', 'order_desc'], + }, + timeseries_limit_metric: { + label: t('Series Limit Sort By'), + description: t( + 'Metric used to order the limit if a series limit is present. ' + + 'If undefined reverts to the first metric (where appropriate).', + ), + visibility: ({ controls }) => Boolean(controls?.limit.value), + mapStateToProps: (state, controlState) => { + const timeserieslimitProps = + sharedControls.timeseries_limit_metric.mapStateToProps?.( + state, + controlState, + ) || {}; + timeserieslimitProps.value = state.controls?.limit?.value + ? controlState.value + : []; + return timeserieslimitProps; + }, + }, + order_desc: { + label: t('Series Limit Sort Descending'), + default: false, + description: t( + 'Whether to sort descending or ascending if a series limit is present', + ), + }, }, formDataOverrides: formData => ({ ...formData, diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts index b6b23fc70141c..07f18c3d5e62f 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts @@ -41,6 +41,7 @@ import { EchartsTimeseriesSeriesType, TimeseriesChartTransformedProps, OrientationType, + AxisType, } from './types'; import { DEFAULT_FORM_DATA } from './constants'; import { ForecastSeriesEnum, ForecastValue } from '../types'; @@ -337,13 +338,23 @@ export default function transformProps( rotate: xAxisLabelRotation, }, minInterval: - xAxisType === 'time' && timeGrainSqla + xAxisType === AxisType.time && timeGrainSqla ? TIMEGRAIN_TO_TIMESTAMP[timeGrainSqla] : 0, }; + + if (xAxisType === AxisType.time) { + /** + * Overriding default behavior (false) for time axis regardless of the granilarity. + * Not including this in the initial declaration above so if echarts changes the default + * behavior for other axist types we won't unintentionally override it + */ + xAxis.axisLabel.showMaxLabel = null; + } + let yAxis: any = { ...defaultYAxis, - type: logAxis ? 'log' : 'value', + type: logAxis ? AxisType.log : AxisType.value, min, max, minorTick: { show: true }, diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts index 946d41ec164d8..71729bd0f11d0 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts @@ -94,3 +94,10 @@ export interface EchartsTimeseriesChartProps export type TimeseriesChartTransformedProps = EChartTransformedProps; + +export enum AxisType { + category = 'category', + value = 'value', + time = 'time', + log = 'log', +} diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/constants.ts b/superset-frontend/plugins/plugin-chart-echarts/src/constants.ts index 7dd823f644792..ec956a9591764 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/constants.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/constants.ts @@ -31,7 +31,7 @@ import { export const NULL_STRING = ''; export const TIMESERIES_CONSTANTS = { - gridOffsetRight: 40, + gridOffsetRight: 20, gridOffsetLeft: 20, gridOffsetTop: 20, gridOffsetBottom: 20, diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index ebb8e4d3f030c..9e1df28b4b28f 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -78,9 +78,12 @@ interface ResultSetState { alertIsOpen: boolean; } -const Styles = styled.div` +const ResultlessStyles = styled.div` position: relative; minheight: 100px; + [role='alert'] { + margin-top: ${({ theme }) => theme.gridUnit * 2}px; + } .sql-result-track-job { margin-top: ${({ theme }) => theme.gridUnit * 2}px; } @@ -113,10 +116,6 @@ const ResultSetButtons = styled.div` padding-right: ${({ theme }) => 2 * theme.gridUnit}px; `; -const ResultSetErrorMessage = styled.div` - padding-top: ${({ theme }) => 4 * theme.gridUnit}px; -`; - export default class ResultSet extends React.PureComponent< ResultSetProps, ResultSetState @@ -445,7 +444,7 @@ export default class ResultSet extends React.PureComponent< } if (query.state === 'failed') { return ( - + {trackingUrl} - + ); } if (query.state === 'success' && query.ctas) { @@ -586,7 +585,7 @@ export default class ResultSet extends React.PureComponent< : null; return ( - +
{!progressBar && }
{/* show loading bar whenever progress bar is completed but needs time to render */}
{query.progress === 100 && }
@@ -596,7 +595,7 @@ export default class ResultSet extends React.PureComponent<
{query.progress !== 100 && progressBar}
{trackingUrl &&
{trackingUrl}
} -
+ ); } } diff --git a/superset-frontend/src/components/Datasource/CollectionTable.tsx b/superset-frontend/src/components/Datasource/CollectionTable.tsx index 194d3765792c9..e0eb44e8453d2 100644 --- a/superset-frontend/src/components/Datasource/CollectionTable.tsx +++ b/superset-frontend/src/components/Datasource/CollectionTable.tsx @@ -33,6 +33,14 @@ interface CRUDCollectionProps { expandFieldset?: ReactNode; extraButtons?: ReactNode; itemGenerator?: () => any; + itemCellProps?: (( + val: unknown, + label: string, + record: any, + ) => React.DetailedHTMLProps< + React.TdHTMLAttributes, + HTMLTableCellElement + >)[]; itemRenderers?: (( val: unknown, onChange: () => void, @@ -335,6 +343,12 @@ export default class CRUDCollection extends React.PureComponent< ); } + getCellProps(record: any, col: any) { + const cellPropsFn = this.props.itemCellProps?.[col]; + const val = record[col]; + return cellPropsFn ? cellPropsFn(val, this.getLabel(col), record) : {}; + } + renderCell(record: any, col: any) { const renderer = this.props.itemRenderers && this.props.itemRenderers[col]; const val = record[col]; @@ -366,7 +380,9 @@ export default class CRUDCollection extends React.PureComponent< } tds = tds.concat( tableColumns.map(col => ( - {this.renderCell(record, col)} + + {this.renderCell(record, col)} + )), ); if (allowAddItem) { diff --git a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx index a64abaf0dc80a..bf524c22101f5 100644 --- a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx +++ b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx @@ -235,6 +235,7 @@ function ColumnCollectionTable({ } /> @@ -848,7 +849,11 @@ class DatasourceEditor extends React.PureComponent { fieldKey="description" label={t('Description')} control={ - + } /> } /> @@ -901,6 +907,7 @@ class DatasourceEditor extends React.PureComponent { controlId="extra" language="json" offerEditInModal={false} + resize="vertical" /> } /> @@ -1081,6 +1088,7 @@ class DatasourceEditor extends React.PureComponent { minLines={20} maxLines={20} readOnly={!this.state.isEditMode} + resize="both" /> } /> @@ -1233,6 +1241,7 @@ class DatasourceEditor extends React.PureComponent { controlId="warning_markdown" language="markdown" offerEditInModal={false} + resize="vertical" /> } /> @@ -1247,6 +1256,11 @@ class DatasourceEditor extends React.PureComponent { verbose_name: '', expression: '', })} + itemCellProps={{ + expression: () => ({ + width: '240px', + }), + }} itemRenderers={{ metric_name: (v, onChange, _, record) => ( @@ -1276,6 +1290,8 @@ class DatasourceEditor extends React.PureComponent { language="sql" offerEditInModal={false} minLines={5} + textAreaStyles={{ minWidth: '200px', maxWidth: '450px' }} + resize="both" /> ), description: (v, onChange, label) => ( diff --git a/superset-frontend/src/explore/components/controls/AnnotationLayerControl/index.jsx b/superset-frontend/src/explore/components/controls/AnnotationLayerControl/index.jsx index f1381abee1aa4..db3bbca272319 100644 --- a/superset-frontend/src/explore/components/controls/AnnotationLayerControl/index.jsx +++ b/superset-frontend/src/explore/components/controls/AnnotationLayerControl/index.jsx @@ -26,7 +26,9 @@ import AsyncEsmComponent from 'src/components/AsyncEsmComponent'; import { getChartKey } from 'src/explore/exploreUtils'; import { runAnnotationQuery } from 'src/components/Chart/chartAction'; import CustomListItem from 'src/explore/components/controls/CustomListItem'; -import ControlPopover from '../ControlPopover/ControlPopover'; +import ControlPopover, { + getSectionContainerElement, +} from '../ControlPopover/ControlPopover'; const AnnotationLayer = AsyncEsmComponent( () => import('./AnnotationLayer'), @@ -114,6 +116,11 @@ class AnnotationLayerControl extends React.PureComponent { removeAnnotationLayer(annotation) { const annotations = this.props.value.filter(anno => anno !== annotation); + // So scrollbar doesnt get stuck on hidden + const element = getSectionContainerElement(); + if (element) { + element.style.setProperty('overflow-y', 'auto', 'important'); + } this.props.onChange(annotations); } diff --git a/superset-frontend/src/explore/components/controls/ControlPopover/ControlPopover.tsx b/superset-frontend/src/explore/components/controls/ControlPopover/ControlPopover.tsx index f84194c43caaa..28dd6e2bd2a23 100644 --- a/superset-frontend/src/explore/components/controls/ControlPopover/ControlPopover.tsx +++ b/superset-frontend/src/explore/components/controls/ControlPopover/ControlPopover.tsx @@ -24,7 +24,7 @@ import Popover, { } from 'src/components/Popover'; const sectionContainerId = 'controlSections'; -const getSectionContainerElement = () => +export const getSectionContainerElement = () => document.getElementById(sectionContainerId)?.lastElementChild as HTMLElement; const getElementYVisibilityRatioOnContainer = (node: HTMLElement) => { diff --git a/superset-frontend/src/explore/components/controls/TextAreaControl.jsx b/superset-frontend/src/explore/components/controls/TextAreaControl.jsx index e371061fbe556..48582c4bc757c 100644 --- a/superset-frontend/src/explore/components/controls/TextAreaControl.jsx +++ b/superset-frontend/src/explore/components/controls/TextAreaControl.jsx @@ -45,6 +45,16 @@ const propTypes = { ]), aboveEditorSection: PropTypes.node, readOnly: PropTypes.bool, + resize: PropTypes.oneOf([ + null, + 'block', + 'both', + 'horizontal', + 'inline', + 'none', + 'vertical', + ]), + textAreaStyles: PropTypes.object, }; const defaultProps = { @@ -55,6 +65,8 @@ const defaultProps = { maxLines: 10, offerEditInModal: true, readOnly: false, + resize: null, + textAreaStyles: {}, }; class TextAreaControl extends React.Component { @@ -72,18 +84,23 @@ class TextAreaControl extends React.Component { if (this.props.language) { const style = { border: `1px solid ${this.props.theme.colors.grayscale.light1}`, + minHeight: `${minLines}em`, + width: 'auto', + ...this.props.textAreaStyles, }; + if (this.props.resize) { + style.resize = this.props.resize; + } if (this.props.readOnly) { style.backgroundColor = '#f2f2f2'; } + return ( + <>
{this.props.aboveEditorSection}
{this.renderEditor(true)} - + ); } diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 658dbb535e9ee..c85f5f1b47828 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -63,7 +63,6 @@ def run(self) -> Model: security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) - security_manager.add_permission_view_menu("database_access", database.perm) db.session.commit() except DAOCreateFailedError as ex: db.session.rollback() diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index fadc8ba254c20..d0e50bbe2945e 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -44,10 +44,13 @@ def __init__(self, model_id: int, data: Dict[str, Any]): def run(self) -> Model: self.validate() + if not self._model: + raise DatabaseNotFoundError() + old_database_name = self._model.database_name + try: database = DatabaseDAO.update(self._model, self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - security_manager.add_permission_view_menu("database_access", database.perm) # adding a new database we always want to force refresh schema list # TODO Improve this simplistic implementation for catching DB conn fails try: @@ -55,7 +58,24 @@ def run(self) -> Model: except Exception as ex: db.session.rollback() raise DatabaseConnectionFailedError() from ex + # Update database schema permissions + new_schemas: List[str] = [] for schema in schemas: + old_view_menu_name = security_manager.get_schema_perm( + old_database_name, schema + ) + new_view_menu_name = security_manager.get_schema_perm( + database.database_name, schema + ) + schema_pvm = security_manager.find_permission_view_menu( + "schema_access", old_view_menu_name + ) + # Update the schema permission if the database name changed + if schema_pvm and old_database_name != database.database_name: + schema_pvm.view_menu.name = new_view_menu_name + else: + new_schemas.append(schema) + for schema in new_schemas: security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index bd30ad4ec5bd1..aa88822a854df 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -626,6 +626,7 @@ def fix_schemas_allowed_for_csv_upload( cost_estimate_enabled = fields.Boolean() allows_virtual_table_explore = fields.Boolean(required=False) cancel_query_on_windows_unload = fields.Boolean(required=False) + disable_data_preview = fields.Boolean(required=False) class ImportV1DatabaseSchema(Schema): diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 8a44da458f564..38646471d03c3 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -144,7 +144,7 @@ class ImportV1ColumnSchema(Schema): @pre_load def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: """ - Fix for extra initially beeing exported as a string. + Fix for extra initially being exported as a string. """ if isinstance(data.get("extra"), str): data["extra"] = json.loads(data["extra"]) @@ -170,7 +170,7 @@ class ImportV1MetricSchema(Schema): @pre_load def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: """ - Fix for extra initially beeing exported as a string. + Fix for extra initially being exported as a string. """ if isinstance(data.get("extra"), str): data["extra"] = json.loads(data["extra"]) @@ -192,7 +192,7 @@ class ImportV1DatasetSchema(Schema): @pre_load def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: """ - Fix for extra initially beeing exported as a string. + Fix for extra initially being exported as a string. """ if isinstance(data.get("extra"), str): data["extra"] = json.loads(data["extra"]) diff --git a/superset/migrations/versions/2022-07-05_15-48_409c7b420ab0_add_created_by_fk_as_owner.py b/superset/migrations/versions/2022-07-05_15-48_409c7b420ab0_add_created_by_fk_as_owner.py index 8992af1b57098..6cdf9f6891cbd 100644 --- a/superset/migrations/versions/2022-07-05_15-48_409c7b420ab0_add_created_by_fk_as_owner.py +++ b/superset/migrations/versions/2022-07-05_15-48_409c7b420ab0_add_created_by_fk_as_owner.py @@ -22,16 +22,17 @@ """ -# revision identifiers, used by Alembic. -revision = "409c7b420ab0" -down_revision = "a39867932713" - from alembic import op from sqlalchemy import and_, Column, insert, Integer from sqlalchemy.ext.declarative import declarative_base +# revision identifiers, used by Alembic. from superset import db +revision = "409c7b420ab0" +down_revision = "a39867932713" + + Base = declarative_base() @@ -95,7 +96,7 @@ def upgrade(): DatasetUser.user_id == Dataset.created_by_fk, ), ) - .filter(DatasetUser.dataset_id == None), + .filter(DatasetUser.dataset_id == None, Dataset.created_by_fk != None), ) ) diff --git a/superset/models/core.py b/superset/models/core.py index 3f5a5cb2f3ec8..617c23ef9edb1 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -795,7 +795,8 @@ def get_dialect(self) -> Dialect: sqla.event.listen(Database, "after_insert", security_manager.set_perm) -sqla.event.listen(Database, "after_update", security_manager.set_perm) +sqla.event.listen(Database, "after_update", security_manager.database_after_update) +sqla.event.listen(Database, "after_delete", security_manager.database_after_delete) class Log(Model): # pylint: disable=too-few-public-methods diff --git a/superset/security/manager.py b/superset/security/manager.py index dfa6ce1dd2123..a66e35e2d845a 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -56,7 +56,7 @@ from flask_appbuilder.widgets import ListWidget from flask_login import AnonymousUserMixin, LoginManager from jwt.api_jwt import _jwt_global_obj -from sqlalchemy import and_, or_ +from sqlalchemy import and_, inspect, or_ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import Session from sqlalchemy.orm.mapper import Mapper @@ -270,6 +270,14 @@ def get_schema_perm( # pylint: disable=no-self-use return None + @staticmethod + def get_database_perm(database_id: int, database_name: str) -> str: + return f"[{database_name}].(id:{database_id})" + + @staticmethod + def get_dataset_perm(dataset_id: int, dataset_name: str, database_name: str) -> str: + return f"[{database_name}].[{dataset_name}](id:{dataset_id})" + def unpack_database_and_schema( # pylint: disable=no-self-use self, schema_permission: str ) -> DatabaseAndSchema: @@ -933,6 +941,222 @@ def _is_granter_pvm( # pylint: disable=no-self-use return pvm.permission.name in {"can_override_role_permissions", "can_approve"} + def database_after_delete( + self, + mapper: Mapper, + connection: Connection, + target: "Database", + ) -> None: + self._delete_vm_database_access( + mapper, connection, target.id, target.database_name + ) + + def _delete_vm_database_access( + self, + mapper: Mapper, + connection: Connection, + database_id: int, + database_name: str, + ) -> None: + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + permission_view_menu_table = ( + self.permissionview_model.__table__ # pylint: disable=no-member + ) + view_menu_name = self.get_database_perm(database_id, database_name) + # Clean database access permission + db_pvm = self.find_permission_view_menu("database_access", view_menu_name) + if not db_pvm: + logger.warning( + "Could not find previous database permission %s", + view_menu_name, + ) + return + connection.execute( + permission_view_menu_table.delete().where( + permission_view_menu_table.c.id == db_pvm.id + ) + ) + self.on_permission_after_delete(mapper, connection, db_pvm) + connection.execute( + view_menu_table.delete().where(view_menu_table.c.id == db_pvm.view_menu_id) + ) + + # Clean database schema permissions + schema_pvms = ( + self.get_session.query(self.permissionview_model) + .join(self.permission_model) + .join(self.viewmenu_model) + .filter(self.permission_model.name == "schema_access") + .filter(self.viewmenu_model.name.like(f"[{database_name}].[%]")) + .all() + ) + for schema_pvm in schema_pvms: + connection.execute( + permission_view_menu_table.delete().where( + permission_view_menu_table.c.id == schema_pvm.id + ) + ) + self.on_permission_after_delete(mapper, connection, schema_pvm) + connection.execute( + view_menu_table.delete().where( + view_menu_table.c.id == schema_pvm.view_menu_id + ) + ) + + def _update_vm_database_access( + self, + mapper: Mapper, + connection: Connection, + old_database_name: str, + target: "Database", + ) -> Optional[ViewMenu]: + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + new_database_name = target.database_name + old_view_menu_name = self.get_database_perm(target.id, old_database_name) + new_view_menu_name = self.get_database_perm(target.id, new_database_name) + db_pvm = self.find_permission_view_menu("database_access", old_view_menu_name) + if not db_pvm: + logger.warning( + "Could not find previous database permission %s", + old_view_menu_name, + ) + return None + new_updated_pvm = self.find_permission_view_menu( + "database_access", new_view_menu_name + ) + if new_updated_pvm: + logger.info( + "New permission [%s] already exists, deleting the previous", + new_view_menu_name, + ) + self._delete_vm_database_access( + mapper, connection, target.id, old_database_name + ) + return None + connection.execute( + view_menu_table.update() + .where(view_menu_table.c.id == db_pvm.view_menu_id) + .values(name=new_view_menu_name) + ) + new_db_view_menu = self.find_view_menu(new_view_menu_name) + + self.on_view_menu_after_update(mapper, connection, new_db_view_menu) + return new_db_view_menu + + def _update_vm_datasources_access( # pylint: disable=too-many-locals + self, + mapper: Mapper, + connection: Connection, + old_database_name: str, + target: "Database", + ) -> List[ViewMenu]: + """ + Updates all datasource access permission when a database name changes + + :param connection: Current connection (called on SQLAlchemy event listener scope) + :param old_database_name: the old database name + :param target: The new database name + :return: A list of changed view menus (permission resource names) + """ + from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel + SqlaTable, + ) + from superset.models.slice import ( # pylint: disable=import-outside-toplevel + Slice, + ) + + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member + chart_table = Slice.__table__ # pylint: disable=no-member + new_database_name = target.database_name + datasets = ( + self.get_session.query(SqlaTable) + .filter(SqlaTable.database_id == target.id) + .all() + ) + updated_view_menus: List[ViewMenu] = [] + for dataset in datasets: + old_dataset_vm_name = self.get_dataset_perm( + dataset.id, dataset.table_name, old_database_name + ) + new_dataset_vm_name = self.get_dataset_perm( + dataset.id, dataset.table_name, new_database_name + ) + new_dataset_view_menu = self.find_view_menu(new_dataset_vm_name) + if new_dataset_view_menu: + continue + connection.execute( + view_menu_table.update() + .where(view_menu_table.c.name == old_dataset_vm_name) + .values(name=new_dataset_vm_name) + ) + # Update dataset (SqlaTable perm field) + connection.execute( + sqlatable_table.update() + .where( + sqlatable_table.c.id == dataset.id, + sqlatable_table.c.perm == old_dataset_vm_name, + ) + .values(perm=new_dataset_vm_name) + ) + # Update charts (Slice perm field) + connection.execute( + chart_table.update() + .where(chart_table.c.perm == old_dataset_vm_name) + .values(perm=new_dataset_vm_name) + ) + self.on_view_menu_after_update(mapper, connection, new_dataset_view_menu) + updated_view_menus.append(self.find_view_menu(new_dataset_view_menu)) + return updated_view_menus + + def database_after_update( + self, + mapper: Mapper, + connection: Connection, + target: "Database", + ) -> None: + # Check if database name has changed + state = inspect(target) + history = state.get_history("database_name", True) + if not history.has_changes() or not history.deleted: + return + + old_database_name = history.deleted[0] + # update database access permission + self._update_vm_database_access(mapper, connection, old_database_name, target) + # update datasource access + self._update_vm_datasources_access( + mapper, connection, old_database_name, target + ) + + def on_view_menu_after_update( + self, mapper: Mapper, connection: Connection, target: ViewMenu + ) -> None: + """ + Hook that allows for further custom operations when a new ViewMenu + is updated + + Since the update may be performed on after_update event. We cannot + update ViewMenus using a session, so any SQLAlchemy events hooked to + `ViewMenu` will not trigger an after_update. + + :param mapper: The table mapper + :param connection: The DB-API connection + :param target: The mapped instance being persisted + """ + + def on_permission_after_delete( + self, mapper: Mapper, connection: Connection, target: Permission + ) -> None: + """ + Hook that allows for further custom operations when a permission + is deleted by sqlalchemy events. + + :param mapper: The table mapper + :param connection: The DB-API connection + :param target: The mapped instance being persisted + """ + def on_permission_after_insert( self, mapper: Mapper, connection: Connection, target: Permission ) -> None: @@ -996,6 +1220,8 @@ def set_perm( except DatasetInvalidPermissionEvaluationException: logger.warning("Dataset has no database refusing to set permission") return + permission_table = self.permission_model.__table__ # pylint: disable=no-member + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member link_table = target.__table__ if target.perm != target_get_perm: connection.execute( @@ -1003,8 +1229,19 @@ def set_perm( .where(link_table.c.id == target.id) .values(perm=target_get_perm) ) + connection.execute( + permission_table.update() + .where(permission_table.c.name == target.perm) + .values(name=target_get_perm) + ) + connection.execute( + view_menu_table.update() + .where(view_menu_table.c.name == target.perm) + .values(name=target_get_perm) + ) target.perm = target_get_perm + # check schema perm for datasets if ( hasattr(target, "schema_perm") and target.schema_perm != target.get_schema_perm() @@ -1031,18 +1268,12 @@ def set_perm( pv = None if not permission: - permission_table = ( - self.permission_model.__table__ # pylint: disable=no-member - ) connection.execute( permission_table.insert().values(name=permission_name) ) permission = self.find_permission(permission_name) self.on_permission_after_insert(mapper, connection, permission) if not view_menu: - view_menu_table = ( - self.viewmenu_model.__table__ # pylint: disable=no-member - ) connection.execute(view_menu_table.insert().values(name=view_menu_name)) view_menu = self.find_view_menu(view_menu_name) self.on_view_menu_after_insert(mapper, connection, view_menu) diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py index f83c4521e707a..1042b8f920ac2 100644 --- a/superset/views/sql_lab.py +++ b/superset/views/sql_lab.py @@ -139,7 +139,9 @@ def post(self) -> FlaskResponse: # pylint: disable=no-self-use query_editor = json.loads(request.form["queryEditor"]) tab_state = TabState( user_id=get_user_id(), - label=query_editor.get("title", "Untitled Query"), + # This is for backward compatibility + label=query_editor.get("name") + or query_editor.get("title", "Untitled Query"), active=True, database_id=query_editor["dbId"], schema=query_editor.get("schema"), diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 476cf27aab223..ebb1e65e36f48 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -206,6 +206,7 @@ def test_set_perm_sqla_table(self): ) # table name change + orig_table_perm = stored_table.perm stored_table.table_name = "tmp_perm_table_v2" session.commit() stored_table = ( @@ -214,6 +215,11 @@ def test_set_perm_sqla_table(self): self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" ) + self.assertIsNone( + security_manager.find_permission_view_menu( + "datasource_access", orig_table_perm + ) + ) self.assertIsNotNone( security_manager.find_permission_view_menu( "datasource_access", stored_table.perm @@ -354,6 +360,186 @@ def test_set_perm_database(self): session.delete(stored_db) session.commit() + def test_after_update_database__perm_database_access(self): + session = db.session + database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.commit() + stored_db = ( + session.query(Database).filter_by(database_name="tmp_database").one() + ) + + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_db.perm + ) + ) + + stored_db.database_name = "tmp_database2" + session.commit() + + # Assert that the old permission was updated + self.assertIsNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_database].(id:{stored_db.id})" + ) + ) + # Assert that the db permission was updated + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_database2].(id:{stored_db.id})" + ) + ) + session.delete(stored_db) + session.commit() + + def test_after_update_database__perm_database_access_exists(self): + session = db.session + # Add a bogus existing permission before the change + + database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.commit() + stored_db = ( + session.query(Database).filter_by(database_name="tmp_database").one() + ) + security_manager.add_permission_view_menu( + "database_access", f"[tmp_database2].(id:{stored_db.id})" + ) + + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_db.perm + ) + ) + + stored_db.database_name = "tmp_database2" + session.commit() + + # Assert that the old permission was updated + self.assertIsNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_database].(id:{stored_db.id})" + ) + ) + # Assert that the db permission was updated + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_database2].(id:{stored_db.id})" + ) + ) + session.delete(stored_db) + session.commit() + + def test_after_update_database__perm_datasource_access(self): + session = db.session + database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=database, + ) + session.add(table1) + table2 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table2", + database=database, + ) + session.add(table2) + session.commit() + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + table2 = session.query(SqlaTable).filter_by(table_name="tmp_table2").one() + + # assert initial perms + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_database].[tmp_table1](id:{table1.id})" + ) + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_database].[tmp_table2](id:{table2.id})" + ) + ) + self.assertEqual(slice1.perm, f"[tmp_database].[tmp_table1](id:{table1.id})") + self.assertEqual(table1.perm, f"[tmp_database].[tmp_table1](id:{table1.id})") + self.assertEqual(table2.perm, f"[tmp_database].[tmp_table2](id:{table2.id})") + + stored_db = ( + session.query(Database).filter_by(database_name="tmp_database").one() + ) + stored_db.database_name = "tmp_database2" + session.commit() + + # Assert that the old permissions were updated + self.assertIsNone( + security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_database].[tmp_table1](id:{table1.id})" + ) + ) + self.assertIsNone( + security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_database].[tmp_table2](id:{table2.id})" + ) + ) + + # Assert that the db permission was updated + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_database2].[tmp_table1](id:{table1.id})" + ) + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_database2].[tmp_table2](id:{table2.id})" + ) + ) + self.assertEqual(slice1.perm, f"[tmp_database2].[tmp_table1](id:{table1.id})") + self.assertEqual(table1.perm, f"[tmp_database2].[tmp_table1](id:{table1.id})") + self.assertEqual(table2.perm, f"[tmp_database2].[tmp_table2](id:{table2.id})") + + session.delete(slice1) + session.delete(table1) + session.delete(table2) + session.delete(stored_db) + session.commit() + + def test_after_delete_database__perm_database_access(self): + session = db.session + database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.commit() + stored_db = ( + session.query(Database).filter_by(database_name="tmp_database").one() + ) + + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_db.perm + ) + ) + session.delete(stored_db) + session.commit() + + # Assert that the old permission was updated + self.assertIsNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_database].(id:{stored_db.id})" + ) + ) + def test_hybrid_perm_database(self): database = Database(database_name="tmp_database3", sqlalchemy_uri="sqlite://") diff --git a/tests/unit_tests/advanced_data_type/types_tests.py b/tests/unit_tests/advanced_data_type/types_tests.py index 82c9d8b29ad9c..189b9e1aab22d 100644 --- a/tests/unit_tests/advanced_data_type/types_tests.py +++ b/tests/unit_tests/advanced_data_type/types_tests.py @@ -17,11 +17,8 @@ # isort:skip_file """Unit tests for Superset""" -from ipaddress import ip_address import sqlalchemy -from flask.ctx import AppContext from sqlalchemy import Column, Integer -from tests.integration_tests.base_tests import SupersetTestCase from superset.advanced_data_type.types import ( AdvancedDataTypeRequest, AdvancedDataTypeResponse, @@ -36,7 +33,7 @@ # tox -e py38 -- tests/unit_tests/advanced_data_type/types_tests.py -def test_ip_func_valid_ip(app_context: None): +def test_ip_func_valid_ip(): """Test to see if the cidr_func behaves as expected when a valid IP is passed in""" cidr_request: AdvancedDataTypeRequest = { "advanced_data_type": "cidr", @@ -59,7 +56,7 @@ def test_ip_func_valid_ip(app_context: None): assert internet_address.translate_type(cidr_request) == cidr_response -def test_cidr_func_invalid_ip(app_context: None): +def test_cidr_func_invalid_ip(): """Test to see if the cidr_func behaves as expected when an invalid IP is passed in""" cidr_request: AdvancedDataTypeRequest = { "advanced_data_type": "cidr", @@ -82,7 +79,7 @@ def test_cidr_func_invalid_ip(app_context: None): assert internet_address.translate_type(cidr_request) == cidr_response -def test_port_translation_func_valid_port_number(app_context: None): +def test_port_translation_func_valid_port_number(): """Test to see if the port_translation_func behaves as expected when a valid port number is passed in""" port_request: AdvancedDataTypeRequest = { @@ -106,7 +103,7 @@ def test_port_translation_func_valid_port_number(app_context: None): assert port.translate_type(port_request) == port_response -def test_port_translation_func_valid_port_name(app_context: None): +def test_port_translation_func_valid_port_name(): """Test to see if the port_translation_func behaves as expected when a valid port name is passed in""" port_request: AdvancedDataTypeRequest = { @@ -130,7 +127,7 @@ def test_port_translation_func_valid_port_name(app_context: None): assert port.translate_type(port_request) == port_response -def test_port_translation_func_invalid_port_name(app_context: None): +def test_port_translation_func_invalid_port_name(): """Test to see if the port_translation_func behaves as expected when an invalid port name is passed in""" port_request: AdvancedDataTypeRequest = { @@ -154,7 +151,7 @@ def test_port_translation_func_invalid_port_name(app_context: None): assert port.translate_type(port_request) == port_response -def test_port_translation_func_invalid_port_number(app_context: None): +def test_port_translation_func_invalid_port_number(): """Test to see if the port_translation_func behaves as expected when an invalid port number is passed in""" port_request: AdvancedDataTypeRequest = { @@ -178,7 +175,7 @@ def test_port_translation_func_invalid_port_number(app_context: None): assert port.translate_type(port_request) == port_response -def test_cidr_translate_filter_func_equals(app_context: None): +def test_cidr_translate_filter_func_equals(): """Test to see if the cidr_translate_filter_func behaves as expected when the EQUALS operator is used""" @@ -193,7 +190,7 @@ def test_cidr_translate_filter_func_equals(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_not_equals(app_context: None): +def test_cidr_translate_filter_func_not_equals(): """Test to see if the cidr_translate_filter_func behaves as expected when the NOT_EQUALS operator is used""" @@ -208,7 +205,7 @@ def test_cidr_translate_filter_func_not_equals(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_greater_than_or_equals(app_context: None): +def test_cidr_translate_filter_func_greater_than_or_equals(): """Test to see if the cidr_translate_filter_func behaves as expected when the GREATER_THAN_OR_EQUALS operator is used""" @@ -225,7 +222,7 @@ def test_cidr_translate_filter_func_greater_than_or_equals(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_greater_than(app_context: None): +def test_cidr_translate_filter_func_greater_than(): """Test to see if the cidr_translate_filter_func behaves as expected when the GREATER_THAN operator is used""" @@ -242,7 +239,7 @@ def test_cidr_translate_filter_func_greater_than(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_less_than(app_context: None): +def test_cidr_translate_filter_func_less_than(): """Test to see if the cidr_translate_filter_func behaves as expected when the LESS_THAN operator is used""" @@ -259,7 +256,7 @@ def test_cidr_translate_filter_func_less_than(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_less_than_or_equals(app_context: None): +def test_cidr_translate_filter_func_less_than_or_equals(): """Test to see if the cidr_translate_filter_func behaves as expected when the LESS_THAN_OR_EQUALS operator is used""" @@ -276,7 +273,7 @@ def test_cidr_translate_filter_func_less_than_or_equals(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_in_single(app_context: None): +def test_cidr_translate_filter_func_in_single(): """Test to see if the cidr_translate_filter_func behaves as expected when the IN operator is used with a single IP""" @@ -293,7 +290,7 @@ def test_cidr_translate_filter_func_in_single(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_in_double(app_context: None): +def test_cidr_translate_filter_func_in_double(): """Test to see if the cidr_translate_filter_func behaves as expected when the IN operator is used with two IP's""" @@ -312,7 +309,7 @@ def test_cidr_translate_filter_func_in_double(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_not_in_single(app_context: None): +def test_cidr_translate_filter_func_not_in_single(): """Test to see if the cidr_translate_filter_func behaves as expected when the NOT_IN operator is used with a single IP""" @@ -329,7 +326,7 @@ def test_cidr_translate_filter_func_not_in_single(app_context: None): ).compare(cidr_translate_filter_response) -def test_cidr_translate_filter_func_not_in_double(app_context: None): +def test_cidr_translate_filter_func_not_in_double(): """Test to see if the cidr_translate_filter_func behaves as expected when the NOT_IN operator is used with two IP's""" @@ -348,7 +345,7 @@ def test_cidr_translate_filter_func_not_in_double(app_context: None): ).compare(cidr_translate_filter_response) -def test_port_translate_filter_func_equals(app_context: None): +def test_port_translate_filter_func_equals(): """Test to see if the port_translate_filter_func behaves as expected when the EQUALS operator is used""" @@ -365,7 +362,7 @@ def test_port_translate_filter_func_equals(app_context: None): ) -def test_port_translate_filter_func_not_equals(app_context: None): +def test_port_translate_filter_func_not_equals(): """Test to see if the port_translate_filter_func behaves as expected when the NOT_EQUALS operator is used""" @@ -382,7 +379,7 @@ def test_port_translate_filter_func_not_equals(app_context: None): ) -def test_port_translate_filter_func_greater_than_or_equals(app_context: None): +def test_port_translate_filter_func_greater_than_or_equals(): """Test to see if the port_translate_filter_func behaves as expected when the GREATER_THAN_OR_EQUALS operator is used""" @@ -399,7 +396,7 @@ def test_port_translate_filter_func_greater_than_or_equals(app_context: None): ) -def test_port_translate_filter_func_greater_than(app_context: None): +def test_port_translate_filter_func_greater_than(): """Test to see if the port_translate_filter_func behaves as expected when the GREATER_THAN operator is used""" @@ -416,7 +413,7 @@ def test_port_translate_filter_func_greater_than(app_context: None): ) -def test_port_translate_filter_func_less_than_or_equals(app_context: None): +def test_port_translate_filter_func_less_than_or_equals(): """Test to see if the port_translate_filter_func behaves as expected when the LESS_THAN_OR_EQUALS operator is used""" @@ -433,7 +430,7 @@ def test_port_translate_filter_func_less_than_or_equals(app_context: None): ) -def test_port_translate_filter_func_less_than(app_context: None): +def test_port_translate_filter_func_less_than(): """Test to see if the port_translate_filter_func behaves as expected when the LESS_THAN operator is used""" @@ -450,7 +447,7 @@ def test_port_translate_filter_func_less_than(app_context: None): ) -def test_port_translate_filter_func_in_single(app_context: None): +def test_port_translate_filter_func_in_single(): """Test to see if the port_translate_filter_func behaves as expected when the IN operator is used with a single port""" @@ -467,7 +464,7 @@ def test_port_translate_filter_func_in_single(app_context: None): ) -def test_port_translate_filter_func_in_double(app_context: None): +def test_port_translate_filter_func_in_double(): """Test to see if the port_translate_filter_func behaves as expected when the IN operator is used with two ports""" @@ -484,7 +481,7 @@ def test_port_translate_filter_func_in_double(app_context: None): ) -def test_port_translate_filter_func_not_in_single(app_context: None): +def test_port_translate_filter_func_not_in_single(): """Test to see if the port_translate_filter_func behaves as expected when the NOT_IN operator is used with a single port""" @@ -501,7 +498,7 @@ def test_port_translate_filter_func_not_in_single(app_context: None): ) -def test_port_translate_filter_func_not_in_double(app_context: None): +def test_port_translate_filter_func_not_in_double(): """Test to see if the port_translate_filter_func behaves as expected when the NOT_IN operator is used with two ports""" diff --git a/tests/unit_tests/charts/commands/importers/v1/import_test.py b/tests/unit_tests/charts/commands/importers/v1/import_test.py index e8687036394cb..e29fd70fb8a70 100644 --- a/tests/unit_tests/charts/commands/importers/v1/import_test.py +++ b/tests/unit_tests/charts/commands/importers/v1/import_test.py @@ -21,7 +21,7 @@ from sqlalchemy.orm.session import Session -def test_import_chart(app_context: None, session: Session) -> None: +def test_import_chart(session: Session) -> None: """ Test importing a chart. """ @@ -45,7 +45,7 @@ def test_import_chart(app_context: None, session: Session) -> None: assert chart.external_url is None -def test_import_chart_managed_externally(app_context: None, session: Session) -> None: +def test_import_chart_managed_externally(session: Session) -> None: """ Test importing a chart that is managed externally. """ diff --git a/tests/unit_tests/columns/test_models.py b/tests/unit_tests/columns/test_models.py index 40cc2075d380e..068557e7a6a7f 100644 --- a/tests/unit_tests/columns/test_models.py +++ b/tests/unit_tests/columns/test_models.py @@ -20,7 +20,7 @@ from sqlalchemy.orm.session import Session -def test_column_model(app_context: None, session: Session) -> None: +def test_column_model(session: Session) -> None: """ Test basic attributes of a ``Column``. """ diff --git a/tests/unit_tests/commands/export_test.py b/tests/unit_tests/commands/export_test.py index 91aebf1b684eb..24fa491664042 100644 --- a/tests/unit_tests/commands/export_test.py +++ b/tests/unit_tests/commands/export_test.py @@ -20,7 +20,7 @@ from pytest_mock import MockFixture -def test_export_assets_command(mocker: MockFixture, app_context: None) -> None: +def test_export_assets_command(mocker: MockFixture) -> None: """ Test that all assets are exported correctly. """ diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py index 2ec81a2b8e62f..021193a6cd36e 100644 --- a/tests/unit_tests/config_test.py +++ b/tests/unit_tests/config_test.py @@ -74,7 +74,7 @@ def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: Dict[str, Any]) -> No @pytest.fixture -def test_table(app_context: None, session: Session) -> "SqlaTable": +def test_table(session: Session) -> "SqlaTable": """ Fixture that generates an in-memory table. """ diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index c98b09ac5af4f..817dc79c56958 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -102,7 +102,7 @@ def client(app: SupersetApp) -> Any: yield client -@pytest.fixture +@pytest.fixture(autouse=True) def app_context(app: SupersetApp) -> Iterator[None]: """ A fixture that yields and application context. diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py index 8df6d2066aaca..8e2a458434cd9 100644 --- a/tests/unit_tests/dao/queries_test.py +++ b/tests/unit_tests/dao/queries_test.py @@ -21,7 +21,7 @@ from sqlalchemy.orm.session import Session -def test_query_dao_save_metadata(app_context: None, session: Session) -> None: +def test_query_dao_save_metadata(session: Session) -> None: from superset.models.core import Database from superset.models.sql_lab import Query diff --git a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py index 651e5dc10b7ca..08f681d916b3c 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py @@ -21,7 +21,7 @@ from sqlalchemy.orm.session import Session -def test_import_dashboard(app_context: None, session: Session) -> None: +def test_import_dashboard(session: Session) -> None: """ Test importing a dashboard. """ @@ -43,9 +43,7 @@ def test_import_dashboard(app_context: None, session: Session) -> None: assert dashboard.external_url is None -def test_import_dashboard_managed_externally( - app_context: None, session: Session -) -> None: +def test_import_dashboard_managed_externally(session: Session) -> None: """ Test importing a dashboard that is managed externally. """ diff --git a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py index bddc96eda36e6..0392acb31596a 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py @@ -82,7 +82,7 @@ def test_update_id_refs_immune_missing( # pylint: disable=invalid-name } -def test_update_native_filter_config_scope_excluded(app_context: None): +def test_update_native_filter_config_scope_excluded(): from superset.dashboards.commands.importers.v1.utils import update_id_refs config = { diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py index 622aa27fc3d56..f482e16d8a685 100644 --- a/tests/unit_tests/databases/commands/importers/v1/import_test.py +++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py @@ -21,7 +21,7 @@ from sqlalchemy.orm.session import Session -def test_import_database(app_context: None, session: Session) -> None: +def test_import_database(session: Session) -> None: """ Test importing a database. """ @@ -48,9 +48,7 @@ def test_import_database(app_context: None, session: Session) -> None: assert database.external_url is None -def test_import_database_managed_externally( - app_context: None, session: Session -) -> None: +def test_import_database_managed_externally(session: Session) -> None: """ Test importing a database that is managed externally. """ diff --git a/tests/unit_tests/databases/utils_test.py b/tests/unit_tests/databases/utils_test.py index 8dbc11a3b7a70..e402ced2a529f 100644 --- a/tests/unit_tests/databases/utils_test.py +++ b/tests/unit_tests/databases/utils_test.py @@ -21,7 +21,7 @@ from superset.databases.utils import make_url_safe -def test_make_url_safe_string(app_context: None, session: Session) -> None: +def test_make_url_safe_string(session: Session) -> None: """ Test converting a string to a safe uri """ @@ -31,7 +31,7 @@ def test_make_url_safe_string(app_context: None, session: Session) -> None: assert uri_safe == make_url(uri_string) -def test_make_url_safe_url(app_context: None, session: Session) -> None: +def test_make_url_safe_url(session: Session) -> None: """ Test converting a url to a safe uri """ diff --git a/tests/unit_tests/dataframe_test.py b/tests/unit_tests/dataframe_test.py index 8785ee1d7b3ec..016d2f4d9bae4 100644 --- a/tests/unit_tests/dataframe_test.py +++ b/tests/unit_tests/dataframe_test.py @@ -24,7 +24,7 @@ from superset.superset_typing import DbapiDescription -def test_df_to_records(app_context: None) -> None: +def test_df_to_records() -> None: from superset.db_engine_specs import BaseEngineSpec from superset.result_set import SupersetResultSet @@ -41,7 +41,7 @@ def test_df_to_records(app_context: None) -> None: ] -def test_js_max_int(app_context: None) -> None: +def test_js_max_int() -> None: from superset.db_engine_specs import BaseEngineSpec from superset.result_set import SupersetResultSet diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py index c7e6da8649e7c..aa444b5fd734a 100644 --- a/tests/unit_tests/datasets/commands/export_test.py +++ b/tests/unit_tests/datasets/commands/export_test.py @@ -21,7 +21,7 @@ from sqlalchemy.orm.session import Session -def test_export(app_context: None, session: Session) -> None: +def test_export(session: Session) -> None: """ Test exporting a dataset. """ diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index 164f7f83e93ea..934712b8c9b7e 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -24,7 +24,7 @@ from sqlalchemy.orm.session import Session -def test_import_dataset(app_context: None, session: Session) -> None: +def test_import_dataset(session: Session) -> None: """ Test importing a dataset. """ @@ -137,7 +137,7 @@ def test_import_dataset(app_context: None, session: Session) -> None: assert sqla_table.database.id == database.id -def test_import_dataset_duplicate_column(app_context: None, session: Session) -> None: +def test_import_dataset_duplicate_column(session: Session) -> None: """ Test importing a dataset with a column that already exists. """ @@ -260,7 +260,7 @@ def test_import_dataset_duplicate_column(app_context: None, session: Session) -> assert sqla_table.database.id == database.id -def test_import_column_extra_is_string(app_context: None, session: Session) -> None: +def test_import_column_extra_is_string(session: Session) -> None: """ Test importing a dataset when the column extra is a string. """ @@ -340,7 +340,7 @@ def test_import_column_extra_is_string(app_context: None, session: Session) -> N assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}' -def test_import_dataset_managed_externally(app_context: None, session: Session) -> None: +def test_import_dataset_managed_externally(session: Session) -> None: """ Test importing a dataset that is managed externally. """ diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py index 961ee7c543639..771bb0d0e179a 100644 --- a/tests/unit_tests/datasets/test_models.py +++ b/tests/unit_tests/datasets/test_models.py @@ -27,7 +27,7 @@ from superset.connectors.sqla.models import SqlMetric, TableColumn -def test_dataset_model(app_context: None, session: Session) -> None: +def test_dataset_model(session: Session) -> None: """ Test basic attributes of a ``Dataset``. """ @@ -86,7 +86,7 @@ def test_dataset_model(app_context: None, session: Session) -> None: assert [column.name for column in dataset.columns] == ["position"] -def test_cascade_delete_table(app_context: None, session: Session) -> None: +def test_cascade_delete_table(session: Session) -> None: """ Test that deleting ``Table`` also deletes its columns. """ @@ -121,7 +121,7 @@ def test_cascade_delete_table(app_context: None, session: Session) -> None: assert len(columns) == 0 -def test_cascade_delete_dataset(app_context: None, session: Session) -> None: +def test_cascade_delete_dataset(session: Session) -> None: """ Test that deleting ``Dataset`` also deletes its columns. """ @@ -175,7 +175,7 @@ def test_cascade_delete_dataset(app_context: None, session: Session) -> None: assert len(columns) == 2 -def test_dataset_attributes(app_context: None, session: Session) -> None: +def test_dataset_attributes(session: Session) -> None: """ Test that checks attributes in the dataset. @@ -649,7 +649,7 @@ def test_create_virtual_sqlatable( } -def test_delete_sqlatable(app_context: None, session: Session) -> None: +def test_delete_sqlatable(session: Session) -> None: """ Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``. """ @@ -689,7 +689,7 @@ def test_delete_sqlatable(app_context: None, session: Session) -> None: def test_update_physical_sqlatable_columns( - mocker: MockFixture, app_context: None, session: Session + mocker: MockFixture, session: Session ) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. @@ -765,7 +765,7 @@ def test_update_physical_sqlatable_columns( def test_update_physical_sqlatable_schema( - mocker: MockFixture, app_context: None, session: Session + mocker: MockFixture, session: Session ) -> None: """ Test that updating a ``SqlaTable`` schema also updates the corresponding ``Dataset``. @@ -1046,7 +1046,7 @@ def test_update_physical_sqlatable_database( def test_update_virtual_sqlatable_references( - mocker: MockFixture, app_context: None, session: Session + mocker: MockFixture, session: Session ) -> None: """ Test that changing the SQL of a virtual ``SqlaTable`` updates ``Dataset``. @@ -1122,7 +1122,7 @@ def test_update_virtual_sqlatable_references( assert new_dataset.tables[2].name == "table_c" -def test_quote_expressions(app_context: None, session: Session) -> None: +def test_quote_expressions(session: Session) -> None: """ Test that expressions are quoted appropriately in columns and datasets. """ diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index 0682c19c28756..8647f97f747c2 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -99,9 +99,7 @@ def session_with_data(session: Session) -> Iterator[Session]: yield session -def test_get_datasource_sqlatable( - app_context: None, session_with_data: Session -) -> None: +def test_get_datasource_sqlatable(session_with_data: Session) -> None: from superset.connectors.sqla.models import SqlaTable from superset.datasource.dao import DatasourceDAO @@ -116,7 +114,7 @@ def test_get_datasource_sqlatable( assert isinstance(result, SqlaTable) -def test_get_datasource_query(app_context: None, session_with_data: Session) -> None: +def test_get_datasource_query(session_with_data: Session) -> None: from superset.datasource.dao import DatasourceDAO from superset.models.sql_lab import Query @@ -128,9 +126,7 @@ def test_get_datasource_query(app_context: None, session_with_data: Session) -> assert isinstance(result, Query) -def test_get_datasource_saved_query( - app_context: None, session_with_data: Session -) -> None: +def test_get_datasource_saved_query(session_with_data: Session) -> None: from superset.datasource.dao import DatasourceDAO from superset.models.sql_lab import SavedQuery @@ -144,7 +140,7 @@ def test_get_datasource_saved_query( assert isinstance(result, SavedQuery) -def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None: +def test_get_datasource_sl_table(session_with_data: Session) -> None: from superset.datasource.dao import DatasourceDAO from superset.tables.models import Table @@ -160,9 +156,7 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session) assert isinstance(result, Table) -def test_get_datasource_sl_dataset( - app_context: None, session_with_data: Session -) -> None: +def test_get_datasource_sl_dataset(session_with_data: Session) -> None: from superset.datasets.models import Dataset from superset.datasource.dao import DatasourceDAO @@ -178,9 +172,7 @@ def test_get_datasource_sl_dataset( assert isinstance(result, Dataset) -def test_get_datasource_w_str_param( - app_context: None, session_with_data: Session -) -> None: +def test_get_datasource_w_str_param(session_with_data: Session) -> None: from superset.connectors.sqla.models import SqlaTable from superset.datasets.models import Dataset from superset.datasource.dao import DatasourceDAO @@ -205,7 +197,7 @@ def test_get_datasource_w_str_param( ) -def test_get_all_datasources(app_context: None, session_with_data: Session) -> None: +def test_get_all_datasources(session_with_data: Session) -> None: from superset.connectors.sqla.models import SqlaTable result = SqlaTable.get_all_datasources(session=session_with_data) diff --git a/tests/unit_tests/db_engine_specs/test_athena.py b/tests/unit_tests/db_engine_specs/test_athena.py index d7c1a3f606fca..a1243ac097b6a 100644 --- a/tests/unit_tests/db_engine_specs/test_athena.py +++ b/tests/unit_tests/db_engine_specs/test_athena.py @@ -18,8 +18,6 @@ import re from datetime import datetime -from flask.ctx import AppContext - from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from tests.unit_tests.fixtures.common import dttm @@ -28,7 +26,7 @@ ) -def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None: +def test_convert_dttm(dttm: datetime) -> None: """ Test that date objects are converted correctly. """ @@ -43,7 +41,7 @@ def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None: ) -def test_extract_errors(app_context: AppContext) -> None: +def test_extract_errors() -> None: """ Test that custom error messages are extracted correctly. """ @@ -70,7 +68,7 @@ def test_extract_errors(app_context: AppContext) -> None: ] -def test_get_text_clause_with_colon(app_context: AppContext) -> None: +def test_get_text_clause_with_colon() -> None: """ Make sure text clauses don't escape the colon character """ diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index b112e2cec8ef4..79a83c6b09d52 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -19,11 +19,10 @@ from textwrap import dedent import pytest -from flask.ctx import AppContext from sqlalchemy.types import TypeEngine -def test_get_text_clause_with_colon(app_context: AppContext) -> None: +def test_get_text_clause_with_colon() -> None: """ Make sure text clauses are correctly escaped """ @@ -36,7 +35,7 @@ def test_get_text_clause_with_colon(app_context: AppContext) -> None: assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')" -def test_parse_sql_single_statement(app_context: AppContext) -> None: +def test_parse_sql_single_statement() -> None: """ `parse_sql` should properly strip leading and trailing spaces and semicolons """ @@ -47,7 +46,7 @@ def test_parse_sql_single_statement(app_context: AppContext) -> None: assert queries == ["SELECT foo FROM tbl"] -def test_parse_sql_multi_statement(app_context: AppContext) -> None: +def test_parse_sql_multi_statement() -> None: """ For string with multiple SQL-statements `parse_sql` method should return list where each element represents the single SQL-statement @@ -95,9 +94,7 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None: ), ], ) -def test_cte_query_parsing( - app_context: AppContext, original: TypeEngine, expected: str -) -> None: +def test_cte_query_parsing(original: TypeEngine, expected: str) -> None: from superset.db_engine_specs.base import BaseEngineSpec actual = BaseEngineSpec.get_cte_query(original) diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index a4a6f706ceab9..292ea94a7b52c 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -16,14 +16,13 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access -from flask.ctx import AppContext from pybigquery.sqlalchemy_bigquery import BigQueryDialect from pytest_mock import MockFixture from sqlalchemy import select from sqlalchemy.sql import sqltypes -def test_get_fields(app_context: AppContext) -> None: +def test_get_fields() -> None: """ Test the custom ``_get_fields`` method. @@ -66,7 +65,7 @@ def test_get_fields(app_context: AppContext) -> None: ) -def test_select_star(mocker: MockFixture, app_context: AppContext) -> None: +def test_select_star(mocker: MockFixture) -> None: """ Test the ``select_star`` method. diff --git a/tests/unit_tests/db_engine_specs/test_drill.py b/tests/unit_tests/db_engine_specs/test_drill.py index a7f0720f29de2..195ad8aca2f59 100644 --- a/tests/unit_tests/db_engine_specs/test_drill.py +++ b/tests/unit_tests/db_engine_specs/test_drill.py @@ -16,11 +16,10 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access -from flask.ctx import AppContext from pytest import raises -def test_odbc_impersonation(app_context: AppContext) -> None: +def test_odbc_impersonation() -> None: """ Test ``get_url_for_impersonation`` method when driver == odbc. @@ -36,7 +35,7 @@ def test_odbc_impersonation(app_context: AppContext) -> None: assert url.query["DelegationUID"] == username -def test_jdbc_impersonation(app_context: AppContext) -> None: +def test_jdbc_impersonation() -> None: """ Test ``get_url_for_impersonation`` method when driver == jdbc. @@ -52,7 +51,7 @@ def test_jdbc_impersonation(app_context: AppContext) -> None: assert url.query["impersonation_target"] == username -def test_sadrill_impersonation(app_context: AppContext) -> None: +def test_sadrill_impersonation() -> None: """ Test ``get_url_for_impersonation`` method when driver == sadrill. @@ -68,7 +67,7 @@ def test_sadrill_impersonation(app_context: AppContext) -> None: assert url.query["impersonation_target"] == username -def test_invalid_impersonation(app_context: AppContext) -> None: +def test_invalid_impersonation() -> None: """ Test ``get_url_for_impersonation`` method when driver == foobar. diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index c2e8346c3c7ac..61c09b63c08ce 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from flask.ctx import AppContext from pytest_mock import MockFixture from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -28,7 +27,6 @@ class ProgrammingError(Exception): def test_validate_parameters_simple( mocker: MockFixture, - app_context: AppContext, ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, @@ -52,7 +50,6 @@ def test_validate_parameters_simple( def test_validate_parameters_catalog( mocker: MockFixture, - app_context: AppContext, ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, @@ -143,7 +140,6 @@ def test_validate_parameters_catalog( def test_validate_parameters_catalog_and_credentials( mocker: MockFixture, - app_context: AppContext, ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index fca6ee5817de1..e556418a89282 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -18,7 +18,6 @@ from datetime import datetime import pytest -from flask.ctx import AppContext from tests.unit_tests.fixtures.common import dttm @@ -32,9 +31,7 @@ ("INSERT INTO tbl (foo) VALUES (1)", False), ], ) -def test_sql_is_readonly_query( - app_context: AppContext, sql: str, expected: bool -) -> None: +def test_sql_is_readonly_query(sql: str, expected: bool) -> None: """ Make sure that SQL dialect consider only SELECT statements as read-only """ @@ -56,7 +53,7 @@ def test_sql_is_readonly_query( (".show tables", False), ], ) -def test_kql_is_select_query(app_context: AppContext, kql: str, expected: bool) -> None: +def test_kql_is_select_query(kql: str, expected: bool) -> None: """ Make sure that KQL dialect consider only statements that do not start with "." (dot) as a SELECT statements @@ -83,9 +80,7 @@ def test_kql_is_select_query(app_context: AppContext, kql: str, expected: bool) (".set-or-append table foo <| bar", False), ], ) -def test_kql_is_readonly_query( - app_context: AppContext, kql: str, expected: bool -) -> None: +def test_kql_is_readonly_query(kql: str, expected: bool) -> None: """ Make sure that KQL dialect consider only SELECT statements as read-only """ @@ -99,7 +94,7 @@ def test_kql_is_readonly_query( assert expected == is_readonly -def test_kql_parse_sql(app_context: AppContext) -> None: +def test_kql_parse_sql() -> None: """ parse_sql method should always return a list with a single element which is an original query @@ -121,7 +116,6 @@ def test_kql_parse_sql(app_context: AppContext) -> None: ], ) def test_kql_convert_dttm( - app_context: AppContext, target_type: str, expected_dttm: str, dttm: datetime, @@ -145,7 +139,6 @@ def test_kql_convert_dttm( ], ) def test_sql_convert_dttm( - app_context: AppContext, target_type: str, expected_dttm: str, dttm: datetime, diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index ddade3bfdb38c..0ceee0adf381e 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -19,7 +19,6 @@ from textwrap import dedent import pytest -from flask.ctx import AppContext from sqlalchemy import column, table from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR @@ -44,7 +43,6 @@ ], ) def test_mssql_column_types( - app_context: AppContext, type_string: str, type_expected: TypeEngine, generic_type_expected: GenericDataType, @@ -61,7 +59,7 @@ def test_mssql_column_types( assert column_spec.generic_type == generic_type_expected -def test_where_clause_n_prefix(app_context: AppContext) -> None: +def test_where_clause_n_prefix() -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec dialect = mssql.dialect() @@ -95,7 +93,7 @@ def test_where_clause_n_prefix(app_context: AppContext) -> None: assert query == query_expected -def test_time_exp_mixd_case_col_1y(app_context: AppContext) -> None: +def test_time_exp_mixd_case_col_1y() -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec col = column("MixedCase") @@ -122,7 +120,6 @@ def test_time_exp_mixd_case_col_1y(app_context: AppContext) -> None: ], ) def test_convert_dttm( - app_context: AppContext, actual: str, expected: str, dttm: datetime, @@ -132,7 +129,7 @@ def test_convert_dttm( assert MssqlEngineSpec.convert_dttm(actual, dttm) == expected -def test_extract_error_message(app_context: AppContext) -> None: +def test_extract_error_message() -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec test_mssql_exception = Exception( @@ -158,7 +155,7 @@ def test_extract_error_message(app_context: AppContext) -> None: assert expected_message == error_message -def test_fetch_data(app_context: AppContext) -> None: +def test_fetch_data() -> None: from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec @@ -185,9 +182,7 @@ def test_fetch_data(app_context: AppContext) -> None: (NTEXT(collation="utf8_general_ci"), "NTEXT"), ], ) -def test_column_datatype_to_string( - app_context: AppContext, original: TypeEngine, expected: str -) -> None: +def test_column_datatype_to_string(original: TypeEngine, expected: str) -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec actual = MssqlEngineSpec.column_datatype_to_string(original, mssql.dialect()) @@ -239,9 +234,7 @@ def test_column_datatype_to_string( ), ], ) -def test_cte_query_parsing( - app_context: AppContext, original: TypeEngine, expected: str -) -> None: +def test_cte_query_parsing(original: TypeEngine, expected: str) -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec actual = MssqlEngineSpec.get_cte_query(original) @@ -270,16 +263,14 @@ def test_cte_query_parsing( ), ], ) -def test_top_query_parsing( - app_context: AppContext, original: TypeEngine, expected: str, top: int -) -> None: +def test_top_query_parsing(original: TypeEngine, expected: str, top: int) -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec actual = MssqlEngineSpec.apply_top_to_sql(original, top) assert actual == expected -def test_extract_errors(app_context: AppContext) -> None: +def test_extract_errors() -> None: """ Test that custom error messages are extracted correctly. """ diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 228427c9caa76..11ab176ff0b20 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -19,7 +19,6 @@ import pytest import pytz -from flask.ctx import AppContext @pytest.mark.parametrize( @@ -45,7 +44,6 @@ ], ) def test_convert_dttm( - app_context: AppContext, target_type: str, dttm: datetime, result: Optional[str], diff --git a/tests/unit_tests/db_engine_specs/test_snowflake.py b/tests/unit_tests/db_engine_specs/test_snowflake.py index 961b92f626147..2479e071f2847 100644 --- a/tests/unit_tests/db_engine_specs/test_snowflake.py +++ b/tests/unit_tests/db_engine_specs/test_snowflake.py @@ -19,7 +19,6 @@ from unittest import mock import pytest -from flask.ctx import AppContext from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from tests.unit_tests.fixtures.common import dttm @@ -33,15 +32,13 @@ ("TIMESTAMP", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), ], ) -def test_convert_dttm( - app_context: AppContext, actual: str, expected: str, dttm: datetime -) -> None: +def test_convert_dttm(actual: str, expected: str, dttm: datetime) -> None: from superset.db_engine_specs.snowflake import SnowflakeEngineSpec assert SnowflakeEngineSpec.convert_dttm(actual, dttm) == expected -def test_database_connection_test_mutator(app_context: AppContext) -> None: +def test_database_connection_test_mutator() -> None: from superset.db_engine_specs.snowflake import SnowflakeEngineSpec from superset.models.core import Database @@ -54,7 +51,7 @@ def test_database_connection_test_mutator(app_context: AppContext) -> None: } == engine_params -def test_extract_errors(app_context: AppContext) -> None: +def test_extract_errors() -> None: from superset.db_engine_specs.snowflake import SnowflakeEngineSpec msg = "Object dumbBrick does not exist or not authorized." diff --git a/tests/unit_tests/db_engine_specs/test_sqlite.py b/tests/unit_tests/db_engine_specs/test_sqlite.py index 1ce574abe39c4..576f4ef9e9f17 100644 --- a/tests/unit_tests/db_engine_specs/test_sqlite.py +++ b/tests/unit_tests/db_engine_specs/test_sqlite.py @@ -19,31 +19,30 @@ from unittest import mock import pytest -from flask.ctx import AppContext from sqlalchemy.engine import create_engine from tests.unit_tests.fixtures.common import dttm -def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None: +def test_convert_dttm(dttm: datetime) -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec assert SqliteEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05.678900'" -def test_convert_dttm_lower(app_context: AppContext, dttm: datetime) -> None: +def test_convert_dttm_lower(dttm: datetime) -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec assert SqliteEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05.678900'" -def test_convert_dttm_invalid_type(app_context: AppContext, dttm: datetime) -> None: +def test_convert_dttm_invalid_type(dttm: datetime) -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec assert SqliteEngineSpec.convert_dttm("other", dttm) is None -def test_get_all_datasource_names_table(app_context: AppContext) -> None: +def test_get_all_datasource_names_table() -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec database = mock.MagicMock() @@ -62,7 +61,7 @@ def test_get_all_datasource_names_table(app_context: AppContext) -> None: ) -def test_get_all_datasource_names_view(app_context: AppContext) -> None: +def test_get_all_datasource_names_view() -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec database = mock.MagicMock() @@ -81,7 +80,7 @@ def test_get_all_datasource_names_view(app_context: AppContext) -> None: ) -def test_get_all_datasource_names_invalid_type(app_context: AppContext) -> None: +def test_get_all_datasource_names_invalid_type() -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec database = mock.MagicMock() @@ -132,9 +131,7 @@ def test_get_all_datasource_names_invalid_type(app_context: AppContext) -> None: ("2022-12-04T05:06:07.89Z", "P3M", "2022-10-01 00:00:00"), ], ) -def test_time_grain_expressions( - dttm: str, grain: str, expected: str, app_context: AppContext -) -> None: +def test_time_grain_expressions(dttm: str, grain: str, expected: str) -> None: from superset.db_engine_specs.sqlite import SqliteEngineSpec engine = create_engine("sqlite://") diff --git a/tests/unit_tests/db_engine_specs/test_teradata.py b/tests/unit_tests/db_engine_specs/test_teradata.py index 5887a9317c7f0..eab03e040d566 100644 --- a/tests/unit_tests/db_engine_specs/test_teradata.py +++ b/tests/unit_tests/db_engine_specs/test_teradata.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access import pytest -from flask.ctx import AppContext @pytest.mark.parametrize( @@ -32,7 +31,6 @@ ], ) def test_apply_top_to_sql_limit( - app_context: AppContext, limit: int, original: str, expected: str, diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py index 06bde3c4e1468..b84000a7f0577 100644 --- a/tests/unit_tests/explore/utils_test.py +++ b/tests/unit_tests/explore/utils_test.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from pytest import raises from pytest_mock import MockFixture @@ -51,7 +50,7 @@ ) -def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: +def test_unsaved_chart_no_dataset_id() -> None: from superset.explore.utils import check_access as check_chart_access with raises(DatasourceNotFoundValidationError): @@ -63,9 +62,7 @@ def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: ) -def test_unsaved_chart_unknown_dataset_id( - mocker: MockFixture, app_context: AppContext -) -> None: +def test_unsaved_chart_unknown_dataset_id(mocker: MockFixture) -> None: from superset.explore.utils import check_access as check_chart_access with raises(DatasetNotFoundError): @@ -79,9 +76,7 @@ def test_unsaved_chart_unknown_dataset_id( ) -def test_unsaved_chart_unknown_query_id( - mocker: MockFixture, app_context: AppContext -) -> None: +def test_unsaved_chart_unknown_query_id(mocker: MockFixture) -> None: from superset.explore.utils import check_access as check_chart_access with raises(QueryNotFoundValidationError): @@ -95,9 +90,7 @@ def test_unsaved_chart_unknown_query_id( ) -def test_unsaved_chart_unauthorized_dataset( - mocker: MockFixture, app_context: AppContext -) -> None: +def test_unsaved_chart_unauthorized_dataset(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access @@ -113,9 +106,7 @@ def test_unsaved_chart_unauthorized_dataset( ) -def test_unsaved_chart_authorized_dataset( - mocker: MockFixture, app_context: AppContext -) -> None: +def test_unsaved_chart_authorized_dataset(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access @@ -130,9 +121,7 @@ def test_unsaved_chart_authorized_dataset( ) -def test_saved_chart_unknown_chart_id( - mocker: MockFixture, app_context: AppContext -) -> None: +def test_saved_chart_unknown_chart_id(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access @@ -149,9 +138,7 @@ def test_saved_chart_unknown_chart_id( ) -def test_saved_chart_unauthorized_dataset( - mocker: MockFixture, app_context: AppContext -) -> None: +def test_saved_chart_unauthorized_dataset(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access @@ -167,7 +154,7 @@ def test_saved_chart_unauthorized_dataset( ) -def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: +def test_saved_chart_is_admin(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice @@ -185,7 +172,7 @@ def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> N ) -def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None: +def test_saved_chart_is_owner(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice @@ -204,7 +191,7 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N ) -def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None: +def test_saved_chart_has_access(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice @@ -224,7 +211,7 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> ) -def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None: +def test_saved_chart_no_access(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice @@ -245,7 +232,7 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> ) -def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None: +def test_dataset_has_access(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_datasource_access @@ -263,7 +250,7 @@ def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> Non ) -def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None: +def test_query_has_access(mocker: MockFixture) -> None: from superset.explore.utils import check_datasource_access from superset.models.sql_lab import Query @@ -281,7 +268,7 @@ def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None: ) -def test_query_no_access(mocker: MockFixture, client, app_context: AppContext) -> None: +def test_query_no_access(mocker: MockFixture, client) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_datasource_access from superset.models.core import Database diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index 75c49f0977bf6..13b3ae9e9c948 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -34,7 +34,7 @@ def test_where_in() -> None: assert where_in(["O'Malley's"]) == "('O''Malley''s')" -def test_dataset_macro(mocker: MockFixture, app_context: None) -> None: +def test_dataset_macro(mocker: MockFixture) -> None: """ Test the ``dataset_macro`` macro. """ diff --git a/tests/unit_tests/notifications/email_tests.py b/tests/unit_tests/notifications/email_tests.py index 9bc8b8090f3da..f9827580c6acd 100644 --- a/tests/unit_tests/notifications/email_tests.py +++ b/tests/unit_tests/notifications/email_tests.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. import pandas as pd -from flask.ctx import AppContext -def test_render_description_with_html(app_context: AppContext) -> None: +def test_render_description_with_html() -> None: # `superset.models.helpers`, a dependency of following imports, # requires app context from superset.reports.models import ReportRecipients, ReportRecipientType diff --git a/tests/unit_tests/result_set_test.py b/tests/unit_tests/result_set_test.py index 80d7ced61ecd0..48b9576a4ca79 100644 --- a/tests/unit_tests/result_set_test.py +++ b/tests/unit_tests/result_set_test.py @@ -18,7 +18,7 @@ # pylint: disable=import-outside-toplevel, unused-argument -def test_column_names_as_bytes(app_context: None) -> None: +def test_column_names_as_bytes() -> None: """ Test that we can handle column names as bytes. """ diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index c5bfa4a16d600..29f45eab682a0 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -63,7 +63,6 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: def test_execute_sql_statement_with_rls( mocker: MockerFixture, - app_context: None, ) -> None: """ Test for `execute_sql_statement` when an RLS rule is in place. @@ -118,7 +117,6 @@ def test_execute_sql_statement_with_rls( def test_sql_lab_insert_rls( mocker: MockerFixture, session: Session, - app_context: None, ) -> None: """ Integration test for `insert_rls`. diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 98eceebd47136..2f168d205cdaf 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1445,7 +1445,7 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None: assert str(condition) == expected -def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None: +def test_get_rls_for_table(mocker: MockerFixture) -> None: """ Tests for ``get_rls_for_table``. """ diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py index 56ca5ba82fbfc..7705dba6aa09d 100644 --- a/tests/unit_tests/tables/test_models.py +++ b/tests/unit_tests/tables/test_models.py @@ -20,7 +20,7 @@ from sqlalchemy.orm.session import Session -def test_table_model(app_context: None, session: Session) -> None: +def test_table_model(session: Session) -> None: """ Test basic attributes of a ``Table``. """ diff --git a/tests/unit_tests/tasks/test_cron_util.py b/tests/unit_tests/tasks/test_cron_util.py index 9042ccad58534..d0f9ae21705e2 100644 --- a/tests/unit_tests/tasks/test_cron_util.py +++ b/tests/unit_tests/tasks/test_cron_util.py @@ -20,7 +20,6 @@ import pytest import pytz from dateutil import parser -from flask.ctx import AppContext from freezegun import freeze_time from freezegun.api import FakeDatetime # type: ignore @@ -50,7 +49,7 @@ ], ) def test_cron_schedule_window_los_angeles( - app_context: AppContext, current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: List[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Los_Angeles" @@ -87,7 +86,7 @@ def test_cron_schedule_window_los_angeles( ], ) def test_cron_schedule_window_invalid_timezone( - app_context: AppContext, current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: List[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "invalid timezone" @@ -125,7 +124,7 @@ def test_cron_schedule_window_invalid_timezone( ], ) def test_cron_schedule_window_new_york( - app_context: AppContext, current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: List[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/New_York" @@ -162,7 +161,7 @@ def test_cron_schedule_window_new_york( ], ) def test_cron_schedule_window_chicago( - app_context: AppContext, current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: List[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Chicago" @@ -199,7 +198,7 @@ def test_cron_schedule_window_chicago( ], ) def test_cron_schedule_window_chicago_daylight( - app_context: AppContext, current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: List[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Chicago" diff --git a/tests/unit_tests/test_jinja_context.py b/tests/unit_tests/test_jinja_context.py index 7c301c88ea3e5..8704b1d65c211 100644 --- a/tests/unit_tests/test_jinja_context.py +++ b/tests/unit_tests/test_jinja_context.py @@ -18,7 +18,6 @@ from typing import Any import pytest -from flask.ctx import AppContext from sqlalchemy.dialects.postgresql import dialect from superset import app @@ -26,30 +25,30 @@ from superset.jinja_context import ExtraCache, safe_proxy -def test_filter_values_default(app_context: AppContext) -> None: +def test_filter_values_default() -> None: cache = ExtraCache() assert cache.filter_values("name", "foo") == ["foo"] assert cache.removed_filters == [] -def test_filter_values_remove_not_present(app_context: AppContext) -> None: +def test_filter_values_remove_not_present() -> None: cache = ExtraCache() assert cache.filter_values("name", remove_filter=True) == [] assert cache.removed_filters == [] -def test_get_filters_remove_not_present(app_context: AppContext) -> None: +def test_get_filters_remove_not_present() -> None: cache = ExtraCache() assert cache.get_filters("name", remove_filter=True) == [] assert cache.removed_filters == [] -def test_filter_values_no_default(app_context: AppContext) -> None: +def test_filter_values_no_default() -> None: cache = ExtraCache() assert cache.filter_values("name") == [] -def test_filter_values_adhoc_filters(app_context: AppContext) -> None: +def test_filter_values_adhoc_filters() -> None: with app.test_request_context( data={ "form_data": json.dumps( @@ -93,7 +92,7 @@ def test_filter_values_adhoc_filters(app_context: AppContext) -> None: assert cache.applied_filters == ["name"] -def test_get_filters_adhoc_filters(app_context: AppContext) -> None: +def test_get_filters_adhoc_filters() -> None: with app.test_request_context( data={ "form_data": json.dumps( @@ -167,7 +166,7 @@ def test_get_filters_adhoc_filters(app_context: AppContext) -> None: assert cache.applied_filters == ["name"] -def test_filter_values_extra_filters(app_context: AppContext) -> None: +def test_filter_values_extra_filters() -> None: with app.test_request_context( data={ "form_data": json.dumps( @@ -180,25 +179,25 @@ def test_filter_values_extra_filters(app_context: AppContext) -> None: assert cache.applied_filters == ["name"] -def test_url_param_default(app_context: AppContext) -> None: +def test_url_param_default() -> None: with app.test_request_context(): cache = ExtraCache() assert cache.url_param("foo", "bar") == "bar" -def test_url_param_no_default(app_context: AppContext) -> None: +def test_url_param_no_default() -> None: with app.test_request_context(): cache = ExtraCache() assert cache.url_param("foo") is None -def test_url_param_query(app_context: AppContext) -> None: +def test_url_param_query() -> None: with app.test_request_context(query_string={"foo": "bar"}): cache = ExtraCache() assert cache.url_param("foo") == "bar" -def test_url_param_form_data(app_context: AppContext) -> None: +def test_url_param_form_data() -> None: with app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})} ): @@ -206,7 +205,7 @@ def test_url_param_form_data(app_context: AppContext) -> None: assert cache.url_param("foo") == "bar" -def test_url_param_escaped_form_data(app_context: AppContext) -> None: +def test_url_param_escaped_form_data() -> None: with app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): @@ -214,7 +213,7 @@ def test_url_param_escaped_form_data(app_context: AppContext) -> None: assert cache.url_param("foo") == "O''Brien" -def test_url_param_escaped_default_form_data(app_context: AppContext) -> None: +def test_url_param_escaped_default_form_data() -> None: with app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): @@ -222,7 +221,7 @@ def test_url_param_escaped_default_form_data(app_context: AppContext) -> None: assert cache.url_param("bar", "O'Malley") == "O''Malley" -def test_url_param_unescaped_form_data(app_context: AppContext) -> None: +def test_url_param_unescaped_form_data() -> None: with app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): @@ -230,7 +229,7 @@ def test_url_param_unescaped_form_data(app_context: AppContext) -> None: assert cache.url_param("foo", escape_result=False) == "O'Brien" -def test_url_param_unescaped_default_form_data(app_context: AppContext) -> None: +def test_url_param_unescaped_default_form_data() -> None: with app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): @@ -238,21 +237,21 @@ def test_url_param_unescaped_default_form_data(app_context: AppContext) -> None: assert cache.url_param("bar", "O'Malley", escape_result=False) == "O'Malley" -def test_safe_proxy_primitive(app_context: AppContext) -> None: +def test_safe_proxy_primitive() -> None: def func(input_: Any) -> Any: return input_ assert safe_proxy(func, "foo") == "foo" -def test_safe_proxy_dict(app_context: AppContext) -> None: +def test_safe_proxy_dict() -> None: def func(input_: Any) -> Any: return input_ assert safe_proxy(func, {"foo": "bar"}) == {"foo": "bar"} -def test_safe_proxy_lambda(app_context: AppContext) -> None: +def test_safe_proxy_lambda() -> None: def func(input_: Any) -> Any: return input_ @@ -260,7 +259,7 @@ def func(input_: Any) -> Any: safe_proxy(func, lambda: "bar") -def test_safe_proxy_nested_lambda(app_context: AppContext) -> None: +def test_safe_proxy_nested_lambda() -> None: def func(input_: Any) -> Any: return input_ diff --git a/tests/unit_tests/utils/cache_test.py b/tests/unit_tests/utils/cache_test.py index 7c1354aa3cb39..53650e1d20324 100644 --- a/tests/unit_tests/utils/cache_test.py +++ b/tests/unit_tests/utils/cache_test.py @@ -21,7 +21,7 @@ from pytest_mock import MockerFixture -def test_memoized_func(app_context: None, mocker: MockerFixture) -> None: +def test_memoized_func(mocker: MockerFixture) -> None: """ Test the ``memoized_func`` decorator. """