freak/freak/accounts.py

149 lines
4 KiB
Python

import datetime
import logging
import enum
import re
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from suou import age_and_days
from suou.sqlalchemy.asyncio import AsyncSession
from werkzeug.security import generate_password_hash
from .models import REPORT_REASONS, User, db
from quart_auth import AuthUser, Action as _Action
from quart_wtf.utils import validate_csrf
logger = logging.getLogger(__name__)
class LoginStatus(enum.Enum):
SUCCESS = 0
ERROR = 1
SUSPENDED = 2
PASS_EXPIRED = 3
def check_login(user: User | None, password: str) -> LoginStatus:
try:
if user is None:
return LoginStatus.ERROR
if ('$' not in user.passhash) and user.email:
return LoginStatus.PASS_EXPIRED
if not user.is_active:
return LoginStatus.SUSPENDED
if user.check_password(password):
return LoginStatus.SUCCESS
except Exception as e:
logger.error(f'{e}')
return LoginStatus.ERROR
class RegisterIn(BaseModel):
username: str
display_name: str = ""
password: str
confirm_password: str
email: str | None = None
birthday: str
invite_code: str | None = None
class RegisterStatus(enum.Enum):
SUCCESS = 0
ERROR = 1
USERNAME_TAKEN = 2
IP_BANNED = 3
USERNAME_INVALID = 4
PASSWORD_INVALID = 5
DATE_INVALID = 6
async def validate_register(data: RegisterIn) -> RegisterStatus | dict:
f = {}
try:
birthday = datetime.date.fromisoformat(data.birthday)
birthday_age = age_and_days(birthday)
if birthday_age == (0, 0):
return RegisterStatus.DATE_INVALID
if birthday_age < (14,):
f['banned_at'] = datetime.datetime.now()
f['banned_reason'] = REPORT_REASONS['underage']
except ValueError:
return RegisterStatus.DATE_INVALID
f['username'] = data.username.lower()
if not re.fullmatch('[a-z0-9_-]+', f['username']):
return RegisterStatus.USERNAME_INVALID
f['display_name'] = data.display_name
if not data.password or data.password != data.confirm_password:
return RegisterStatus.PASSWORD_INVALID
f['passhash'] = generate_password_hash(data.password)
f['email'] = data.email
async with db as session:
# TODO check ip ban
# TODO implement IpBan table
# TODO check invite code [will be implemented in 0.6]
pass
return f
class UserLoader(AuthUser):
"""
Loads user from the session.
*WARNING* requires to be awaited before request before usage!
Actual User object is at .user; other attributes are proxied.
"""
def __init__(self, auth_id: str | None, action: _Action= _Action.PASS):
self._auth_id = auth_id
self._auth_obj = None
self._auth_sess: AsyncSession | None = None
self.action = action
@property
def auth_id(self) -> str | None:
return self._auth_id
@property
async def is_authenticated(self) -> bool:
await self._load()
return self._auth_id is not None
async def _load(self):
if self._auth_obj is None and self._auth_id is not None:
async with db as session:
self._auth_obj = (await session.execute(select(User).where(User.id == int(self._auth_id)))).scalar()
if self._auth_obj is None:
raise RuntimeError('failed to fetch user')
def __getattr__(self, key):
if self._auth_obj is None:
raise RuntimeError('user is not loaded')
return getattr(self._auth_obj, key)
def __bool__(self):
return self._auth_obj is not None
@property
def session(self):
return self._auth_sess
async def _unload(self):
# user is not expected to mutate
if self._auth_sess:
await self._auth_sess.rollback()
@property
def user(self):
return self._auth_obj
id: int
username: str
display_name: str
color_theme: int