oauth work
This commit is contained in:
parent
e9709b6c8f
commit
a0124baf1d
|
@ -193,4 +193,4 @@ if ENV != "DEV":
|
||||||
ASSETS_DEBUG = False
|
ASSETS_DEBUG = False
|
||||||
SQLALCHEMY_ECHO = False
|
SQLALCHEMY_ECHO = False
|
||||||
|
|
||||||
MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap']
|
MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap', 'auth.oauth']
|
||||||
|
|
|
@ -67,18 +67,6 @@ class BaseUser(UserMixin):
|
||||||
def load_user(*args, **kwargs):
|
def load_user(*args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_by_username(username):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_by_email(email):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def signer(salt):
|
def signer(salt):
|
||||||
return URLSafeSerializer(current_app.config['SECRET_KEY'] + salt)
|
return URLSafeSerializer(current_app.config['SECRET_KEY'] + salt)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import render_template
|
from flask_login import login_user
|
||||||
from flask_oauthlib.client import OAuth
|
from flask_oauthlib.client import OAuth
|
||||||
|
|
||||||
from realms import config
|
from realms import config
|
||||||
from ..models import BaseUser
|
from ..models import BaseUser
|
||||||
|
|
||||||
|
@ -10,23 +11,66 @@ users = {}
|
||||||
providers = {
|
providers = {
|
||||||
'twitter': {
|
'twitter': {
|
||||||
'oauth': dict(
|
'oauth': dict(
|
||||||
base_url='https://api.twitter.com/1/',
|
base_url='https://api.twitter.com/1.1/',
|
||||||
request_token_url='https://api.twitter.com/oauth/request_token',
|
request_token_url='https://api.twitter.com/oauth/request_token',
|
||||||
access_token_url='https://api.twitter.com/oauth/access_token',
|
access_token_url='https://api.twitter.com/oauth/access_token',
|
||||||
authorize_url='https://api.twitter.com/oauth/authenticate')
|
authorize_url='https://api.twitter.com/oauth/authenticate',
|
||||||
|
access_token_method='GET'),
|
||||||
|
'button': '<a href="/login/oauth/twitter" class="btn btn-default"><i class="fa fa-twitter"></i> Twitter</a>'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class User(BaseUser):
|
class User(BaseUser):
|
||||||
|
type = 'oauth'
|
||||||
|
provider = None
|
||||||
|
|
||||||
|
def __init__(self, provider, username, token):
|
||||||
|
self.provider = provider
|
||||||
|
self.username = username
|
||||||
|
self.id = username
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_token_id(self):
|
||||||
|
return self.token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_user(*args, **kwargs):
|
||||||
|
return User.get_by_id(args[0])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_by_id(user_id):
|
||||||
|
return users.get(user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auth(username, provider, token):
|
||||||
|
user = User(provider, username, User.hash_password(token))
|
||||||
|
users[user.id] = user
|
||||||
|
if user:
|
||||||
|
login_user(user, remember=True)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_app(cls, provider):
|
def get_app(cls, provider):
|
||||||
return oauth.remote_app(provider,
|
if oauth.remote_apps.get(provider):
|
||||||
|
return oauth.remote_apps.get(provider)
|
||||||
|
return oauth.remote_app(
|
||||||
|
provider,
|
||||||
consumer_key=config.OAUTH.get(provider, {}).get('key'),
|
consumer_key=config.OAUTH.get(provider, {}).get('key'),
|
||||||
consumer_secret=config.OAUTH.get(provider, {}).get('secret'),
|
consumer_secret=config.OAUTH.get(provider, {}).get(
|
||||||
|
'secret'),
|
||||||
**providers[provider]['oauth'])
|
**providers[provider]['oauth'])
|
||||||
|
|
||||||
|
def get_id(self):
|
||||||
|
return unicode("%s/%s/%s" % (self.type, self.provider, self.id))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def login_form():
|
def login_form():
|
||||||
pass
|
buttons = ''
|
||||||
|
for k, v in providers.items():
|
||||||
|
buttons += v.get('button')
|
||||||
|
|
||||||
|
return buttons
|
||||||
|
|
|
@ -10,17 +10,23 @@ def oauth_failed(next_url):
|
||||||
|
|
||||||
|
|
||||||
@blueprint.route("/login/oauth/<provider>")
|
@blueprint.route("/login/oauth/<provider>")
|
||||||
def oauth_login(provider):
|
def login(provider):
|
||||||
return User.get_app(provider).authorize(callback=url_for('oauth_callback', provider=provider,
|
return User.get_app(provider).authorize(callback=url_for('auth.oauth.callback', provider=provider))
|
||||||
next=request.args.get('next') or request.referrer or None))
|
|
||||||
|
|
||||||
|
|
||||||
@blueprint.route('/login/oauth/<provider>/callback')
|
@blueprint.route('/login/oauth/<provider>/callback')
|
||||||
def oauth_callback(provider):
|
def callback(provider):
|
||||||
next_url = request.args.get('next') or url_for('index')
|
next_url = request.args.get('next') or url_for('index')
|
||||||
|
try:
|
||||||
resp = User.get_app(provider).authorized_response()
|
resp = User.get_app(provider).authorized_response()
|
||||||
if resp is None:
|
if resp is None:
|
||||||
return oauth_failed(next_url)
|
flash('You denied the request to sign in.', 'error')
|
||||||
|
flash('Reason: ' + request.args['error_reason'] +
|
||||||
|
' ' + request.args['error_description'], 'error')
|
||||||
|
return redirect(next_url)
|
||||||
|
except Exception as e:
|
||||||
|
flash('Access denied: %s' % e.message)
|
||||||
|
return redirect(next_url)
|
||||||
|
|
||||||
session[provider + '_token'] = (
|
session[provider + '_token'] = (
|
||||||
resp['oauth_token'],
|
resp['oauth_token'],
|
||||||
|
|
Loading…
Reference in a new issue