add more REST routes, split accounts.py (user loader) from main
This commit is contained in:
parent
cc8858b7ac
commit
87d2eb6d0b
5 changed files with 205 additions and 82 deletions
|
|
@ -26,7 +26,7 @@ from suou import twocolon_list, WantsContentType
|
|||
|
||||
from .colors import color_themes, theme_classes
|
||||
|
||||
__version__ = '0.5.0-dev34'
|
||||
__version__ = '0.5.0-dev35'
|
||||
|
||||
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
|
|
@ -57,57 +57,10 @@ app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
|
|||
app.config['QUART_AUTH_DURATION'] = 365 * 24 * 60 * 60
|
||||
app.config['SERVER_NAME'] = app_config.server_name
|
||||
|
||||
class UserLoader(AuthUser):
|
||||
"""
|
||||
Loads user from the session.
|
||||
|
||||
*WARNING* requires to be awaited before request before usage!
|
||||
|
||||
Actual User object is at .user; other attributes are proxied.
|
||||
"""
|
||||
def __init__(self, auth_id: str | None, action: QA_Action= QA_Action.PASS):
|
||||
self._auth_id = auth_id
|
||||
self._auth_obj = None
|
||||
self._auth_sess = None
|
||||
self.action = action
|
||||
|
||||
@property
|
||||
def auth_id(self) -> str | None:
|
||||
return self._auth_id
|
||||
|
||||
@property
|
||||
async def is_authenticated(self) -> bool:
|
||||
await self._load()
|
||||
return self._auth_id is not None
|
||||
|
||||
async def _load(self):
|
||||
if self._auth_obj is None and self._auth_id is not None:
|
||||
session = self._auth_sess = await db.begin()
|
||||
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
|
||||
if self._auth_obj is None:
|
||||
raise RuntimeError('failed to fetch user')
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._auth_obj is None:
|
||||
raise RuntimeError('user is not loaded')
|
||||
return getattr(self._auth_obj, key)
|
||||
|
||||
def __bool__(self):
|
||||
return self._auth_obj is not None
|
||||
|
||||
async def _unload(self):
|
||||
# user is not expected to mutate
|
||||
if self._auth_sess:
|
||||
await self._auth_sess.rollback()
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return self._auth_obj
|
||||
|
||||
id: int
|
||||
|
||||
## DO NOT ADD LOCAL IMPORTS BEFORE THIS LINE
|
||||
|
||||
from .accounts import UserLoader
|
||||
from .models import Guild, db, User, Post
|
||||
|
||||
# SASS
|
||||
|
|
@ -190,7 +143,7 @@ async def _load_user():
|
|||
g.no_user = True
|
||||
|
||||
@app.after_request
|
||||
async def _unload_request(resp):
|
||||
async def _unload_user(resp):
|
||||
try:
|
||||
await current_user._unload()
|
||||
except RuntimeError as e:
|
||||
|
|
|
|||
81
freak/accounts.py
Normal file
81
freak/accounts.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
|
||||
|
||||
import logging
|
||||
import enum
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from .models import User, db
|
||||
from quart_auth import AuthUser, Action as _Action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoginStatus(enum.Enum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
SUSPENDED = 2
|
||||
PASS_EXPIRED = 3
|
||||
|
||||
def check_login(user: User | None, password: str) -> LoginStatus:
|
||||
try:
|
||||
if user is None:
|
||||
return LoginStatus.ERROR
|
||||
if ('$' not in user.passhash) and user.email:
|
||||
return LoginStatus.PASS_EXPIRED
|
||||
if not user.is_active:
|
||||
return LoginStatus.SUSPENDED
|
||||
if user.check_password(password):
|
||||
return LoginStatus.SUCCESS
|
||||
except Exception as e:
|
||||
logger.error(f'{e}')
|
||||
return LoginStatus.ERROR
|
||||
|
||||
|
||||
class UserLoader(AuthUser):
|
||||
"""
|
||||
Loads user from the session.
|
||||
|
||||
*WARNING* requires to be awaited before request before usage!
|
||||
|
||||
Actual User object is at .user; other attributes are proxied.
|
||||
"""
|
||||
def __init__(self, auth_id: str | None, action: _Action= _Action.PASS):
|
||||
self._auth_id = auth_id
|
||||
self._auth_obj = None
|
||||
self._auth_sess = None
|
||||
self.action = action
|
||||
|
||||
@property
|
||||
def auth_id(self) -> str | None:
|
||||
return self._auth_id
|
||||
|
||||
@property
|
||||
async def is_authenticated(self) -> bool:
|
||||
await self._load()
|
||||
return self._auth_id is not None
|
||||
|
||||
async def _load(self):
|
||||
if self._auth_obj is None and self._auth_id is not None:
|
||||
async with db as session:
|
||||
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
|
||||
if self._auth_obj is None:
|
||||
raise RuntimeError('failed to fetch user')
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._auth_obj is None:
|
||||
raise RuntimeError('user is not loaded')
|
||||
return getattr(self._auth_obj, key)
|
||||
|
||||
def __bool__(self):
|
||||
return self._auth_obj is not None
|
||||
|
||||
async def _unload(self):
|
||||
# user is not expected to mutate
|
||||
if self._auth_sess:
|
||||
await self._auth_sess.rollback()
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return self._auth_obj
|
||||
|
||||
id: int
|
||||
|
|
@ -9,8 +9,8 @@ from operator import or_
|
|||
import re
|
||||
from threading import Lock
|
||||
from typing import Any, Callable
|
||||
from quart_auth import AuthUser, current_user
|
||||
from sqlalchemy import Column, ExceptionContext, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \
|
||||
from quart_auth import current_user
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \
|
||||
CheckConstraint, Date, DateTime, Boolean, func, BigInteger, \
|
||||
SmallInteger, select, update, Table
|
||||
from sqlalchemy.orm import Relationship, relationship
|
||||
|
|
@ -19,13 +19,11 @@ from suou import SiqType, Snowflake, Wanted, deprecated, makelist, not_implement
|
|||
from suou.sqlalchemy import create_session, declarative_base, id_column, parent_children, snowflake_column
|
||||
from werkzeug.security import check_password_hash
|
||||
|
||||
from . import UserLoader, app_config
|
||||
from . import app_config
|
||||
from .utils import get_remote_addr
|
||||
|
||||
from suou import timed_cache, age_and_days
|
||||
|
||||
current_user: UserLoader
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -120,6 +118,9 @@ db = SQLAlchemy(model_class=Base)
|
|||
CSI = create_session_interactively = partial(create_session, app_config.database_url)
|
||||
|
||||
|
||||
## .accounts requires db
|
||||
#current_user: UserLoader
|
||||
|
||||
|
||||
## Many-to-many relationship keys for some reasons have to go
|
||||
## BEFORE other table definitions.
|
||||
|
|
@ -221,18 +222,22 @@ class User(Base):
|
|||
def age(self):
|
||||
return age_and_days(self.gdpr_birthday)[0]
|
||||
|
||||
def simple_info(self):
|
||||
def simple_info(self, *, typed = False):
|
||||
"""
|
||||
Return essential informations for representing a user in the REST
|
||||
"""
|
||||
## XXX change func name?
|
||||
return dict(
|
||||
gg = dict(
|
||||
id = Snowflake(self.id).to_b32l(),
|
||||
username = self.username,
|
||||
display_name = self.display_name,
|
||||
age = self.age(),
|
||||
badges = self.badges()
|
||||
badges = self.badges(),
|
||||
|
||||
)
|
||||
if typed:
|
||||
gg['type'] = 'user'
|
||||
return gg
|
||||
|
||||
@deprecated('updates may be not atomic. DO NOT USE until further notice')
|
||||
async def reward(self, points=1):
|
||||
|
|
@ -483,6 +488,21 @@ class Guild(Base):
|
|||
session.execute(update(Member).where(Member.user_id == u.id, Member.guild_id == self.id).values(**values))
|
||||
return m
|
||||
|
||||
def simple_info(self, *, typed=False):
|
||||
"""
|
||||
Return essential informations for representing a guild in the REST
|
||||
"""
|
||||
## XXX change func name?
|
||||
gg = dict(
|
||||
id = Snowflake(self.id).to_b32l(),
|
||||
name = self.name,
|
||||
display_name = self.display_name,
|
||||
badges = []
|
||||
)
|
||||
if typed:
|
||||
gg['type'] = 'guild'
|
||||
return gg
|
||||
|
||||
|
||||
Topic = deprecated('renamed to Guild')(Guild)
|
||||
|
||||
|
|
@ -621,6 +641,8 @@ class Post(Base):
|
|||
return or_(Post.author_id == user_id, Post.privacy == 0)
|
||||
#return or_(Post.author_id == user_id, and_(Post.privacy.in_((0, 1)), ~Post.author.has_blocked_q(user_id)))
|
||||
|
||||
def is_text_post(self):
|
||||
return self.post_type == POST_TYPE_DEFAULT
|
||||
|
||||
class Comment(Base):
|
||||
__tablename__ = 'freak_comment'
|
||||
|
|
@ -716,3 +738,4 @@ class UserStrike(Base):
|
|||
# PostUpvote table is at the top !!
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask import abort
|
||||
from pydantic import BaseModel
|
||||
from quart import Blueprint, redirect, url_for
|
||||
from quart_auth import current_user, login_required
|
||||
from quart_auth import AuthUser, current_user, login_required, login_user, logout_user
|
||||
from quart_schema import QuartSchema, validate_request, validate_response
|
||||
from sqlalchemy import select
|
||||
from suou import Snowflake, deprecated, not_implemented, want_isodate
|
||||
from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate
|
||||
|
||||
from werkzeug.security import check_password_hash
|
||||
from suou.quart import add_rest
|
||||
|
||||
from ..models import Post, User, db
|
||||
from freak.accounts import LoginStatus, check_login
|
||||
from freak.algorithms import topic_timeline
|
||||
|
||||
from ..models import Guild, Post, User, db
|
||||
from .. import UserLoader, app, app_config, __version__ as freak_version
|
||||
|
||||
bp = Blueprint('rest', __name__, url_prefix='/v1')
|
||||
|
|
@ -34,7 +39,9 @@ async def get_nurupo():
|
|||
async def health():
|
||||
return dict(
|
||||
version=freak_version,
|
||||
name = app_config.app_name
|
||||
name = app_config.app_name,
|
||||
post_count = await Post.count(),
|
||||
user_count = await User.active_count()
|
||||
)
|
||||
|
||||
## TODO coverage of REST is still partial, but it's planned
|
||||
|
|
@ -88,8 +95,85 @@ async def get_post(id: int):
|
|||
id = f'{Snowflake(p.id):l}',
|
||||
title = p.title,
|
||||
author = p.author.simple_info(),
|
||||
to = p.topic_or_user().handle(),
|
||||
to = p.topic_or_user().simple_info(typed=True),
|
||||
created_at = p.created_at.isoformat('T')
|
||||
)
|
||||
|
||||
if p.is_text_post():
|
||||
pj['content'] = p.text_content
|
||||
|
||||
return dict(posts={f'{Snowflake(id):l}': pj})
|
||||
|
||||
async def _guild_info(gu: Guild):
|
||||
return dict(
|
||||
id = f'{Snowflake(gu.id):l}',
|
||||
name = gu.name,
|
||||
display_name = gu.display_name,
|
||||
description = gu.description,
|
||||
created_at = want_isodate(gu.created_at),
|
||||
badges = []
|
||||
)
|
||||
|
||||
@bp.get('/guild/@<gname>')
|
||||
async def guild_info_only(gname: str):
|
||||
async with db as session:
|
||||
gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar()
|
||||
|
||||
if gu is None:
|
||||
return dict(error='Not found'), 404
|
||||
gj = await _guild_info(gu)
|
||||
|
||||
return dict(guilds={f'{Snowflake(gu.id):l}': gj})
|
||||
|
||||
|
||||
@bp.get('/guild/@<gname>/feed')
|
||||
async def guild_feed(gname: str):
|
||||
async with db as session:
|
||||
gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar()
|
||||
|
||||
if gu is None:
|
||||
return dict(error='Not found'), 404
|
||||
gj = await _guild_info(gu)
|
||||
|
||||
# TODO add feed
|
||||
feed = []
|
||||
algo = topic_timeline(gname)
|
||||
posts: list[Post] = makelist((await db.paginate(algo)))
|
||||
for p in posts:
|
||||
feed.append(p.feed_info())
|
||||
|
||||
return dict(guilds={f'{Snowflake(gu.id):l}': gj}, feed=feed)
|
||||
|
||||
|
||||
class LoginIn(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
remember: bool
|
||||
|
||||
@bp.post('/login')
|
||||
@validate_request(LoginIn)
|
||||
async def login(data: LoginIn):
|
||||
async with db as session:
|
||||
u = (await session.execute(select(User).where(User.username == data.username))).scalar()
|
||||
match check_login(u, data.password):
|
||||
case LoginStatus.SUCCESS:
|
||||
remember_for = int(data.remember)
|
||||
if remember_for > 0:
|
||||
login_user(UserLoader(u.get_id()), remember=True)
|
||||
else:
|
||||
login_user(UserLoader(u.get_id()))
|
||||
return {'id': f'{Snowflake(u.id):l}'}, 204
|
||||
case LoginStatus.ERROR:
|
||||
abort(404, 'Invalid username or password')
|
||||
case LoginStatus.SUSPENDED:
|
||||
abort(403, 'Your account is suspended')
|
||||
case LoginStatus.PASS_EXPIRED:
|
||||
abort(403, 'You need to reset your password following the procedure.')
|
||||
|
||||
|
||||
@bp.post('/logout')
|
||||
@login_required
|
||||
async def logout():
|
||||
logout_user()
|
||||
return {}, 204
|
||||
|
||||
|
|
|
|||
|
|
@ -24,25 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
bp = Blueprint('accounts', __name__)
|
||||
|
||||
class LoginStatus(enum.Enum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
SUSPENDED = 2
|
||||
PASS_EXPIRED = 3
|
||||
|
||||
def check_login(user: User | None, password: str) -> LoginStatus:
|
||||
try:
|
||||
if user is None:
|
||||
return LoginStatus.ERROR
|
||||
if ('$' not in user.passhash) and user.email:
|
||||
return LoginStatus.PASS_EXPIRED
|
||||
if not user.is_active:
|
||||
return LoginStatus.SUSPENDED
|
||||
if user.check_password(password):
|
||||
return LoginStatus.SUCCESS
|
||||
except Exception as e:
|
||||
logger.error(f'{e}')
|
||||
return LoginStatus.ERROR
|
||||
from ..accounts import LoginStatus, check_login
|
||||
|
||||
|
||||
@bp.get('/login')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue