add more REST routes, split accounts.py (user loader) from main

This commit is contained in:
Yusur 2025-09-06 16:19:49 +02:00
parent cc8858b7ac
commit 87d2eb6d0b
5 changed files with 205 additions and 82 deletions

View file

@ -26,7 +26,7 @@ from suou import twocolon_list, WantsContentType
from .colors import color_themes, theme_classes
__version__ = '0.5.0-dev34'
__version__ = '0.5.0-dev35'
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
@ -57,57 +57,10 @@ app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
app.config['QUART_AUTH_DURATION'] = 365 * 24 * 60 * 60
app.config['SERVER_NAME'] = app_config.server_name
class UserLoader(AuthUser):
"""
Loads user from the session.
*WARNING* requires to be awaited before request before usage!
Actual User object is at .user; other attributes are proxied.
"""
def __init__(self, auth_id: str | None, action: QA_Action= QA_Action.PASS):
self._auth_id = auth_id
self._auth_obj = None
self._auth_sess = None
self.action = action
@property
def auth_id(self) -> str | None:
return self._auth_id
@property
async def is_authenticated(self) -> bool:
await self._load()
return self._auth_id is not None
async def _load(self):
if self._auth_obj is None and self._auth_id is not None:
session = self._auth_sess = await db.begin()
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
if self._auth_obj is None:
raise RuntimeError('failed to fetch user')
def __getattr__(self, key):
if self._auth_obj is None:
raise RuntimeError('user is not loaded')
return getattr(self._auth_obj, key)
def __bool__(self):
return self._auth_obj is not None
async def _unload(self):
# user is not expected to mutate
if self._auth_sess:
await self._auth_sess.rollback()
@property
def user(self):
return self._auth_obj
id: int
## DO NOT ADD LOCAL IMPORTS BEFORE THIS LINE
from .accounts import UserLoader
from .models import Guild, db, User, Post
# SASS
@ -190,7 +143,7 @@ async def _load_user():
g.no_user = True
@app.after_request
async def _unload_request(resp):
async def _unload_user(resp):
try:
await current_user._unload()
except RuntimeError as e:

81
freak/accounts.py Normal file
View file

@ -0,0 +1,81 @@
import logging
import enum
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from .models import User, db
from quart_auth import AuthUser, Action as _Action
logger = logging.getLogger(__name__)
class LoginStatus(enum.Enum):
SUCCESS = 0
ERROR = 1
SUSPENDED = 2
PASS_EXPIRED = 3
def check_login(user: User | None, password: str) -> LoginStatus:
try:
if user is None:
return LoginStatus.ERROR
if ('$' not in user.passhash) and user.email:
return LoginStatus.PASS_EXPIRED
if not user.is_active:
return LoginStatus.SUSPENDED
if user.check_password(password):
return LoginStatus.SUCCESS
except Exception as e:
logger.error(f'{e}')
return LoginStatus.ERROR
class UserLoader(AuthUser):
"""
Loads user from the session.
*WARNING* requires to be awaited before request before usage!
Actual User object is at .user; other attributes are proxied.
"""
def __init__(self, auth_id: str | None, action: _Action= _Action.PASS):
self._auth_id = auth_id
self._auth_obj = None
self._auth_sess = None
self.action = action
@property
def auth_id(self) -> str | None:
return self._auth_id
@property
async def is_authenticated(self) -> bool:
await self._load()
return self._auth_id is not None
async def _load(self):
if self._auth_obj is None and self._auth_id is not None:
async with db as session:
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
if self._auth_obj is None:
raise RuntimeError('failed to fetch user')
def __getattr__(self, key):
if self._auth_obj is None:
raise RuntimeError('user is not loaded')
return getattr(self._auth_obj, key)
def __bool__(self):
return self._auth_obj is not None
async def _unload(self):
# user is not expected to mutate
if self._auth_sess:
await self._auth_sess.rollback()
@property
def user(self):
return self._auth_obj
id: int

View file

