diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index fb02ec8ca8201..21c0f4e42f823 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -45,7 +45,7 @@ def run(self) -> str: value = { "chartId": self.chart_id, "datasourceId": datasource_id, - "datasourceType": datasource_type, + "datasourceType": datasource_type.value, "datasource": self.datasource, "state": self.state, } diff --git a/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py b/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py index 57931ff821e7a..e840919de4fc6 100644 --- a/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py +++ b/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py @@ -45,7 +45,10 @@ class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): - raise pickle.UnpicklingError(f"Unpickling of {module}.{name} is forbidden") + if not (module == "superset.utils.core" and name == "DatasourceType"): + raise pickle.UnpicklingError(f"Unpickling of {module}.{name} is forbidden") + + return super().find_class(module, name) class KeyValueEntry(Base): @@ -58,14 +61,28 @@ class KeyValueEntry(Base): def upgrade(): bind = op.get_bind() session: Session = db.Session(bind=bind) + truncated_count = 0 for entry in paginated_update( session.query(KeyValueEntry).filter( KeyValueEntry.resource.in_(RESOURCES_TO_MIGRATE) ) ): - value = RestrictedUnpickler(io.BytesIO(entry.value)).load() or {} + try: + value = RestrictedUnpickler(io.BytesIO(entry.value)).load() or {} + except pickle.UnpicklingError as ex: + if str(ex) == "pickle data was truncated": + # make truncated values that were created prior to #20385 an empty + # dict so that downgrading will work properly. + truncated_count += 1 + value = {} + else: + raise + entry.value = bytes(json.dumps(value), encoding="utf-8") + if truncated_count: + print(f"Replaced {truncated_count} corrupted values with an empty value") + def downgrade(): bind = op.get_bind()