Git based wiki inspired by Gollum
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

295 lines
11KB

  1. import json
  2. from sqlalchemy import not_, and_
  3. from datetime import datetime
  4. from realms import db
  5. class Model(db.Model):
  6. """Base SQLAlchemy Model for automatic serialization and
  7. deserialization of columns and nested relationships.
  8. Source: https://gist.github.com/alanhamlett/6604662
  9. Usage::
  10. >>> class User(Model):
  11. >>> id = db.Column(db.Integer(), primary_key=True)
  12. >>> email = db.Column(db.String(), index=True)
  13. >>> name = db.Column(db.String())
  14. >>> password = db.Column(db.String())
  15. >>> posts = db.relationship('Post', backref='user', lazy='dynamic')
  16. >>> ...
  17. >>> default_fields = ['email', 'name']
  18. >>> hidden_fields = ['password']
  19. >>> readonly_fields = ['email', 'password']
  20. >>>
  21. >>> class Post(Model):
  22. >>> id = db.Column(db.Integer(), primary_key=True)
  23. >>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False)
  24. >>> title = db.Column(db.String())
  25. >>> ...
  26. >>> default_fields = ['title']
  27. >>> readonly_fields = ['user_id']
  28. >>>
  29. >>> model = User(email='john@localhost')
  30. >>> db.session.add(model)
  31. >>> db.session.commit()
  32. >>>
  33. >>> # update name and create a new post
  34. >>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]}
  35. >>> model.set_columns(**validated_input)
  36. >>> db.session.commit()
  37. >>>
  38. >>> print(model.to_dict(show=['password', 'posts']))
  39. >>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1}
  40. """
  41. __abstract__ = True
  42. # Stores changes made to this model's attributes. Can be retrieved
  43. # with model.changes
  44. _changes = {}
  45. def __init__(self, **kwargs):
  46. kwargs['_force'] = True
  47. self._set_columns(**kwargs)
  48. def filter_by(self, **kwargs):
  49. clauses = [key == value
  50. for key, value in kwargs.items()]
  51. return self.filter(and_(*clauses))
  52. def _set_columns(self, **kwargs):
  53. force = kwargs.get('_force')
  54. readonly = []
  55. if hasattr(self, 'readonly_fields'):
  56. readonly = self.readonly_fields
  57. if hasattr(self, 'hidden_fields'):
  58. readonly += self.hidden_fields
  59. readonly += [
  60. 'id',
  61. 'created',
  62. 'updated',
  63. 'modified',
  64. 'created_at',
  65. 'updated_at',
  66. 'modified_at',
  67. ]
  68. changes = {}
  69. columns = self.__table__.columns.keys()
  70. relationships = self.__mapper__.relationships.keys()
  71. for key in columns:
  72. allowed = True if force or key not in readonly else False
  73. exists = True if key in kwargs else False
  74. if allowed and exists:
  75. val = getattr(self, key)
  76. if val != kwargs[key]:
  77. changes[key] = {'old': val, 'new': kwargs[key]}
  78. setattr(self, key, kwargs[key])
  79. for rel in relationships:
  80. allowed = True if force or rel not in readonly else False
  81. exists = True if rel in kwargs else False
  82. if allowed and exists:
  83. is_list = self.__mapper__.relationships[rel].uselist
  84. if is_list:
  85. valid_ids = []
  86. query = getattr(self, rel)
  87. cls = self.__mapper__.relationships[rel].argument()
  88. for item in kwargs[rel]:
  89. if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1:
  90. obj = cls.query.filter_by(id=item['id']).first()
  91. col_changes = obj.set_columns(**item)
  92. if col_changes:
  93. col_changes['id'] = str(item['id'])
  94. if rel in changes:
  95. changes[rel].append(col_changes)
  96. else:
  97. changes.update({rel: [col_changes]})
  98. valid_ids.append(str(item['id']))
  99. else:
  100. col = cls()
  101. col_changes = col.set_columns(**item)
  102. query.append(col)
  103. db.session.flush()
  104. if col_changes:
  105. col_changes['id'] = str(col.id)
  106. if rel in changes:
  107. changes[rel].append(col_changes)
  108. else:
  109. changes.update({rel: [col_changes]})
  110. valid_ids.append(str(col.id))
  111. # delete related rows that were not in kwargs[rel]
  112. for item in query.filter(not_(cls.id.in_(valid_ids))).all():
  113. col_changes = {
  114. 'id': str(item.id),
  115. 'deleted': True,
  116. }
  117. if rel in changes:
  118. changes[rel].append(col_changes)
  119. else:
  120. changes.update({rel: [col_changes]})
  121. db.session.delete(item)
  122. else:
  123. val = getattr(self, rel)
  124. if self.__mapper__.relationships[rel].query_class is not None:
  125. if val is not None:
  126. col_changes = val.set_columns(**kwargs[rel])
  127. if col_changes:
  128. changes.update({rel: col_changes})
  129. else:
  130. if val != kwargs[rel]:
  131. setattr(self, rel, kwargs[rel])
  132. changes[rel] = {'old': val, 'new': kwargs[rel]}
  133. return changes
  134. def set_columns(self, **kwargs):
  135. self._changes = self._set_columns(**kwargs)
  136. if 'modified' in self.__table__.columns:
  137. self.modified = datetime.utcnow()
  138. if 'updated' in self.__table__.columns:
  139. self.updated = datetime.utcnow()
  140. if 'modified_at' in self.__table__.columns:
  141. self.modified_at = datetime.utcnow()
  142. if 'updated_at' in self.__table__.columns:
  143. self.updated_at = datetime.utcnow()
  144. return self._changes
  145. def __repr__(self):
  146. if 'id' in self.__table__.columns.keys():
  147. return '%s(%s)' % (self.__class__.__name__, self.id)
  148. data = {}
  149. for key in self.__table__.columns.keys():
  150. val = getattr(self, key)
  151. if type(val) is datetime:
  152. val = val.strftime('%Y-%m-%dT%H:%M:%SZ')
  153. data[key] = val
  154. return json.dumps(data, use_decimal=True)
  155. @property
  156. def changes(self):
  157. return self._changes
  158. def reset_changes(self):
  159. self._changes = {}
  160. def to_dict(self, show=None, hide=None, path=None, show_all=None):
  161. """ Return a dictionary representation of this model.
  162. """
  163. if not show:
  164. show = []
  165. if not hide:
  166. hide = []
  167. hidden = []
  168. if hasattr(self, 'hidden_fields'):
  169. hidden = self.hidden_fields
  170. default = []
  171. if hasattr(self, 'default_fields'):
  172. default = self.default_fields
  173. ret_data = {}
  174. if not path:
  175. path = self.__tablename__.lower()
  176. def prepend_path(item):
  177. item = item.lower()
  178. if item.split('.', 1)[0] == path:
  179. return item
  180. if len(item) == 0:
  181. return item
  182. if item[0] != '.':
  183. item = '.%s' % item
  184. item = '%s%s' % (path, item)
  185. return item
  186. show[:] = [prepend_path(x) for x in show]
  187. hide[:] = [prepend_path(x) for x in hide]
  188. columns = self.__table__.columns.keys()
  189. relationships = self.__mapper__.relationships.keys()
  190. properties = dir(self)
  191. for key in columns:
  192. check = '%s.%s' % (path, key)
  193. if check in hide or key in hidden:
  194. continue
  195. if show_all or key is 'id' or check in show or key in default:
  196. ret_data[key] = getattr(self, key)
  197. for key in relationships:
  198. check = '%s.%s' % (path, key)
  199. if check in hide or key in hidden:
  200. continue
  201. if show_all or check in show or key in default:
  202. hide.append(check)
  203. is_list = self.__mapper__.relationships[key].uselist
  204. if is_list:
  205. ret_data[key] = []
  206. for item in getattr(self, key):
  207. ret_data[key].append(item.to_dict(
  208. show=show,
  209. hide=hide,
  210. path=('%s.%s' % (path, key.lower())),
  211. show_all=show_all,
  212. ))
  213. else:
  214. if self.__mapper__.relationships[key].query_class is not None:
  215. ret_data[key] = getattr(self, key).to_dict(
  216. show=show,
  217. hide=hide,
  218. path=('%s.%s' % (path, key.lower())),
  219. show_all=show_all,
  220. )
  221. else:
  222. ret_data[key] = getattr(self, key)
  223. for key in list(set(properties) - set(columns) - set(relationships)):
  224. if key.startswith('_'):
  225. continue
  226. check = '%s.%s' % (path, key)
  227. if check in hide or key in hidden:
  228. continue
  229. if show_all or check in show or key in default:
  230. val = getattr(self, key)
  231. try:
  232. ret_data[key] = json.loads(json.dumps(val))
  233. except:
  234. pass
  235. return ret_data
  236. @classmethod
  237. def insert_or_update(cls, cond, data):
  238. obj = cls.query.filter_by(**cond).first()
  239. if obj:
  240. obj.set_columns(**data)
  241. else:
  242. data.update(cond)
  243. obj = cls(**data)
  244. db.session.add(obj)
  245. db.session.commit()
  246. def save(self):
  247. if self not in db.session:
  248. db.session.merge(self)
  249. db.session.commit()
  250. def delete(self):
  251. if self not in db.session:
  252. db.session.merge(self)
  253. db.session.delete(self)
  254. db.session.commit()
  255. @classmethod
  256. def query(cls):
  257. return db.session.query(cls)
  258. @classmethod
  259. def get_by_id(cls, id_):
  260. return cls.query().filter_by(id=id_).first()