add more REST routes, split accounts.py (user loader) from main
This commit is contained in:
parent
cc8858b7ac
commit
87d2eb6d0b
5 changed files with 205 additions and 82 deletions
|
|
@ -26,7 +26,7 @@ from suou import twocolon_list, WantsContentType
|
||||||
|
|
||||||
from .colors import color_themes, theme_classes
|
from .colors import color_themes, theme_classes
|
||||||
|
|
||||||
__version__ = '0.5.0-dev34'
|
__version__ = '0.5.0-dev35'
|
||||||
|
|
||||||
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
|
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
@ -57,57 +57,10 @@ app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
|
||||||
app.config['QUART_AUTH_DURATION'] = 365 * 24 * 60 * 60
|
app.config['QUART_AUTH_DURATION'] = 365 * 24 * 60 * 60
|
||||||
app.config['SERVER_NAME'] = app_config.server_name
|
app.config['SERVER_NAME'] = app_config.server_name
|
||||||
|
|
||||||
class UserLoader(AuthUser):
|
|
||||||
"""
|
|
||||||
Loads user from the session.
|
|
||||||
|
|
||||||
*WARNING* requires to be awaited before request before usage!
|
|
||||||
|
|
||||||
Actual User object is at .user; other attributes are proxied.
|
|
||||||
"""
|
|
||||||
def __init__(self, auth_id: str | None, action: QA_Action= QA_Action.PASS):
|
|
||||||
self._auth_id = auth_id
|
|
||||||
self._auth_obj = None
|
|
||||||
self._auth_sess = None
|
|
||||||
self.action = action
|
|
||||||
|
|
||||||
@property
|
|
||||||
def auth_id(self) -> str | None:
|
|
||||||
return self._auth_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
async def is_authenticated(self) -> bool:
|
|
||||||
await self._load()
|
|
||||||
return self._auth_id is not None
|
|
||||||
|
|
||||||
async def _load(self):
|
|
||||||
if self._auth_obj is None and self._auth_id is not None:
|
|
||||||
session = self._auth_sess = await db.begin()
|
|
||||||
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
|
|
||||||
if self._auth_obj is None:
|
|
||||||
raise RuntimeError('failed to fetch user')
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
if self._auth_obj is None:
|
|
||||||
raise RuntimeError('user is not loaded')
|
|
||||||
return getattr(self._auth_obj, key)
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return self._auth_obj is not None
|
|
||||||
|
|
||||||
async def _unload(self):
|
|
||||||
# user is not expected to mutate
|
|
||||||
if self._auth_sess:
|
|
||||||
await self._auth_sess.rollback()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def user(self):
|
|
||||||
return self._auth_obj
|
|
||||||
|
|
||||||
id: int
|
|
||||||
|
|
||||||
## DO NOT ADD LOCAL IMPORTS BEFORE THIS LINE
|
## DO NOT ADD LOCAL IMPORTS BEFORE THIS LINE
|
||||||
|
|
||||||
|
from .accounts import UserLoader
|
||||||
from .models import Guild, db, User, Post
|
from .models import Guild, db, User, Post
|
||||||
|
|
||||||
# SASS
|
# SASS
|
||||||
|
|
@ -190,7 +143,7 @@ async def _load_user():
|
||||||
g.no_user = True
|
g.no_user = True
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
async def _unload_request(resp):
|
async def _unload_user(resp):
|
||||||
try:
|
try:
|
||||||
await current_user._unload()
|
await current_user._unload()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|
|
||||||
81
freak/accounts.py
Normal file
81
freak/accounts.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
from .models import User, db
|
||||||
|
from quart_auth import AuthUser, Action as _Action
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LoginStatus(enum.Enum):
|
||||||
|
SUCCESS = 0
|
||||||
|
ERROR = 1
|
||||||
|
SUSPENDED = 2
|
||||||
|
PASS_EXPIRED = 3
|
||||||
|
|
||||||
|
def check_login(user: User | None, password: str) -> LoginStatus:
|
||||||
|
try:
|
||||||
|
if user is None:
|
||||||
|
return LoginStatus.ERROR
|
||||||
|
if ('$' not in user.passhash) and user.email:
|
||||||
|
return LoginStatus.PASS_EXPIRED
|
||||||
|
if not user.is_active:
|
||||||
|
return LoginStatus.SUSPENDED
|
||||||
|
if user.check_password(password):
|
||||||
|
return LoginStatus.SUCCESS
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'{e}')
|
||||||
|
return LoginStatus.ERROR
|
||||||
|
|
||||||
|
|
||||||
|
class UserLoader(AuthUser):
|
||||||
|
"""
|
||||||
|
Loads user from the session.
|
||||||
|
|
||||||
|
*WARNING* requires to be awaited before request before usage!
|
||||||
|
|
||||||
|
Actual User object is at .user; other attributes are proxied.
|
||||||
|
"""
|
||||||
|
def __init__(self, auth_id: str | None, action: _Action= _Action.PASS):
|
||||||
|
self._auth_id = auth_id
|
||||||
|
self._auth_obj = None
|
||||||
|
self._auth_sess = None
|
||||||
|
self.action = action
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_id(self) -> str | None:
|
||||||
|
return self._auth_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def is_authenticated(self) -> bool:
|
||||||
|
await self._load()
|
||||||
|
return self._auth_id is not None
|
||||||
|
|
||||||
|
async def _load(self):
|
||||||
|
if self._auth_obj is None and self._auth_id is not None:
|
||||||
|
async with db as session:
|
||||||
|
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
|
||||||
|
if self._auth_obj is None:
|
||||||
|
raise RuntimeError('failed to fetch user')
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
if self._auth_obj is None:
|
||||||
|
raise RuntimeError('user is not loaded')
|
||||||
|
return getattr(self._auth_obj, key)
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self._auth_obj is not None
|
||||||
|
|
||||||
|
async def _unload(self):
|
||||||
|
# user is not expected to mutate
|
||||||
|
if self._auth_sess:
|
||||||
|
await self._auth_sess.rollback()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self):
|
||||||
|
return self._auth_obj
|
||||||
|
|
||||||
|
id: int
|
||||||
|
|
@ -9,8 +9,8 @@ from operator import or_
|
||||||
import re
|
import re
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from quart_auth import AuthUser, current_user
|
from quart_auth import current_user
|
||||||
from sqlalchemy import Column, ExceptionContext, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \
|
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \
|
||||||
CheckConstraint, Date, DateTime, Boolean, func, BigInteger, \
|
CheckConstraint, Date, DateTime, Boolean, func, BigInteger, \
|
||||||
SmallInteger, select, update, Table
|
SmallInteger, select, update, Table
|
||||||
from sqlalchemy.orm import Relationship, relationship
|
from sqlalchemy.orm import Relationship, relationship
|
||||||
|
|
@ -19,13 +19,11 @@ from suou import SiqType, Snowflake, Wanted, deprecated, makelist, not_implement
|
||||||
from suou.sqlalchemy import create_session, declarative_base, id_column, parent_children, snowflake_column
|
from suou.sqlalchemy import create_session, declarative_base, id_column, parent_children, snowflake_column
|
||||||
from werkzeug.security import check_password_hash
|
from werkzeug.security import check_password_hash
|
||||||
|
|
||||||
from . import UserLoader, app_config
|
from . import app_config
|
||||||
from .utils import get_remote_addr
|
from .utils import get_remote_addr
|
||||||
|
|
||||||
from suou import timed_cache, age_and_days
|
from suou import timed_cache, age_and_days
|
||||||
|
|
||||||
current_user: UserLoader
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -120,6 +118,9 @@ db = SQLAlchemy(model_class=Base)
|
||||||
CSI = create_session_interactively = partial(create_session, app_config.database_url)
|
CSI = create_session_interactively = partial(create_session, app_config.database_url)
|
||||||
|
|
||||||
|
|
||||||
|
## .accounts requires db
|
||||||
|
#current_user: UserLoader
|
||||||
|
|
||||||
|
|
||||||
## Many-to-many relationship keys for some reasons have to go
|
## Many-to-many relationship keys for some reasons have to go
|
||||||
## BEFORE other table definitions.
|
## BEFORE other table definitions.
|
||||||
|
|
@ -221,18 +222,22 @@ class User(Base):
|
||||||
def age(self):
|
def age(self):
|
||||||
return age_and_days(self.gdpr_birthday)[0]
|
return age_and_days(self.gdpr_birthday)[0]
|
||||||
|
|
||||||
def simple_info(self):
|
def simple_info(self, *, typed = False):
|
||||||
"""
|
"""
|
||||||
Return essential informations for representing a user in the REST
|
Return essential informations for representing a user in the REST
|
||||||
"""
|
"""
|
||||||
## XXX change func name?
|
## XXX change func name?
|
||||||
return dict(
|
gg = dict(
|
||||||
id = Snowflake(self.id).to_b32l(),
|
id = Snowflake(self.id).to_b32l(),
|
||||||
username = self.username,
|
username = self.username,
|
||||||
display_name = self.display_name,
|
display_name = self.display_name,
|
||||||
age = self.age(),
|
age = self.age(),
|
||||||
badges = self.badges()
|
badges = self.badges(),
|
||||||
|
|
||||||
)
|
)
|
||||||
|
if typed:
|
||||||
|
gg['type'] = 'user'
|
||||||
|
return gg
|
||||||
|
|
||||||
@deprecated('updates may be not atomic. DO NOT USE until further notice')
|
@deprecated('updates may be not atomic. DO NOT USE until further notice')
|
||||||
async def reward(self, points=1):
|
async def reward(self, points=1):
|
||||||
|
|
@ -483,6 +488,21 @@ class Guild(Base):
|
||||||
session.execute(update(Member).where(Member.user_id == u.id, Member.guild_id == self.id).values(**values))
|
session.execute(update(Member).where(Member.user_id == u.id, Member.guild_id == self.id).values(**values))
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
def simple_info(self, *, typed=False):
|
||||||
|
"""
|
||||||
|
Return essential informations for representing a guild in the REST
|
||||||
|
"""
|
||||||
|
## XXX change func name?
|
||||||
|
gg = dict(
|
||||||
|
id = Snowflake(self.id).to_b32l(),
|
||||||
|
name = self.name,
|
||||||
|
display_name = self.display_name,
|
||||||
|
badges = []
|
||||||
|
)
|
||||||
|
if typed:
|
||||||
|
gg['type'] = 'guild'
|
||||||
|
return gg
|
||||||
|
|
||||||
|
|
||||||
Topic = deprecated('renamed to Guild')(Guild)
|
Topic = deprecated('renamed to Guild')(Guild)
|
||||||
|
|
||||||
|
|
@ -621,6 +641,8 @@ class Post(Base):
|
||||||
return or_(Post.author_id == user_id, Post.privacy == 0)
|
return or_(Post.author_id == user_id, Post.privacy == 0)
|
||||||
#return or_(Post.author_id == user_id, and_(Post.privacy.in_((0, 1)), ~Post.author.has_blocked_q(user_id)))
|
#return or_(Post.author_id == user_id, and_(Post.privacy.in_((0, 1)), ~Post.author.has_blocked_q(user_id)))
|
||||||
|
|
||||||
|
def is_text_post(self):
|
||||||
|
return self.post_type == POST_TYPE_DEFAULT
|
||||||
|
|
||||||
class Comment(Base):
|
class Comment(Base):
|
||||||
__tablename__ = 'freak_comment'
|
__tablename__ = 'freak_comment'
|
||||||
|
|
@ -716,3 +738,4 @@ class UserStrike(Base):
|
||||||
# PostUpvote table is at the top !!
|
# PostUpvote table is at the top !!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,20 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from flask import abort
|
from flask import abort
|
||||||
|
from pydantic import BaseModel
|
||||||
from quart import Blueprint, redirect, url_for
|
from quart import Blueprint, redirect, url_for
|
||||||
from quart_auth import current_user, login_required
|
from quart_auth import AuthUser, current_user, login_required, login_user, logout_user
|
||||||
from quart_schema import QuartSchema, validate_request, validate_response
|
from quart_schema import QuartSchema, validate_request, validate_response
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from suou import Snowflake, deprecated, not_implemented, want_isodate
|
from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate
|
||||||
|
|
||||||
|
from werkzeug.security import check_password_hash
|
||||||
from suou.quart import add_rest
|
from suou.quart import add_rest
|
||||||
|
|
||||||
from ..models import Post, User, db
|
from freak.accounts import LoginStatus, check_login
|
||||||
|
from freak.algorithms import topic_timeline
|
||||||
|
|
||||||
|
from ..models import Guild, Post, User, db
|
||||||
from .. import UserLoader, app, app_config, __version__ as freak_version
|
from .. import UserLoader, app, app_config, __version__ as freak_version
|
||||||
|
|
||||||
bp = Blueprint('rest', __name__, url_prefix='/v1')
|
bp = Blueprint('rest', __name__, url_prefix='/v1')
|
||||||
|
|
@ -34,7 +39,9 @@ async def get_nurupo():
|
||||||
async def health():
|
async def health():
|
||||||
return dict(
|
return dict(
|
||||||
version=freak_version,
|
version=freak_version,
|
||||||
name = app_config.app_name
|
name = app_config.app_name,
|
||||||
|
post_count = await Post.count(),
|
||||||
|
user_count = await User.active_count()
|
||||||
)
|
)
|
||||||
|
|
||||||
## TODO coverage of REST is still partial, but it's planned
|
## TODO coverage of REST is still partial, but it's planned
|
||||||
|
|
@ -88,8 +95,85 @@ async def get_post(id: int):
|
||||||
id = f'{Snowflake(p.id):l}',
|
id = f'{Snowflake(p.id):l}',
|
||||||
title = p.title,
|
title = p.title,
|
||||||
author = p.author.simple_info(),
|
author = p.author.simple_info(),
|
||||||
to = p.topic_or_user().handle(),
|
to = p.topic_or_user().simple_info(typed=True),
|
||||||
created_at = p.created_at.isoformat('T')
|
created_at = p.created_at.isoformat('T')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if p.is_text_post():
|
||||||
|
pj['content'] = p.text_content
|
||||||
|
|
||||||
return dict(posts={f'{Snowflake(id):l}': pj})
|
return dict(posts={f'{Snowflake(id):l}': pj})
|
||||||
|
|
||||||
|
async def _guild_info(gu: Guild):
|
||||||
|
return dict(
|
||||||
|
id = f'{Snowflake(gu.id):l}',
|
||||||
|
name = gu.name,
|
||||||
|
display_name = gu.display_name,
|
||||||
|
description = gu.description,
|
||||||
|
created_at = want_isodate(gu.created_at),
|
||||||
|
badges = []
|
||||||
|
)
|
||||||
|
|
||||||
|
@bp.get('/guild/@<gname>')
|
||||||
|
async def guild_info_only(gname: str):
|
||||||
|
async with db as session:
|
||||||
|
gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar()
|
||||||
|
|
||||||
|
if gu is None:
|
||||||
|
return dict(error='Not found'), 404
|
||||||
|
gj = await _guild_info(gu)
|
||||||
|
|
||||||
|
return dict(guilds={f'{Snowflake(gu.id):l}': gj})
|
||||||
|
|
||||||
|
|
||||||
|
@bp.get('/guild/@<gname>/feed')
|
||||||
|
async def guild_feed(gname: str):
|
||||||
|
async with db as session:
|
||||||
|
gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar()
|
||||||
|
|
||||||
|
if gu is None:
|
||||||
|
return dict(error='Not found'), 404
|
||||||
|
gj = await _guild_info(gu)
|
||||||
|
|
||||||
|
# TODO add feed
|
||||||
|
feed = []
|
||||||
|
algo = topic_timeline(gname)
|
||||||
|
posts: list[Post] = makelist((await db.paginate(algo)))
|
||||||
|
for p in posts:
|
||||||
|
feed.append(p.feed_info())
|
||||||
|
|
||||||
|
return dict(guilds={f'{Snowflake(gu.id):l}': gj}, feed=feed)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginIn(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
remember: bool
|
||||||
|
|
||||||
|
@bp.post('/login')
|
||||||
|
@validate_request(LoginIn)
|
||||||
|
async def login(data: LoginIn):
|
||||||
|
async with db as session:
|
||||||
|
u = (await session.execute(select(User).where(User.username == data.username))).scalar()
|
||||||
|
match check_login(u, data.password):
|
||||||
|
case LoginStatus.SUCCESS:
|
||||||
|
remember_for = int(data.remember)
|
||||||
|
if remember_for > 0:
|
||||||
|
login_user(UserLoader(u.get_id()), remember=True)
|
||||||
|
else:
|
||||||
|
login_user(UserLoader(u.get_id()))
|
||||||
|
return {'id': f'{Snowflake(u.id):l}'}, 204
|
||||||
|
case LoginStatus.ERROR:
|
||||||
|
abort(404, 'Invalid username or password')
|
||||||
|
case LoginStatus.SUSPENDED:
|
||||||
|
abort(403, 'Your account is suspended')
|
||||||
|
case LoginStatus.PASS_EXPIRED:
|
||||||
|
abort(403, 'You need to reset your password following the procedure.')
|
||||||
|
|
||||||
|
|
||||||
|
@bp.post('/logout')
|
||||||
|
@login_required
|
||||||
|
async def logout():
|
||||||
|
logout_user()
|
||||||
|
return {}, 204
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,25 +24,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
bp = Blueprint('accounts', __name__)
|
bp = Blueprint('accounts', __name__)
|
||||||
|
|
||||||
class LoginStatus(enum.Enum):
|
from ..accounts import LoginStatus, check_login
|
||||||
SUCCESS = 0
|
|
||||||
ERROR = 1
|
|
||||||
SUSPENDED = 2
|
|
||||||
PASS_EXPIRED = 3
|
|
||||||
|
|
||||||
def check_login(user: User | None, password: str) -> LoginStatus:
|
|
||||||
try:
|
|
||||||
if user is None:
|
|
||||||
return LoginStatus.ERROR
|
|
||||||
if ('$' not in user.passhash) and user.email:
|
|
||||||
return LoginStatus.PASS_EXPIRED
|
|
||||||
if not user.is_active:
|
|
||||||
return LoginStatus.SUSPENDED
|
|
||||||
if user.check_password(password):
|
|
||||||
return LoginStatus.SUCCESS
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'{e}')
|
|
||||||
return LoginStatus.ERROR
|
|
||||||
|
|
||||||
|
|
||||||
@bp.get('/login')
|
@bp.get('/login')
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue