From 87d2eb6d0ba017fbb6a44599123a4a60b30063b8 Mon Sep 17 00:00:00 2001 From: Yusur Princeps Date: Sat, 6 Sep 2025 16:19:49 +0200 Subject: [PATCH] add more REST routes, split accounts.py (user loader) from main --- freak/__init__.py | 53 ++-------------------- freak/accounts.py | 81 +++++++++++++++++++++++++++++++++ freak/models.py | 39 ++++++++++++---- freak/rest/__init__.py | 94 ++++++++++++++++++++++++++++++++++++--- freak/website/accounts.py | 20 +-------- 5 files changed, 205 insertions(+), 82 deletions(-) create mode 100644 freak/accounts.py diff --git a/freak/__init__.py b/freak/__init__.py index 287fae0..8d8cc2c 100644 --- a/freak/__init__.py +++ b/freak/__init__.py @@ -26,7 +26,7 @@ from suou import twocolon_list, WantsContentType from .colors import color_themes, theme_classes -__version__ = '0.5.0-dev34' +__version__ = '0.5.0-dev35' APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__)) @@ -57,57 +57,10 @@ app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False app.config['QUART_AUTH_DURATION'] = 365 * 24 * 60 * 60 app.config['SERVER_NAME'] = app_config.server_name -class UserLoader(AuthUser): - """ - Loads user from the session. - - *WARNING* requires to be awaited before request before usage! - - Actual User object is at .user; other attributes are proxied. - """ - def __init__(self, auth_id: str | None, action: QA_Action= QA_Action.PASS): - self._auth_id = auth_id - self._auth_obj = None - self._auth_sess = None - self.action = action - - @property - def auth_id(self) -> str | None: - return self._auth_id - - @property - async def is_authenticated(self) -> bool: - await self._load() - return self._auth_id is not None - - async def _load(self): - if self._auth_obj is None and self._auth_id is not None: - session = self._auth_sess = await db.begin() - self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar() - if self._auth_obj is None: - raise RuntimeError('failed to fetch user') - - def __getattr__(self, key): - if self._auth_obj is None: - raise RuntimeError('user is not loaded') - return getattr(self._auth_obj, key) - - def __bool__(self): - return self._auth_obj is not None - - async def _unload(self): - # user is not expected to mutate - if self._auth_sess: - await self._auth_sess.rollback() - - @property - def user(self): - return self._auth_obj - - id: int ## DO NOT ADD LOCAL IMPORTS BEFORE THIS LINE +from .accounts import UserLoader from .models import Guild, db, User, Post # SASS @@ -190,7 +143,7 @@ async def _load_user(): g.no_user = True @app.after_request -async def _unload_request(resp): +async def _unload_user(resp): try: await current_user._unload() except RuntimeError as e: diff --git a/freak/accounts.py b/freak/accounts.py new file mode 100644 index 0000000..b42108b --- /dev/null +++ b/freak/accounts.py @@ -0,0 +1,81 @@ + + +import logging +import enum + +from sqlalchemy import select +from sqlalchemy.orm import selectinload +from .models import User, db +from quart_auth import AuthUser, Action as _Action + +logger = logging.getLogger(__name__) + +class LoginStatus(enum.Enum): + SUCCESS = 0 + ERROR = 1 + SUSPENDED = 2 + PASS_EXPIRED = 3 + +def check_login(user: User | None, password: str) -> LoginStatus: + try: + if user is None: + return LoginStatus.ERROR + if ('$' not in user.passhash) and user.email: + return LoginStatus.PASS_EXPIRED + if not user.is_active: + return LoginStatus.SUSPENDED + if user.check_password(password): + return LoginStatus.SUCCESS + except Exception as e: + logger.error(f'{e}') + return LoginStatus.ERROR + + +class UserLoader(AuthUser): + """ + Loads user from the session. + + *WARNING* requires to be awaited before request before usage! + + Actual User object is at .user; other attributes are proxied. + """ + def __init__(self, auth_id: str | None, action: _Action= _Action.PASS): + self._auth_id = auth_id + self._auth_obj = None + self._auth_sess = None + self.action = action + + @property + def auth_id(self) -> str | None: + return self._auth_id + + @property + async def is_authenticated(self) -> bool: + await self._load() + return self._auth_id is not None + + async def _load(self): + if self._auth_obj is None and self._auth_id is not None: + async with db as session: + self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar() + if self._auth_obj is None: + raise RuntimeError('failed to fetch user') + + def __getattr__(self, key): + if self._auth_obj is None: + raise RuntimeError('user is not loaded') + return getattr(self._auth_obj, key) + + def __bool__(self): + return self._auth_obj is not None + + async def _unload(self): + # user is not expected to mutate + if self._auth_sess: + await self._auth_sess.rollback() + + @property + def user(self): + return self._auth_obj + + id: int diff --git a/freak/models.py b/freak/models.py index 86e77db..f255e62 100644 --- a/freak/models.py +++ b/freak/models.py @@ -9,8 +9,8 @@ from operator import or_ import re from threading import Lock from typing import Any, Callable -from quart_auth import AuthUser, current_user -from sqlalchemy import Column, ExceptionContext, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \ +from quart_auth import current_user +from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, and_, insert, text, \ CheckConstraint, Date, DateTime, Boolean, func, BigInteger, \ SmallInteger, select, update, Table from sqlalchemy.orm import Relationship, relationship @@ -19,13 +19,11 @@ from suou import SiqType, Snowflake, Wanted, deprecated, makelist, not_implement from suou.sqlalchemy import create_session, declarative_base, id_column, parent_children, snowflake_column from werkzeug.security import check_password_hash -from . import UserLoader, app_config +from . import app_config from .utils import get_remote_addr from suou import timed_cache, age_and_days -current_user: UserLoader - import logging logger = logging.getLogger(__name__) @@ -120,6 +118,9 @@ db = SQLAlchemy(model_class=Base) CSI = create_session_interactively = partial(create_session, app_config.database_url) +## .accounts requires db +#current_user: UserLoader + ## Many-to-many relationship keys for some reasons have to go ## BEFORE other table definitions. @@ -221,18 +222,22 @@ class User(Base): def age(self): return age_and_days(self.gdpr_birthday)[0] - def simple_info(self): + def simple_info(self, *, typed = False): """ Return essential informations for representing a user in the REST """ ## XXX change func name? - return dict( + gg = dict( id = Snowflake(self.id).to_b32l(), username = self.username, display_name = self.display_name, age = self.age(), - badges = self.badges() + badges = self.badges(), + ) + if typed: + gg['type'] = 'user' + return gg @deprecated('updates may be not atomic. DO NOT USE until further notice') async def reward(self, points=1): @@ -482,6 +487,21 @@ class Guild(Base): async with db as session: session.execute(update(Member).where(Member.user_id == u.id, Member.guild_id == self.id).values(**values)) return m + + def simple_info(self, *, typed=False): + """ + Return essential informations for representing a guild in the REST + """ + ## XXX change func name? + gg = dict( + id = Snowflake(self.id).to_b32l(), + name = self.name, + display_name = self.display_name, + badges = [] + ) + if typed: + gg['type'] = 'guild' + return gg Topic = deprecated('renamed to Guild')(Guild) @@ -621,6 +641,8 @@ class Post(Base): return or_(Post.author_id == user_id, Post.privacy == 0) #return or_(Post.author_id == user_id, and_(Post.privacy.in_((0, 1)), ~Post.author.has_blocked_q(user_id))) + def is_text_post(self): + return self.post_type == POST_TYPE_DEFAULT class Comment(Base): __tablename__ = 'freak_comment' @@ -716,3 +738,4 @@ class UserStrike(Base): # PostUpvote table is at the top !! + diff --git a/freak/rest/__init__.py b/freak/rest/__init__.py index 91080ec..f1ce627 100644 --- a/freak/rest/__init__.py +++ b/freak/rest/__init__.py @@ -2,15 +2,20 @@ from __future__ import annotations from flask import abort +from pydantic import BaseModel from quart import Blueprint, redirect, url_for -from quart_auth import current_user, login_required +from quart_auth import AuthUser, current_user, login_required, login_user, logout_user from quart_schema import QuartSchema, validate_request, validate_response from sqlalchemy import select -from suou import Snowflake, deprecated, not_implemented, want_isodate +from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate +from werkzeug.security import check_password_hash from suou.quart import add_rest -from ..models import Post, User, db +from freak.accounts import LoginStatus, check_login +from freak.algorithms import topic_timeline + +from ..models import Guild, Post, User, db from .. import UserLoader, app, app_config, __version__ as freak_version bp = Blueprint('rest', __name__, url_prefix='/v1') @@ -34,7 +39,9 @@ async def get_nurupo(): async def health(): return dict( version=freak_version, - name = app_config.app_name + name = app_config.app_name, + post_count = await Post.count(), + user_count = await User.active_count() ) ## TODO coverage of REST is still partial, but it's planned @@ -88,8 +95,85 @@ async def get_post(id: int): id = f'{Snowflake(p.id):l}', title = p.title, author = p.author.simple_info(), - to = p.topic_or_user().handle(), + to = p.topic_or_user().simple_info(typed=True), created_at = p.created_at.isoformat('T') ) + if p.is_text_post(): + pj['content'] = p.text_content + return dict(posts={f'{Snowflake(id):l}': pj}) + +async def _guild_info(gu: Guild): + return dict( + id = f'{Snowflake(gu.id):l}', + name = gu.name, + display_name = gu.display_name, + description = gu.description, + created_at = want_isodate(gu.created_at), + badges = [] + ) + +@bp.get('/guild/@') +async def guild_info_only(gname: str): + async with db as session: + gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar() + + if gu is None: + return dict(error='Not found'), 404 + gj = await _guild_info(gu) + + return dict(guilds={f'{Snowflake(gu.id):l}': gj}) + + +@bp.get('/guild/@/feed') +async def guild_feed(gname: str): + async with db as session: + gu: Guild | None = (await session.execute(select(Guild).where(Guild.name == gname))).scalar() + + if gu is None: + return dict(error='Not found'), 404 + gj = await _guild_info(gu) + + # TODO add feed + feed = [] + algo = topic_timeline(gname) + posts: list[Post] = makelist((await db.paginate(algo))) + for p in posts: + feed.append(p.feed_info()) + + return dict(guilds={f'{Snowflake(gu.id):l}': gj}, feed=feed) + + +class LoginIn(BaseModel): + username: str + password: str + remember: bool + +@bp.post('/login') +@validate_request(LoginIn) +async def login(data: LoginIn): + async with db as session: + u = (await session.execute(select(User).where(User.username == data.username))).scalar() + match check_login(u, data.password): + case LoginStatus.SUCCESS: + remember_for = int(data.remember) + if remember_for > 0: + login_user(UserLoader(u.get_id()), remember=True) + else: + login_user(UserLoader(u.get_id())) + return {'id': f'{Snowflake(u.id):l}'}, 204 + case LoginStatus.ERROR: + abort(404, 'Invalid username or password') + case LoginStatus.SUSPENDED: + abort(403, 'Your account is suspended') + case LoginStatus.PASS_EXPIRED: + abort(403, 'You need to reset your password following the procedure.') + + +@bp.post('/logout') +@login_required +async def logout(): + logout_user() + return {}, 204 + diff --git a/freak/website/accounts.py b/freak/website/accounts.py index fb452b4..db14eba 100644 --- a/freak/website/accounts.py +++ b/freak/website/accounts.py @@ -24,25 +24,7 @@ logger = logging.getLogger(__name__) bp = Blueprint('accounts', __name__) -class LoginStatus(enum.Enum): - SUCCESS = 0 - ERROR = 1 - SUSPENDED = 2 - PASS_EXPIRED = 3 - -def check_login(user: User | None, password: str) -> LoginStatus: - try: - if user is None: - return LoginStatus.ERROR - if ('$' not in user.passhash) and user.email: - return LoginStatus.PASS_EXPIRED - if not user.is_active: - return LoginStatus.SUSPENDED - if user.check_password(password): - return LoginStatus.SUCCESS - except Exception as e: - logger.error(f'{e}') - return LoginStatus.ERROR +from ..accounts import LoginStatus, check_login @bp.get('/login')