286 lines
10 KiB
Python
286 lines
10 KiB
Python
import json
|
|
from realms import db
|
|
from sqlalchemy import not_
|
|
from datetime import datetime
|
|
|
|
|
|
class Model(db.Model):
|
|
"""Base SQLAlchemy Model for automatic serialization and
|
|
deserialization of columns and nested relationships.
|
|
|
|
Source: https://gist.github.com/alanhamlett/6604662
|
|
|
|
Usage::
|
|
|
|
>>> class User(Model):
|
|
>>> id = db.Column(db.Integer(), primary_key=True)
|
|
>>> email = db.Column(db.String(), index=True)
|
|
>>> name = db.Column(db.String())
|
|
>>> password = db.Column(db.String())
|
|
>>> posts = db.relationship('Post', backref='user', lazy='dynamic')
|
|
>>> ...
|
|
>>> default_fields = ['email', 'name']
|
|
>>> hidden_fields = ['password']
|
|
>>> readonly_fields = ['email', 'password']
|
|
>>>
|
|
>>> class Post(Model):
|
|
>>> id = db.Column(db.Integer(), primary_key=True)
|
|
>>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False)
|
|
>>> title = db.Column(db.String())
|
|
>>> ...
|
|
>>> default_fields = ['title']
|
|
>>> readonly_fields = ['user_id']
|
|
>>>
|
|
>>> model = User(email='john@localhost')
|
|
>>> db.session.add(model)
|
|
>>> db.session.commit()
|
|
>>>
|
|
>>> # update name and create a new post
|
|
>>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]}
|
|
>>> model.set_columns(**validated_input)
|
|
>>> db.session.commit()
|
|
>>>
|
|
>>> print(model.to_dict(show=['password', 'posts']))
|
|
>>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1}
|
|
"""
|
|
__abstract__ = True
|
|
# Stores changes made to this model's attributes. Can be retrieved
|
|
# with model.changes
|
|
_changes = {}
|
|
|
|
def __init__(self, **kwargs):
|
|
kwargs['_force'] = True
|
|
self._set_columns(**kwargs)
|
|
|
|
def _set_columns(self, **kwargs):
|
|
force = kwargs.get('_force')
|
|
|
|
readonly = []
|
|
if hasattr(self, 'readonly_fields'):
|
|
readonly = self.readonly_fields
|
|
if hasattr(self, 'hidden_fields'):
|
|
readonly += self.hidden_fields
|
|
|
|
readonly += [
|
|
'id',
|
|
'created',
|
|
'updated',
|
|
'modified',
|
|
'created_at',
|
|
'updated_at',
|
|
'modified_at',
|
|
]
|
|
|
|
changes = {}
|
|
|
|
columns = self.__table__.columns.keys()
|
|
relationships = self.__mapper__.relationships.keys()
|
|
|
|
for key in columns:
|
|
allowed = True if force or key not in readonly else False
|
|
exists = True if key in kwargs else False
|
|
if allowed and exists:
|
|
val = getattr(self, key)
|
|
if val != kwargs[key]:
|
|
changes[key] = {'old': val, 'new': kwargs[key]}
|
|
setattr(self, key, kwargs[key])
|
|
|
|
for rel in relationships:
|
|
allowed = True if force or rel not in readonly else False
|
|
exists = True if rel in kwargs else False
|
|
if allowed and exists:
|
|
is_list = self.__mapper__.relationships[rel].uselist
|
|
if is_list:
|
|
valid_ids = []
|
|
query = getattr(self, rel)
|
|
cls = self.__mapper__.relationships[rel].argument()
|
|
for item in kwargs[rel]:
|
|
if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1:
|
|
obj = cls.query.filter_by(id=item['id']).first()
|
|
col_changes = obj.set_columns(**item)
|
|
if col_changes:
|
|
col_changes['id'] = str(item['id'])
|
|
if rel in changes:
|
|
changes[rel].append(col_changes)
|
|
else:
|
|
changes.update({rel: [col_changes]})
|
|
valid_ids.append(str(item['id']))
|
|
else:
|
|
col = cls()
|
|
col_changes = col.set_columns(**item)
|
|
query.append(col)
|
|
db.session.flush()
|
|
if col_changes:
|
|
col_changes['id'] = str(col.id)
|
|
if rel in changes:
|
|
changes[rel].append(col_changes)
|
|
else:
|
|
changes.update({rel: [col_changes]})
|
|
valid_ids.append(str(col.id))
|
|
|
|
# delete related rows that were not in kwargs[rel]
|
|
for item in query.filter(not_(cls.id.in_(valid_ids))).all():
|
|
col_changes = {
|
|
'id': str(item.id),
|
|
'deleted': True,
|
|
}
|
|
if rel in changes:
|
|
changes[rel].append(col_changes)
|
|
else:
|
|
changes.update({rel: [col_changes]})
|
|
db.session.delete(item)
|
|
|
|
else:
|
|
val = getattr(self, rel)
|
|
if self.__mapper__.relationships[rel].query_class is not None:
|
|
if val is not None:
|
|
col_changes = val.set_columns(**kwargs[rel])
|
|
if col_changes:
|
|
changes.update({rel: col_changes})
|
|
else:
|
|
if val != kwargs[rel]:
|
|
setattr(self, rel, kwargs[rel])
|
|
changes[rel] = {'old': val, 'new': kwargs[rel]}
|
|
|
|
return changes
|
|
|
|
def set_columns(self, **kwargs):
|
|
self._changes = self._set_columns(**kwargs)
|
|
if 'modified' in self.__table__.columns:
|
|
self.modified = datetime.utcnow()
|
|
if 'updated' in self.__table__.columns:
|
|
self.updated = datetime.utcnow()
|
|
if 'modified_at' in self.__table__.columns:
|
|
self.modified_at = datetime.utcnow()
|
|
if 'updated_at' in self.__table__.columns:
|
|
self.updated_at = datetime.utcnow()
|
|
return self._changes
|
|
|
|
def __repr__(self):
|
|
if 'id' in self.__table__.columns.keys():
|
|
return '%s(%s)' % (self.__class__.__name__, self.id)
|
|
data = {}
|
|
for key in self.__table__.columns.keys():
|
|
val = getattr(self, key)
|
|
if type(val) is datetime:
|
|
val = val.strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
data[key] = val
|
|
return json.dumps(data, use_decimal=True)
|
|
|
|
@property
|
|
def changes(self):
|
|
return self._changes
|
|
|
|
def reset_changes(self):
|
|
self._changes = {}
|
|
|
|
def to_dict(self, show=None, hide=None, path=None, show_all=None):
|
|
""" Return a dictionary representation of this model.
|
|
"""
|
|
|
|
if not show:
|
|
show = []
|
|
if not hide:
|
|
hide = []
|
|
hidden = []
|
|
if hasattr(self, 'hidden_fields'):
|
|
hidden = self.hidden_fields
|
|
default = []
|
|
if hasattr(self, 'default_fields'):
|
|
default = self.default_fields
|
|
|
|
ret_data = {}
|
|
|
|
if not path:
|
|
path = self.__tablename__.lower()
|
|
def prepend_path(item):
|
|
item = item.lower()
|
|
if item.split('.', 1)[0] == path:
|
|
return item
|
|
if len(item) == 0:
|
|
return item
|
|
if item[0] != '.':
|
|
item = '.%s' % item
|
|
item = '%s%s' % (path, item)
|
|
return item
|
|
show[:] = [prepend_path(x) for x in show]
|
|
hide[:] = [prepend_path(x) for x in hide]
|
|
|
|
columns = self.__table__.columns.keys()
|
|
relationships = self.__mapper__.relationships.keys()
|
|
properties = dir(self)
|
|
|
|
for key in columns:
|
|
check = '%s.%s' % (path, key)
|
|
if check in hide or key in hidden:
|
|
continue
|
|
if show_all or key is 'id' or check in show or key in default:
|
|
ret_data[key] = getattr(self, key)
|
|
|
|
for key in relationships:
|
|
check = '%s.%s' % (path, key)
|
|
if check in hide or key in hidden:
|
|
continue
|
|
if show_all or check in show or key in default:
|
|
hide.append(check)
|
|
is_list = self.__mapper__.relationships[key].uselist
|
|
if is_list:
|
|
ret_data[key] = []
|
|
for item in getattr(self, key):
|
|
ret_data[key].append(item.to_dict(
|
|
show=show,
|
|
hide=hide,
|
|
path=('%s.%s' % (path, key.lower())),
|
|
show_all=show_all,
|
|
))
|
|
else:
|
|
if self.__mapper__.relationships[key].query_class is not None:
|
|
ret_data[key] = getattr(self, key).to_dict(
|
|
show=show,
|
|
hide=hide,
|
|
path=('%s.%s' % (path, key.lower())),
|
|
show_all=show_all,
|
|
)
|
|
else:
|
|
ret_data[key] = getattr(self, key)
|
|
|
|
for key in list(set(properties) - set(columns) - set(relationships)):
|
|
if key.startswith('_'):
|
|
continue
|
|
check = '%s.%s' % (path, key)
|
|
if check in hide or key in hidden:
|
|
continue
|
|
if show_all or check in show or key in default:
|
|
val = getattr(self, key)
|
|
try:
|
|
ret_data[key] = json.loads(json.dumps(val))
|
|
except:
|
|
pass
|
|
|
|
return ret_data
|
|
|
|
@classmethod
|
|
def insert_or_update(cls, cond, data):
|
|
obj = cls.query.filter_by(**cond).first()
|
|
if obj:
|
|
obj.set_columns(**data)
|
|
else:
|
|
data.update(cond)
|
|
obj = cls(**data)
|
|
db.session.add(obj)
|
|
db.session.commit()
|
|
|
|
def save(self):
|
|
if self not in db.session:
|
|
db.session.merge(self)
|
|
db.session.commit()
|
|
|
|
def delete(self):
|
|
if self not in db.session:
|
|
db.session.merge(self)
|
|
db.session.delete(self)
|
|
db.session.commit()
|
|
|
|
@classmethod
|
|
def get_by_id(cls, id_):
|
|
return cls.query.filter_by(id=id_).first() |