@ -9,8 +9,8 @@ from operator import or_
import re
from threading import Lock
from typing import Any, Callable
from quart_auth import AuthUser, current_user
from sqlalchemy import Column, ExceptionContext, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \
from quart_auth import current_user
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \
CheckConstraint, Date, DateTime, Boolean, func, BigInteger, \
SmallInteger, select, update, Table
from sqlalchemy.orm import Relationship, relationship
@ -19,13 +19,11 @@ from suou import SiqType, Snowflake, Wanted, deprecated, makelist, not_implement
from suou.sqlalchemy import create_session, declarative_base, id_column, parent_children, snowflake_column
from werkzeug.security import check_password_hash
from . import UserLoader, app_config
from . import app_config
from .utils import get_remote_addr
from suou import timed_cache, age_and_days
current_user: UserLoader
import logging
logger = logging.getLogger(__name__)
@ -120,6 +118,9 @@ db = SQLAlchemy(model_class=Base)
CSI = create_session_interactively = partial(create_session, app_config.database_url)
## .accounts requires db
#current_user: UserLoader
## Many-to-many relationship keys for some reasons have to go
## BEFORE other table definitions.
@ -221,18 +222,22 @@ class User(Base):
def age(self):
return age_and_days(self.gdpr_birthday)[0]
def simple_info(self):
def simple_info(self, *, typed = False):
"""
Return essential informations for representing a user in the REST
"""
## XXX change func name?
return dict(
gg = dict(
id = Snowflake(self.id).to_b32l(),
username = self.username,
display_name = self.display_name,
age = self.age(),
badges = self.badges()
badges = self.badges(),
)
if typed:
gg['type'] = 'user'
return gg
@deprecated('updates may be not atomic. DO NOT USE until further notice')
async def reward(self, points=1):
@ -482,6 +487,21 @@ class Guild(Base):
async with db as session:
session.execute(update(Member).where(Member.user_id == u.id, Member.guild_id == self.id).values(**values))
return m
def simple_info(self, *, typed=False):
"""
Return essential informations for representing a guild in the REST
"""
## XXX change func name?
gg = dict(
id = Snowflake(self.id).to_b32l(),
name = self.name,
display_name = self.display_name,
badges = []
)
if typed:
gg['type'] = 'guild'
return gg
Topic = deprecated('renamed to Guild')(Guild)
@ -621,6 +641,8 @@ class Post(Base):
return or_(Post.author_id == user_id, Post.privacy == 0)
#return or_(Post.author_id == user_id, and_(Post.privacy.in_((0, 1)), ~Post.author.has_blocked_q(user_id)))
def is_text_post(self):
return self.post_type == POST_TYPE_DEFAULT
class Comment(Base):
__tablename__ = 'freak_comment'
@ -716,3 +738,4 @@ class UserStrike(Base):
# PostUpvote table is at the top !!

View file

@ -2,15 +2,20 @@
from __future__ import annotations
from flask import abort
from pydantic import BaseModel
from quart import Blueprint, redirect, url_for
from quart_auth import current_user, login_required
from quart_auth import AuthUser, current_user, login_required, login_user, logout_user
from quart_schema import QuartSchema, validate_request, validate_response
from sqlalchemy import select
from suou import Snowflake, deprecated, not_implemented, want_isodate
from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate
from werkzeug.security import check_password_hash
from suou.quart import add_rest
from ..models import Post, User, db
from freak.accounts import LoginStatus, check_login
from freak.algorithms import topic_timeline
from ..models import Guild, Post, User, db
from .. import UserLoader, app, app_config, __version__ as freak_version
bp = Blueprint('rest', __name__, url_prefix='/v1')
@ -34,7 +39,9 @@ async def get_nurupo():
async def health():
return dict(
version=freak_version,
name = app_config.app_name
name = app_config.app_name,
post_count = await Post.count(),
user_count = await User.active_count()
)
## TODO coverage of REST is still partial, but it's planned
@ -88,8 +95,85 @@ async def get_post(id: int):
id = f'{Snowflake(p.id):l}',
title = p.title,
author = p.author.simple_info(),
to = p.topic_or_user().handle(),
to = p.topic_or_user().simple_info(typed=True),
created_at = p.created_at.isoformat('T')
)
if p.is_text_post():
pj['content'] = p.text_content
return dict(posts={f'{Snowflake(id):l}': pj})
async def _guild_info(gu: Guild):
return dict(
id = f'{Snowflake(gu.id):l}',
name = gu.name,
display_name = gu.display_name,
description = gu.description,
created_at = want_isodate(gu.created_at),
badges = []
)
@bp.get('/guild/@<gname>')
async def guild_info_only(gname: str):
async with db as session:
gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar()
if gu is None:
return dict(error='Not found'), 404
gj = await _guild_info(gu)
return dict(guilds={f'{Snowflake(gu.id):l}': gj})
@bp.get('/guild/@<gname>/feed')
async def guild_feed(gname: str):
async with db as session:
gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar()
if gu is None:
return dict(error='Not found'), 404
gj = await _guild_info(gu)
# TODO add feed
feed = []
algo = topic_timeline(gname)
posts: list[Post] = makelist((await db.paginate(algo)))
for p in posts:
feed.append(p.feed_info())
return dict(guilds={f'{Snowflake(gu.id):l}': gj}, feed=feed)
class LoginIn(BaseModel):
username: str
password: str
remember: bool
@bp.post('/login')
@validate_request(LoginIn)
async def login(data: LoginIn):
async with db as session:
u = (await session.execute(select(User).where(User.username == data.username))).scalar()
match check_login(u, data.password):
case LoginStatus.SUCCESS:
remember_for = int(data.remember)
if remember_for > 0:
login_user(UserLoader(u.get_id()), remember=True)
else:
login_user(UserLoader(u.get_id()))
return {'id': f'{Snowflake(u.id):l}'}, 204
case LoginStatus.ERROR:
abort(404, 'Invalid username or password')
case LoginStatus.SUSPENDED:
abort(403, 'Your account is suspended')
case LoginStatus.PASS_EXPIRED:
abort(403, 'You need to reset your password following the procedure.')
@bp.post('/logout')
@login_required
async def logout():
logout_user()
return {}, 204

View file

@ -24,25 +24,7 @@ logger = logging.getLogger(__name__)
bp = Blueprint('accounts', __name__)
class LoginStatus(enum.Enum):
SUCCESS = 0
ERROR = 1
SUSPENDED = 2
PASS_EXPIRED = 3
def check_login(user: User | None, password: str) -> LoginStatus:
try:
if user is None:
return LoginStatus.ERROR
if ('$' not in user.passhash) and user.email:
return LoginStatus.PASS_EXPIRED
if not user.is_active:
return LoginStatus.SUSPENDED
if user.check_password(password):
return LoginStatus.SUCCESS
except Exception as e:
logger.error(f'{e}')
return LoginStatus.ERROR
from ..accounts import LoginStatus, check_login
@bp.get('/login')