add register handlers
This commit is contained in:
parent
77316d783e
commit
6f02df98dd
7 changed files with 130 additions and 24 deletions
|
|
@ -26,7 +26,7 @@ from suou import twocolon_list, WantsContentType
|
|||
|
||||
from .colors import color_themes, theme_classes
|
||||
|
||||
__version__ = '0.5.0-dev68'
|
||||
__version__ = '0.5.0-dev76'
|
||||
|
||||
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import enum
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from suou import age_and_days
|
||||
from suou.sqlalchemy.asyncio import AsyncSession
|
||||
from .models import User, db
|
||||
from werkzeug.security import generate_password_hash
|
||||
from .models import REPORT_REASONS, User, db
|
||||
from quart_auth import AuthUser, Action as _Action
|
||||
from quart_wtf.utils import validate_csrf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,6 +37,60 @@ def check_login(user: User | None, password: str) -> LoginStatus:
|
|||
logger.error(f'{e}')
|
||||
return LoginStatus.ERROR
|
||||
|
||||
class RegisterIn(BaseModel):
|
||||
username: str
|
||||
display_name: str = ""
|
||||
password: str
|
||||
confirm_password: str
|
||||
email: str | None = None
|
||||
birthday: str
|
||||
invite_code: str | None = None
|
||||
|
||||
class RegisterStatus(enum.Enum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
USERNAME_TAKEN = 2
|
||||
IP_BANNED = 3
|
||||
USERNAME_INVALID = 4
|
||||
PASSWORD_INVALID = 5
|
||||
DATE_INVALID = 6
|
||||
|
||||
async def validate_register(data: RegisterIn) -> RegisterStatus | dict:
|
||||
f = {}
|
||||
|
||||
try:
|
||||
birthday = datetime.date.fromisoformat(data.birthday)
|
||||
birthday_age = age_and_days(birthday)
|
||||
|
||||
if birthday_age == (0, 0):
|
||||
return RegisterStatus.DATE_INVALID
|
||||
if birthday_age < (14,):
|
||||
f['banned_at'] = datetime.datetime.now()
|
||||
f['banned_reason'] = REPORT_REASONS['underage']
|
||||
except ValueError:
|
||||
return RegisterStatus.DATE_INVALID
|
||||
|
||||
f['username'] = data.username.lower()
|
||||
if not re.fullmatch('[a-z0-9_-]+', f['username']):
|
||||
return RegisterStatus.USERNAME_INVALID
|
||||
f['display_name'] = data.display_name
|
||||
|
||||
if not data.password or data.password != data.confirm_password:
|
||||
return RegisterStatus.PASSWORD_INVALID
|
||||
f['passhash'] = generate_password_hash(data.password)
|
||||
|
||||
f['email'] = data.email
|
||||
|
||||
async with db as session:
|
||||
# TODO check ip ban
|
||||
# TODO implement IpBan table
|
||||
|
||||
# TODO check invite code [will be implemented in 0.6]
|
||||
|
||||
pass
|
||||
|
||||
return f
|
||||
|
||||
|
||||
class UserLoader(AuthUser):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,16 +10,19 @@ from __future__ import annotations
|
|||
import re
|
||||
from quart import Blueprint, abort, flash, redirect, request
|
||||
from sqlalchemy import delete, insert, select
|
||||
|
||||
from freak import UserLoader
|
||||
from freak.utils import get_request_form
|
||||
from .models import Guild, Member, UserBlock, db, User, Post, PostUpvote, username_is_legal
|
||||
from quart_auth import current_user, login_required
|
||||
from suou import deprecated
|
||||
|
||||
from . import UserLoader
|
||||
from .utils import get_request_form
|
||||
from .models import Guild, Member, UserBlock, db, User, Post, PostUpvote, username_is_legal
|
||||
|
||||
|
||||
current_user: UserLoader
|
||||
|
||||
bp = Blueprint('ajax', __name__)
|
||||
|
||||
@deprecated("please use /v1/username/@<username>")
|
||||
@bp.route('/username_availability/<username>')
|
||||
@bp.route('/ajax/username_availability/<username>')
|
||||
async def username_availability(username: str):
|
||||
|
|
|
|||
|
|
@ -384,6 +384,8 @@ class User(Base):
|
|||
user = (await session.execute(user_q)).scalar()
|
||||
return user
|
||||
|
||||
# TODO add table UserInvite [planned for 0.6]
|
||||
|
||||
# UserBlock table is at the top !!
|
||||
|
||||
## END User
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
import datetime
|
||||
import re
|
||||
import sys
|
||||
from typing import Iterable, TypeVar
|
||||
import logging
|
||||
|
|
@ -13,10 +14,10 @@ from quart_auth import current_user, login_required, login_user, logout_user
|
|||
from quart_schema import validate_request
|
||||
from quart_wtf.csrf import generate_csrf
|
||||
from sqlalchemy import delete, insert, select, __version__ as sa_version
|
||||
from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate
|
||||
from suou import Snowflake, age_and_days, deprecated, makelist, not_implemented, want_isodate
|
||||
|
||||
from suou.classtools import MISSING, MissingType
|
||||
from werkzeug.security import check_password_hash
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
from suou.quart import add_rest
|
||||
|
||||
# quart does not define __version__
|
||||
|
|
@ -26,7 +27,7 @@ from freak.accounts import LoginStatus, check_login
|
|||
from freak.algorithms import private_timeline, public_timeline, top_guilds_query, topic_timeline, user_timeline
|
||||
from freak.search import SearchQuery
|
||||
|
||||
from ..models import Comment, Guild, Post, PostUpvote, User, db
|
||||
from ..models import REPORT_REASONS, Comment, Guild, Post, PostUpvote, User, db, username_is_legal
|
||||
from .. import UserLoader, app, app_config, __version__ as freak_version, csrf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -35,7 +36,7 @@ _T = TypeVar('_T')
|
|||
bp = Blueprint('rest', __name__, url_prefix='/v1')
|
||||
rest = add_rest(app, '/v1', '/ajax')
|
||||
|
||||
## XXX potential security hole, but needed for REST to work
|
||||
## XXX potential security hole, but somewhat needed for REST to work
|
||||
csrf.exempt(bp)
|
||||
|
||||
current_user: UserLoader
|
||||
|
|
@ -57,7 +58,8 @@ async def health():
|
|||
post_count = await Post.count(),
|
||||
user_count = await User.active_count(),
|
||||
me = Snowflake(current_user.id).to_b32l() if current_user else None,
|
||||
color_theme = current_user.color_theme if current_user else 0
|
||||
color_theme = current_user.color_theme if current_user else 0,
|
||||
invite_only = False # TODO implement invites!
|
||||
)
|
||||
|
||||
return hi
|
||||
|
|
@ -375,6 +377,50 @@ async def logout():
|
|||
logout_user()
|
||||
return '', 204
|
||||
|
||||
from ..accounts import RegisterIn, RegisterStatus, validate_register
|
||||
|
||||
|
||||
@bp.post('/register')
|
||||
@validate_request(RegisterIn)
|
||||
async def register(data: RegisterIn):
|
||||
# validate register form
|
||||
match await validate_register(data):
|
||||
case RegisterStatus.DATE_INVALID:
|
||||
abort(400, "Invalid date format")
|
||||
case RegisterStatus.USERNAME_INVALID:
|
||||
abort(400, "Username can contain only letters, digits, underscores and dashes.")
|
||||
case RegisterStatus.PASSWORD_INVALID:
|
||||
abort(400, "Passwords do not match")
|
||||
case RegisterStatus.USERNAME_TAKEN:
|
||||
abort(409, "A user with this username already exists.")
|
||||
case RegisterStatus.IP_BANNED:
|
||||
abort(403, "Your IP address is banned.")
|
||||
case user_data:
|
||||
if not isinstance(user_data, dict):
|
||||
abort(500)
|
||||
async with db as session:
|
||||
new_user_id: int = (await session.execute(insert(User).values(**user_data).returning(User.id))).scalar()
|
||||
|
||||
session.commit()
|
||||
return dict(id=Snowflake(new_user_id).to_b32l()), 200
|
||||
|
||||
|
||||
@bp.get('/username/@<username>')
|
||||
async def username_availability(username):
|
||||
is_valid = username_is_legal(username)
|
||||
|
||||
if is_valid:
|
||||
async with db as session:
|
||||
user = (await session.execute(select(User).where(User.username == username))).scalar()
|
||||
|
||||
is_available = user is None or user == current_user.user
|
||||
else:
|
||||
is_available = False
|
||||
|
||||
return {
|
||||
"is_available": is_available,
|
||||
"is_valid": is_valid
|
||||
}
|
||||
|
||||
## HOME ##
|
||||
|
||||
|
|
|
|||
|
|
@ -269,9 +269,10 @@ header.header {
|
|||
border-radius: 0;
|
||||
border: 0;
|
||||
border-bottom: 2px solid var(--border);
|
||||
background-color: inherit;
|
||||
focus-background-color: var(--bg-sharp);
|
||||
focus-border-color: var(--accent); }
|
||||
background-color: inherit;}
|
||||
header.header .mini-search-bar [type="search"]:focus {
|
||||
background-color: var(--bg-sharp);
|
||||
border-color: var(--accent); }
|
||||
header.header .mini-search-bar [type="submit"] {
|
||||
height: 0;
|
||||
width: 0;
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ def _check_ip_bans(ip) -> bool:
|
|||
return True
|
||||
return False
|
||||
|
||||
@deprecated('register accounts from the API instead')
|
||||
async def validate_register_form() -> dict:
|
||||
form = await get_request_form()
|
||||
f = dict()
|
||||
|
|
@ -106,10 +107,10 @@ async def validate_register_form() -> dict:
|
|||
|
||||
f['email'] = form['email'] or None
|
||||
|
||||
is_ip_banned: bool = await _check_ip_bans()
|
||||
# is_ip_banned: bool = await _check_ip_bans()
|
||||
|
||||
if is_ip_banned:
|
||||
raise ValueError('Your IP address is banned.')
|
||||
# if is_ip_banned:
|
||||
# raise ValueError('Your IP address is banned.')
|
||||
|
||||
if _currently_logged_in() and not form.get('confirm_another'):
|
||||
raise ValueError('You are already logged in. Please confirm you want to create another account by checking the option.')
|
||||
|
|
@ -117,13 +118,6 @@ async def validate_register_form() -> dict:
|
|||
raise ValueError('You must accept Terms in order to create an account.')
|
||||
|
||||
return f
|
||||
|
||||
|
||||
class RegisterStatus(enum.Enum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
USERNAME_TAKEN = 2
|
||||
IP_BANNED = 3
|
||||
|
||||
|
||||
@bp.post('/register')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue