oauth work
This commit is contained in:
		
							parent
							
								
									e9709b6c8f
								
							
						
					
					
						commit
						a0124baf1d
					
				
					 4 changed files with 66 additions and 28 deletions
				
			
		|  | @ -193,4 +193,4 @@ if ENV != "DEV": | ||||||
|     ASSETS_DEBUG = False |     ASSETS_DEBUG = False | ||||||
|     SQLALCHEMY_ECHO = False |     SQLALCHEMY_ECHO = False | ||||||
| 
 | 
 | ||||||
| MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap'] | MODULES = ['wiki', 'search', 'auth', 'auth.local', 'auth.oauth', 'auth.ldap', 'auth.oauth'] | ||||||
|  |  | ||||||
|  | @ -67,18 +67,6 @@ class BaseUser(UserMixin): | ||||||
|     def load_user(*args, **kwargs): |     def load_user(*args, **kwargs): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|     @staticmethod |  | ||||||
|     def create(*args, **kwargs): |  | ||||||
|         pass |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def get_by_username(username): |  | ||||||
|         pass |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def get_by_email(email): |  | ||||||
|         pass |  | ||||||
| 
 |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def signer(salt): |     def signer(salt): | ||||||
|         return URLSafeSerializer(current_app.config['SECRET_KEY'] + salt) |         return URLSafeSerializer(current_app.config['SECRET_KEY'] + salt) | ||||||
|  |  | ||||||
|  | @ -1,5 +1,6 @@ | ||||||
| from flask import render_template | from flask_login import login_user | ||||||
| from flask_oauthlib.client import OAuth | from flask_oauthlib.client import OAuth | ||||||
|  | 
 | ||||||
| from realms import config | from realms import config | ||||||
| from ..models import BaseUser | from ..models import BaseUser | ||||||
| 
 | 
 | ||||||
|  | @ -10,23 +11,66 @@ users = {} | ||||||
| providers = { | providers = { | ||||||
|     'twitter': { |     'twitter': { | ||||||
|         'oauth': dict( |         'oauth': dict( | ||||||
|             base_url='https://api.twitter.com/1/', |             base_url='https://api.twitter.com/1.1/', | ||||||
|             request_token_url='https://api.twitter.com/oauth/request_token', |             request_token_url='https://api.twitter.com/oauth/request_token', | ||||||
|             access_token_url='https://api.twitter.com/oauth/access_token', |             access_token_url='https://api.twitter.com/oauth/access_token', | ||||||
|             authorize_url='https://api.twitter.com/oauth/authenticate') |             authorize_url='https://api.twitter.com/oauth/authenticate', | ||||||
|  |             access_token_method='GET'), | ||||||
|  |         'button': '<a href="/login/oauth/twitter" class="btn btn-default"><i class="fa fa-twitter"></i> Twitter</a>' | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class User(BaseUser): | class User(BaseUser): | ||||||
|  |     type = 'oauth' | ||||||
|  |     provider = None | ||||||
|  | 
 | ||||||
|  |     def __init__(self, provider, username, token): | ||||||
|  |         self.provider = provider | ||||||
|  |         self.username = username | ||||||
|  |         self.id = username | ||||||
|  |         self.token = token | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def auth_token_id(self): | ||||||
|  |         return self.token | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def load_user(*args, **kwargs): | ||||||
|  |         return User.get_by_id(args[0]) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def get_by_id(user_id): | ||||||
|  |         return users.get(user_id) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def auth(username, provider, token): | ||||||
|  |         user = User(provider, username, User.hash_password(token)) | ||||||
|  |         users[user.id] = user | ||||||
|  |         if user: | ||||||
|  |             login_user(user, remember=True) | ||||||
|  |             return True | ||||||
|  |         else: | ||||||
|  |             return False | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_app(cls, provider): |     def get_app(cls, provider): | ||||||
|         return oauth.remote_app(provider, |         if oauth.remote_apps.get(provider): | ||||||
|  |             return oauth.remote_apps.get(provider) | ||||||
|  |         return oauth.remote_app( | ||||||
|  |                 provider, | ||||||
|                 consumer_key=config.OAUTH.get(provider, {}).get('key'), |                 consumer_key=config.OAUTH.get(provider, {}).get('key'), | ||||||
|                                 consumer_secret=config.OAUTH.get(provider, {}).get('secret'), |                 consumer_secret=config.OAUTH.get(provider, {}).get( | ||||||
|  |                     'secret'), | ||||||
|                 **providers[provider]['oauth']) |                 **providers[provider]['oauth']) | ||||||
| 
 | 
 | ||||||
|  |     def get_id(self): | ||||||
|  |         return unicode("%s/%s/%s" % (self.type, self.provider, self.id)) | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def login_form(): |     def login_form(): | ||||||
|         pass |         buttons = '' | ||||||
|  |         for k, v in providers.items(): | ||||||
|  |             buttons += v.get('button') | ||||||
|  | 
 | ||||||
|  |         return buttons | ||||||
|  |  | ||||||
|  | @ -10,17 +10,23 @@ def oauth_failed(next_url): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @blueprint.route("/login/oauth/<provider>") | @blueprint.route("/login/oauth/<provider>") | ||||||
| def oauth_login(provider): | def login(provider): | ||||||
|     return User.get_app(provider).authorize(callback=url_for('oauth_callback', provider=provider, |     return User.get_app(provider).authorize(callback=url_for('auth.oauth.callback', provider=provider)) | ||||||
|         next=request.args.get('next') or request.referrer or None)) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @blueprint.route('/login/oauth/<provider>/callback') | @blueprint.route('/login/oauth/<provider>/callback') | ||||||
| def oauth_callback(provider): | def callback(provider): | ||||||
|     next_url = request.args.get('next') or url_for('index') |     next_url = request.args.get('next') or url_for('index') | ||||||
|  |     try: | ||||||
|         resp = User.get_app(provider).authorized_response() |         resp = User.get_app(provider).authorized_response() | ||||||
|         if resp is None: |         if resp is None: | ||||||
|         return oauth_failed(next_url) |             flash('You denied the request to sign in.', 'error') | ||||||
|  |             flash('Reason: ' + request.args['error_reason'] + | ||||||
|  |                   ' ' + request.args['error_description'], 'error') | ||||||
|  |             return redirect(next_url) | ||||||
|  |     except Exception as e: | ||||||
|  |         flash('Access denied: %s' % e.message) | ||||||
|  |         return redirect(next_url) | ||||||
| 
 | 
 | ||||||
|     session[provider + '_token'] = ( |     session[provider + '_token'] = ( | ||||||
|         resp['oauth_token'], |         resp['oauth_token'], | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue