diff --git a/realms/config/__init__.py b/realms/config/__init__.py index 6895e1c..2f9fd58 100644 --- a/realms/config/__init__.py +++ b/realms/config/__init__.py @@ -193,4 +193,4 @@ if ENV != "DEV": ASSETS_DEBUG = False SQLALCHEMY_ECHO = False -MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap'] +MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap', 'auth.oauth'] diff --git a/realms/modules/auth/models.py b/realms/modules/auth/models.py index a8fa888..ba9ccd7 100644 --- a/realms/modules/auth/models.py +++ b/realms/modules/auth/models.py @@ -67,18 +67,6 @@ class BaseUser(UserMixin): def load_user(*args, **kwargs): raise NotImplementedError - @staticmethod - def create(*args, **kwargs): - pass - - @staticmethod - def get_by_username(username): - pass - - @staticmethod - def get_by_email(email): - pass - @staticmethod def signer(salt): return URLSafeSerializer(current_app.config['SECRET_KEY'] + salt) diff --git a/realms/modules/auth/oauth/models.py b/realms/modules/auth/oauth/models.py index 1777123..252756e 100644 --- a/realms/modules/auth/oauth/models.py +++ b/realms/modules/auth/oauth/models.py @@ -1,5 +1,6 @@ -from flask import render_template +from flask_login import login_user from flask_oauthlib.client import OAuth + from realms import config from ..models import BaseUser @@ -10,23 +11,66 @@ users = {} providers = { 'twitter': { 'oauth': dict( - base_url='https://api.twitter.com/1/', + base_url='https://api.twitter.com/1.1/', request_token_url='https://api.twitter.com/oauth/request_token', access_token_url='https://api.twitter.com/oauth/access_token', - authorize_url='https://api.twitter.com/oauth/authenticate') + authorize_url='https://api.twitter.com/oauth/authenticate', + access_token_method='GET'), + 'button': ' Twitter' } } class User(BaseUser): + type = 'oauth' + provider = None + + def __init__(self, provider, username, token): + self.provider = provider + self.username = username + self.id = username + self.token = token + + @property + def auth_token_id(self): + return self.token + + @staticmethod + def load_user(*args, **kwargs): + return User.get_by_id(args[0]) + + @staticmethod + def get_by_id(user_id): + return users.get(user_id) + + @staticmethod + def auth(username, provider, token): + user = User(provider, username, User.hash_password(token)) + users[user.id] = user + if user: + login_user(user, remember=True) + return True + else: + return False @classmethod def get_app(cls, provider): - return oauth.remote_app(provider, - consumer_key=config.OAUTH.get(provider, {}).get('key'), - consumer_secret=config.OAUTH.get(provider, {}).get('secret'), - **providers[provider]['oauth']) + if oauth.remote_apps.get(provider): + return oauth.remote_apps.get(provider) + return oauth.remote_app( + provider, + consumer_key=config.OAUTH.get(provider, {}).get('key'), + consumer_secret=config.OAUTH.get(provider, {}).get( + 'secret'), + **providers[provider]['oauth']) + + def get_id(self): + return unicode("%s/%s/%s" % (self.type, self.provider, self.id)) @staticmethod def login_form(): - pass + buttons = '' + for k, v in providers.items(): + buttons += v.get('button') + + return buttons diff --git a/realms/modules/auth/oauth/views.py b/realms/modules/auth/oauth/views.py index bb6990b..b51c60c 100644 --- a/realms/modules/auth/oauth/views.py +++ b/realms/modules/auth/oauth/views.py @@ -10,17 +10,23 @@ def oauth_failed(next_url): @blueprint.route("/login/oauth/") -def oauth_login(provider): - return User.get_app(provider).authorize(callback=url_for('oauth_callback', provider=provider, - next=request.args.get('next') or request.referrer or None)) +def login(provider): + return User.get_app(provider).authorize(callback=url_for('auth.oauth.callback', provider=provider)) @blueprint.route('/login/oauth//callback') -def oauth_callback(provider): +def callback(provider): next_url = request.args.get('next') or url_for('index') - resp = User.get_app(provider).authorized_response() - if resp is None: - return oauth_failed(next_url) + try: + resp = User.get_app(provider).authorized_response() + if resp is None: + flash('You denied the request to sign in.', 'error') + flash('Reason: ' + request.args['error_reason'] + + ' ' + request.args['error_description'], 'error') + return redirect(next_url) + except Exception as e: + flash('Access denied: %s' % e.message) + return redirect(next_url) session[provider + '_token'] = ( resp['oauth_token'],