diff --git a/realms/config/__init__.py b/realms/config/__init__.py
index 6895e1c..2f9fd58 100644
--- a/realms/config/__init__.py
+++ b/realms/config/__init__.py
@@ -193,4 +193,4 @@ if ENV != "DEV":
ASSETS_DEBUG = False
SQLALCHEMY_ECHO = False
-MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap']
+MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap', 'auth.oauth']
diff --git a/realms/modules/auth/models.py b/realms/modules/auth/models.py
index a8fa888..ba9ccd7 100644
--- a/realms/modules/auth/models.py
+++ b/realms/modules/auth/models.py
@@ -67,18 +67,6 @@ class BaseUser(UserMixin):
def load_user(*args, **kwargs):
raise NotImplementedError
- @staticmethod
- def create(*args, **kwargs):
- pass
-
- @staticmethod
- def get_by_username(username):
- pass
-
- @staticmethod
- def get_by_email(email):
- pass
-
@staticmethod
def signer(salt):
return URLSafeSerializer(current_app.config['SECRET_KEY'] + salt)
diff --git a/realms/modules/auth/oauth/models.py b/realms/modules/auth/oauth/models.py
index 1777123..252756e 100644
--- a/realms/modules/auth/oauth/models.py
+++ b/realms/modules/auth/oauth/models.py
@@ -1,5 +1,6 @@
-from flask import render_template
+from flask_login import login_user
from flask_oauthlib.client import OAuth
+
from realms import config
from ..models import BaseUser
@@ -10,23 +11,66 @@ users = {}
providers = {
'twitter': {
'oauth': dict(
- base_url='https://api.twitter.com/1/',
+ base_url='https://api.twitter.com/1.1/',
request_token_url='https://api.twitter.com/oauth/request_token',
access_token_url='https://api.twitter.com/oauth/access_token',
- authorize_url='https://api.twitter.com/oauth/authenticate')
+ authorize_url='https://api.twitter.com/oauth/authenticate',
+ access_token_method='GET'),
+ 'button': ' Twitter'
}
}
class User(BaseUser):
+ type = 'oauth'
+ provider = None
+
+ def __init__(self, provider, username, token):
+ self.provider = provider
+ self.username = username
+ self.id = username
+ self.token = token
+
+ @property
+ def auth_token_id(self):
+ return self.token
+
+ @staticmethod
+ def load_user(*args, **kwargs):
+ return User.get_by_id(args[0])
+
+ @staticmethod
+ def get_by_id(user_id):
+ return users.get(user_id)
+
+ @staticmethod
+ def auth(username, provider, token):
+ user = User(provider, username, User.hash_password(token))
+ users[user.id] = user
+ if user:
+ login_user(user, remember=True)
+ return True
+ else:
+ return False
@classmethod
def get_app(cls, provider):
- return oauth.remote_app(provider,
- consumer_key=config.OAUTH.get(provider, {}).get('key'),
- consumer_secret=config.OAUTH.get(provider, {}).get('secret'),
- **providers[provider]['oauth'])
+ if oauth.remote_apps.get(provider):
+ return oauth.remote_apps.get(provider)
+ return oauth.remote_app(
+ provider,
+ consumer_key=config.OAUTH.get(provider, {}).get('key'),
+ consumer_secret=config.OAUTH.get(provider, {}).get(
+ 'secret'),
+ **providers[provider]['oauth'])
+
+ def get_id(self):
+ return unicode("%s/%s/%s" % (self.type, self.provider, self.id))
@staticmethod
def login_form():
- pass
+ buttons = ''
+ for k, v in providers.items():
+ buttons += v.get('button')
+
+ return buttons
diff --git a/realms/modules/auth/oauth/views.py b/realms/modules/auth/oauth/views.py
index bb6990b..b51c60c 100644
--- a/realms/modules/auth/oauth/views.py
+++ b/realms/modules/auth/oauth/views.py
@@ -10,17 +10,23 @@ def oauth_failed(next_url):
@blueprint.route("/login/oauth/")
-def oauth_login(provider):
- return User.get_app(provider).authorize(callback=url_for('oauth_callback', provider=provider,
- next=request.args.get('next') or request.referrer or None))
+def login(provider):
+ return User.get_app(provider).authorize(callback=url_for('auth.oauth.callback', provider=provider))
@blueprint.route('/login/oauth//callback')
-def oauth_callback(provider):
+def callback(provider):
next_url = request.args.get('next') or url_for('index')
- resp = User.get_app(provider).authorized_response()
- if resp is None:
- return oauth_failed(next_url)
+ try:
+ resp = User.get_app(provider).authorized_response()
+ if resp is None:
+ flash('You denied the request to sign in.', 'error')
+ flash('Reason: ' + request.args['error_reason'] +
+ ' ' + request.args['error_description'], 'error')
+ return redirect(next_url)
+ except Exception as e:
+ flash('Access denied: %s' % e.message)
+ return redirect(next_url)
session[provider + '_token'] = (
resp['oauth_token'],