Git based wiki inspired by Gollum
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

300 lines
11KB

  1. from __future__ import absolute_import
  2. import json
  3. from datetime import datetime
  4. from sqlalchemy import not_, and_
  5. from realms import db
  6. class Model(db.Model):
  7. """Base SQLAlchemy Model for automatic serialization and
  8. deserialization of columns and nested relationships.
  9. Source: https://gist.github.com/alanhamlett/6604662
  10. Usage::
  11. >>> class User(Model):
  12. >>> id = db.Column(db.Integer(), primary_key=True)
  13. >>> email = db.Column(db.String(), index=True)
  14. >>> name = db.Column(db.String())
  15. >>> password = db.Column(db.String())
  16. >>> posts = db.relationship('Post', backref='user', lazy='dynamic')
  17. >>> ...
  18. >>> default_fields = ['email', 'name']
  19. >>> hidden_fields = ['password']
  20. >>> readonly_fields = ['email', 'password']
  21. >>>
  22. >>> class Post(Model):
  23. >>> id = db.Column(db.Integer(), primary_key=True)
  24. >>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False)
  25. >>> title = db.Column(db.String())
  26. >>> ...
  27. >>> default_fields = ['title']
  28. >>> readonly_fields = ['user_id']
  29. >>>
  30. >>> model = User(email='john@localhost')
  31. >>> db.session.add(model)
  32. >>> db.session.commit()
  33. >>>
  34. >>> # update name and create a new post
  35. >>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]}
  36. >>> model.set_columns(**validated_input)
  37. >>> db.session.commit()
  38. >>>
  39. >>> print(model.to_dict(show=['password', 'posts']))
  40. >>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1}
  41. """
  42. __abstract__ = True
  43. # Stores changes made to this model's attributes. Can be retrieved
  44. # with model.changes
  45. _changes = {}
  46. def __init__(self, **kwargs):
  47. kwargs['_force'] = True
  48. self._set_columns(**kwargs)
  49. def filter_by(self, **kwargs):
  50. clauses = [key == value
  51. for key, value in kwargs.items()]
  52. return self.filter(and_(*clauses))
  53. def _set_columns(self, **kwargs):
  54. force = kwargs.get('_force')
  55. readonly = []
  56. if hasattr(self, 'readonly_fields'):
  57. readonly = self.readonly_fields
  58. if hasattr(self, 'hidden_fields'):
  59. readonly += self.hidden_fields
  60. readonly += [
  61. 'id',
  62. 'created',
  63. 'updated',
  64. 'modified',
  65. 'created_at',
  66. 'updated_at',
  67. 'modified_at',
  68. ]
  69. changes = {}
  70. columns = self.__table__.columns.keys()
  71. relationships = self.__mapper__.relationships.keys()
  72. for key in columns:
  73. allowed = True if force or key not in readonly else False
  74. exists = True if key in kwargs else False
  75. if allowed and exists:
  76. val = getattr(self, key)
  77. if val != kwargs[key]:
  78. changes[key] = {'old': val, 'new': kwargs[key]}
  79. setattr(self, key, kwargs[key])
  80. for rel in relationships:
  81. allowed = True if force or rel not in readonly else False
  82. exists = True if rel in kwargs else False
  83. if allowed and exists:
  84. is_list = self.__mapper__.relationships[rel].uselist
  85. if is_list:
  86. valid_ids = []
  87. query = getattr(self, rel)
  88. cls = self.__mapper__.relationships[rel].argument()
  89. for item in kwargs[rel]:
  90. if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1:
  91. obj = cls.query.filter_by(id=item['id']).first()
  92. col_changes = obj.set_columns(**item)
  93. if col_changes:
  94. col_changes['id'] = str(item['id'])
  95. if rel in changes:
  96. changes[rel].append(col_changes)
  97. else:
  98. changes.update({rel: [col_changes]})
  99. valid_ids.append(str(item['id']))
  100. else:
  101. col = cls()
  102. col_changes = col.set_columns(**item)
  103. query.append(col)
  104. db.session.flush()
  105. if col_changes:
  106. col_changes['id'] = str(col.id)
  107. if rel in changes:
  108. changes[rel].append(col_changes)
  109. else:
  110. changes.update({rel: [col_changes]})
  111. valid_ids.append(str(col.id))
  112. # delete related rows that were not in kwargs[rel]
  113. for item in query.filter(not_(cls.id.in_(valid_ids))).all():
  114. col_changes = {
  115. 'id': str(item.id),
  116. 'deleted': True,
  117. }
  118. if rel in changes:
  119. changes[rel].append(col_changes)
  120. else:
  121. changes.update({rel: [col_changes]})
  122. db.session.delete(item)
  123. else:
  124. val = getattr(self, rel)
  125. if self.__mapper__.relationships[rel].query_class is not None:
  126. if val is not None:
  127. col_changes = val.set_columns(**kwargs[rel])
  128. if col_changes:
  129. changes.update({rel: col_changes})
  130. else:
  131. if val != kwargs[rel]:
  132. setattr(self, rel, kwargs[rel])
  133. changes[rel] = {'old': val, 'new': kwargs[rel]}
  134. return changes
  135. def set_columns(self, **kwargs):
  136. self._changes = self._set_columns(**kwargs)
  137. if 'modified' in self.__table__.columns:
  138. self.modified = datetime.utcnow()
  139. if 'updated' in self.__table__.columns:
  140. self.updated = datetime.utcnow()
  141. if 'modified_at' in self.__table__.columns:
  142. self.modified_at = datetime.utcnow()
  143. if 'updated_at' in self.__table__.columns:
  144. self.updated_at = datetime.utcnow()
  145. return self._changes
  146. def __repr__(self):
  147. if 'id' in self.__table__.columns.keys():
  148. return '%s(%s)' % (self.__class__.__name__, self.id)
  149. data = {}
  150. for key in self.__table__.columns.keys():
  151. val = getattr(self, key)
  152. if type(val) is datetime:
  153. val = val.strftime('%Y-%m-%dT%H:%M:%SZ')
  154. data[key] = val
  155. return json.dumps(data, use_decimal=True)
  156. @property
  157. def changes(self):
  158. return self._changes
  159. def reset_changes(self):
  160. self._changes = {}
  161. def to_dict(self, show=None, hide=None, path=None, show_all=None):
  162. """ Return a dictionary representation of this model.
  163. """
  164. if not show:
  165. show = []
  166. if not hide:
  167. hide = []
  168. hidden = []
  169. if hasattr(self, 'hidden_fields'):
  170. hidden = self.hidden_fields
  171. default = []
  172. if hasattr(self, 'default_fields'):
  173. default = self.default_fields
  174. ret_data = {}
  175. if not path:
  176. path = self.__tablename__.lower()
  177. def prepend_path(item):
  178. item = item.lower()
  179. if item.split('.', 1)[0] == path:
  180. return item
  181. if len(item) == 0:
  182. return item
  183. if item[0] != '.':
  184. item = '.%s' % item
  185. item = '%s%s' % (path, item)
  186. return item
  187. show[:] = [prepend_path(x) for x in show]
  188. hide[:] = [prepend_path(x) for x in hide]
  189. columns = self.__table__.columns.keys()
  190. relationships = self.__mapper__.relationships.keys()
  191. properties = dir(self)
  192. for key in columns:
  193. check = '%s.%s' % (path, key)
  194. if check in hide or key in hidden:
  195. continue
  196. if show_all or key is 'id' or check in show or key in default:
  197. ret_data[key] = getattr(self, key)
  198. for key in relationships:
  199. check = '%s.%s' % (path, key)
  200. if check in hide or key in hidden:
  201. continue
  202. if show_all or check in show or key in default:
  203. hide.append(check)
  204. is_list = self.__mapper__.relationships[key].uselist
  205. if is_list:
  206. ret_data[key] = []
  207. for item in getattr(self, key):
  208. ret_data[key].append(item.to_dict(
  209. show=show,
  210. hide=hide,
  211. path=('%s.%s' % (path, key.lower())),
  212. show_all=show_all,
  213. ))
  214. else:
  215. if self.__mapper__.relationships[key].query_class is not None:
  216. ret_data[key] = getattr(self, key).to_dict(
  217. show=show,
  218. hide=hide,
  219. path=('%s.%s' % (path, key.lower())),
  220. show_all=show_all,
  221. )
  222. else:
  223. ret_data[key] = getattr(self, key)
  224. for key in list(set(properties) - set(columns) - set(relationships)):
  225. if key.startswith('_'):
  226. continue
  227. check = '%s.%s' % (path, key)
  228. if check in hide or key in hidden:
  229. continue
  230. if show_all or check in show or key in default:
  231. val = getattr(self, key)
  232. try:
  233. ret_data[key] = json.loads(json.dumps(val))
  234. except:
  235. pass
  236. return ret_data
  237. @classmethod
  238. def insert_or_update(cls, cond, data):
  239. obj = cls.query.filter_by(**cond).first()
  240. if obj:
  241. obj.set_columns(**data)
  242. else:
  243. data.update(cond)
  244. obj = cls(**data)
  245. db.session.add(obj)
  246. db.session.commit()
  247. def save(self):
  248. if self not in db.session:
  249. db.session.merge(self)
  250. db.session.commit()
  251. def delete(self):
  252. if self not in db.session:
  253. db.session.merge(self)
  254. db.session.delete(self)
  255. db.session.commit()
  256. @classmethod
  257. def query(cls):
  258. return db.session.query(cls)
  259. @classmethod
  260. def get_by_id(cls, id_):
  261. return cls.query().filter_by(id=id_).first()