Skip to content

Commit

Permalink
Using uuids to merge objects
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Jul 8, 2019
1 parent 90a2588 commit 824aef5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 74 deletions.
5 changes: 5 additions & 0 deletions superset/connectors/connector_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def get_datasource(cls, datasource_type, datasource_id, session):
.first()
)

@classmethod
def get_datasource_by_uuid(cls, session, source_type, uuid):
source_class = ConnectorRegistry.sources[source_type]
return session.query(source_class).filter_by(uuid=uuid).one()

@classmethod
def get_all_datasources(cls, session):
datasources = []
Expand Down
98 changes: 24 additions & 74 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from flask_appbuilder.models.mixins import AuditMixin
import humanize
import sqlalchemy as sa
from sqlalchemy import and_, or_, UniqueConstraint
from sqlalchemy import UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm.exc import MultipleResultsFound
import yaml

from superset.utils.core import QueryStatus
Expand Down Expand Up @@ -136,106 +135,57 @@ def formatter(c):
return schema

@classmethod
def import_from_dict(cls, session, dict_rep, parent=None, recursive=True, sync=[]):
def import_from_dict(
cls, session, dict_rep, parent=None, recursive=True, sync=None
):
"""Import obj from a dictionary"""
sync = sync or []
parent_refs = cls._parent_foreign_key_mappings()
export_fields = set(cls.export_fields) | set(parent_refs.keys())
new_children = {
c: dict_rep.get(c) for c in cls.export_children if c in dict_rep
}
unique_constrains = cls._unique_constrains()

filters = [] # Using these filters to check if obj already exists

# Remove fields that should not get imported
for k in list(dict_rep):
if k not in export_fields:
# Remove fields that should not get imported
if k not in export_fields and k != "uuid":
del dict_rep[k]
# Serialize json fields that are stored as text in the db
if k in cls.export_fields_json:
dict_rep[k] = json.dumps(dict_rep[k])

if not parent:
if cls.export_parent:
for p in parent_refs.keys():
if p not in dict_rep:
raise RuntimeError(
"{0}: Missing field {1}".format(cls.__name__, p)
)
else:
if parent:
# Set foreign keys to parent obj
for k, v in parent_refs.items():
dict_rep[k] = getattr(parent, v)

# Add filter for parent obj
filters.extend([getattr(cls, k) == dict_rep.get(k) for k in parent_refs.keys()])

# Add filter for unique constraints
ucs = [
and_(
*[
getattr(cls, k) == dict_rep.get(k)
for k in cs
if dict_rep.get(k) is not None
]
)
for cs in unique_constrains
]
filters.append(or_(*ucs))
elif cls.export_parent:
for p in parent_refs.keys():
if p not in dict_rep:
raise RuntimeError(f"{cls.__name__}: Missing field {p}")

# Check if object already exists in DB, break if more than one is found
try:
obj_query = session.query(cls).filter(and_(*filters))
obj = obj_query.one_or_none()
except MultipleResultsFound as e:
logging.error(
"Error importing %s \n %s \n %s",
cls.__name__,
str(obj_query),
yaml.safe_dump(dict_rep),
)
raise e
obj = session.query(cls).filter_by(uuid=dict_rep.get("uuid")).one_or_none()

if not obj:
is_new_obj = True
# Create new DB object
logging.info("Creating new %s %s", cls.__tablename__, str(obj))
obj = cls(**dict_rep)
logging.info("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
session.add(obj)
else:
is_new_obj = False
logging.info("Updating %s %s", obj.__tablename__, str(obj))
# Update columns
for k, v in dict_rep.items():
setattr(obj, k, v)

# Recursively create children
if recursive:
for c in cls.export_children:
child_class = cls.__mapper__.relationships[c].argument.class_
added = []
for c_obj in new_children.get(c, []):
added.append(
child_class.import_from_dict(
session=session, dict_rep=c_obj, parent=obj, sync=sync
)
)
# If children should get synced, delete the ones that did not
# get updated.
if c in sync and not is_new_obj:
back_refs = child_class._parent_foreign_key_mappings()
delete_filters = [
getattr(child_class, k) == getattr(obj, back_refs.get(k))
for k in back_refs.keys()
]
to_delete = set(
session.query(child_class).filter(and_(*delete_filters))
).difference(set(added))
for o in to_delete:
logging.info("Deleting %s %s", c, str(obj))
session.delete(o)

import_args = dict(session=session, parent=obj, sync=sync)
for rel in cls.export_children:
child_class = cls.__mapper__.relationships[rel].argument.class_
children = dict_rep.get(rel, [])
children_orm = [
child_class.import_from_dict(child, **import_args)
for child in children
]
setattr(obj, rel, children_orm)
return obj

def export_to_dict(
Expand Down

0 comments on commit 824aef5

Please sign in to comment.