add register handlers
This commit is contained in:
parent
77316d783e
commit
6f02df98dd
7 changed files with 130 additions and 24 deletions
|
|
@ -26,7 +26,7 @@ from suou import twocolon_list, WantsContentType
|
||||||
|
|
||||||
from .colors import color_themes, theme_classes
|
from .colors import color_themes, theme_classes
|
||||||
|
|
||||||
__version__ = '0.5.0-dev68'
|
__version__ = '0.5.0-dev76'
|
||||||
|
|
||||||
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
|
APP_BASE_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
|
|
||||||
|
|
||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import enum
|
import enum
|
||||||
|
import re
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
from suou import age_and_days
|
||||||
from suou.sqlalchemy.asyncio import AsyncSession
|
from suou.sqlalchemy.asyncio import AsyncSession
|
||||||
from .models import User, db
|
from werkzeug.security import generate_password_hash
|
||||||
|
from .models import REPORT_REASONS, User, db
|
||||||
from quart_auth import AuthUser, Action as _Action
|
from quart_auth import AuthUser, Action as _Action
|
||||||
|
from quart_wtf.utils import validate_csrf
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -31,6 +37,60 @@ def check_login(user: User | None, password: str) -> LoginStatus:
|
||||||
logger.error(f'{e}')
|
logger.error(f'{e}')
|
||||||
return LoginStatus.ERROR
|
return LoginStatus.ERROR
|
||||||
|
|
||||||
|
class RegisterIn(BaseModel):
|
||||||
|
username: str
|
||||||
|
display_name: str = ""
|
||||||
|
password: str
|
||||||
|
confirm_password: str
|
||||||
|
email: str | None = None
|
||||||
|
birthday: str
|
||||||
|
invite_code: str | None = None
|
||||||
|
|
||||||
|
class RegisterStatus(enum.Enum):
|
||||||
|
SUCCESS = 0
|
||||||
|
ERROR = 1
|
||||||
|
USERNAME_TAKEN = 2
|
||||||
|
IP_BANNED = 3
|
||||||
|
USERNAME_INVALID = 4
|
||||||
|
PASSWORD_INVALID = 5
|
||||||
|
DATE_INVALID = 6
|
||||||
|
|
||||||
|
async def validate_register(data: RegisterIn) -> RegisterStatus | dict:
|
||||||
|
f = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
birthday = datetime.date.fromisoformat(data.birthday)
|
||||||
|
birthday_age = age_and_days(birthday)
|
||||||
|
|
||||||
|
if birthday_age == (0, 0):
|
||||||
|
return RegisterStatus.DATE_INVALID
|
||||||
|
if birthday_age < (14,):
|
||||||
|
f['banned_at'] = datetime.datetime.now()
|
||||||
|
f['banned_reason'] = REPORT_REASONS['underage']
|
||||||
|
except ValueError:
|
||||||
|
return RegisterStatus.DATE_INVALID
|
||||||
|
|
||||||
|
f['username'] = data.username.lower()
|
||||||
|
if not re.fullmatch('[a-z0-9_-]+', f['username']):
|
||||||
|
return RegisterStatus.USERNAME_INVALID
|
||||||
|
f['display_name'] = data.display_name
|
||||||
|
|
||||||
|
if not data.password or data.password != data.confirm_password:
|
||||||
|
return RegisterStatus.PASSWORD_INVALID
|
||||||
|
f['passhash'] = generate_password_hash(data.password)
|
||||||
|
|
||||||
|
f['email'] = data.email
|
||||||
|
|
||||||
|
async with db as session:
|
||||||
|
# TODO check ip ban
|
||||||
|
# TODO implement IpBan table
|
||||||
|
|
||||||
|
# TODO check invite code [will be implemented in 0.6]
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
class UserLoader(AuthUser):
|
class UserLoader(AuthUser):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -10,16 +10,19 @@ from __future__ import annotations
|
||||||
import re
|
import re
|
||||||
from quart import Blueprint, abort, flash, redirect, request
|
from quart import Blueprint, abort, flash, redirect, request
|
||||||
from sqlalchemy import delete, insert, select
|
from sqlalchemy import delete, insert, select
|
||||||
|
|
||||||
from freak import UserLoader
|
|
||||||
from freak.utils import get_request_form
|
|
||||||
from .models import Guild, Member, UserBlock, db, User, Post, PostUpvote, username_is_legal
|
|
||||||
from quart_auth import current_user, login_required
|
from quart_auth import current_user, login_required
|
||||||
|
from suou import deprecated
|
||||||
|
|
||||||
|
from . import UserLoader
|
||||||
|
from .utils import get_request_form
|
||||||
|
from .models import Guild, Member, UserBlock, db, User, Post, PostUpvote, username_is_legal
|
||||||
|
|
||||||
|
|
||||||
current_user: UserLoader
|
current_user: UserLoader
|
||||||
|
|
||||||
bp = Blueprint('ajax', __name__)
|
bp = Blueprint('ajax', __name__)
|
||||||
|
|
||||||
|
@deprecated("please use /v1/username/@<username>")
|
||||||
@bp.route('/username_availability/<username>')
|
@bp.route('/username_availability/<username>')
|
||||||
@bp.route('/ajax/username_availability/<username>')
|
@bp.route('/ajax/username_availability/<username>')
|
||||||
async def username_availability(username: str):
|
async def username_availability(username: str):
|
||||||
|
|
|
||||||
|
|
@ -384,6 +384,8 @@ class User(Base):
|
||||||
user = (await session.execute(user_q)).scalar()
|
user = (await session.execute(user_q)).scalar()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
# TODO add table UserInvite [planned for 0.6]
|
||||||
|
|
||||||
# UserBlock table is at the top !!
|
# UserBlock table is at the top !!
|
||||||
|
|
||||||
## END User
|
## END User
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import datetime
|
import datetime
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Iterable, TypeVar
|
from typing import Iterable, TypeVar
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -13,10 +14,10 @@ from quart_auth import current_user, login_required, login_user, logout_user
|
||||||
from quart_schema import validate_request
|
from quart_schema import validate_request
|
||||||
from quart_wtf.csrf import generate_csrf
|
from quart_wtf.csrf import generate_csrf
|
||||||
from sqlalchemy import delete, insert, select, __version__ as sa_version
|
from sqlalchemy import delete, insert, select, __version__ as sa_version
|
||||||
from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate
|
from suou import Snowflake, age_and_days, deprecated, makelist, not_implemented, want_isodate
|
||||||
|
|
||||||
from suou.classtools import MISSING, MissingType
|
from suou.classtools import MISSING, MissingType
|
||||||
from werkzeug.security import check_password_hash
|
from werkzeug.security import check_password_hash, generate_password_hash
|
||||||
from suou.quart import add_rest
|
from suou.quart import add_rest
|
||||||
|
|
||||||
# quart does not define __version__
|
# quart does not define __version__
|
||||||
|
|
@ -26,7 +27,7 @@ from freak.accounts import LoginStatus, check_login
|
||||||
from freak.algorithms import private_timeline, public_timeline, top_guilds_query, topic_timeline, user_timeline
|
from freak.algorithms import private_timeline, public_timeline, top_guilds_query, topic_timeline, user_timeline
|
||||||
from freak.search import SearchQuery
|
from freak.search import SearchQuery
|
||||||
|
|
||||||
from ..models import Comment, Guild, Post, PostUpvote, User, db
|
from ..models import REPORT_REASONS, Comment, Guild, Post, PostUpvote, User, db, username_is_legal
|
||||||
from .. import UserLoader, app, app_config, __version__ as freak_version, csrf
|
from .. import UserLoader, app, app_config, __version__ as freak_version, csrf
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -35,7 +36,7 @@ _T = TypeVar('_T')
|
||||||
bp = Blueprint('rest', __name__, url_prefix='/v1')
|
bp = Blueprint('rest', __name__, url_prefix='/v1')
|
||||||
rest = add_rest(app, '/v1', '/ajax')
|
rest = add_rest(app, '/v1', '/ajax')
|
||||||
|
|
||||||
## XXX potential security hole, but needed for REST to work
|
## XXX potential security hole, but somewhat needed for REST to work
|
||||||
csrf.exempt(bp)
|
csrf.exempt(bp)
|
||||||
|
|
||||||
current_user: UserLoader
|
current_user: UserLoader
|
||||||
|
|
@ -57,7 +58,8 @@ async def health():
|
||||||
post_count = await Post.count(),
|
post_count = await Post.count(),
|
||||||
user_count = await User.active_count(),
|
user_count = await User.active_count(),
|
||||||
me = Snowflake(current_user.id).to_b32l() if current_user else None,
|
me = Snowflake(current_user.id).to_b32l() if current_user else None,
|
||||||
color_theme = current_user.color_theme if current_user else 0
|
color_theme = current_user.color_theme if current_user else 0,
|
||||||
|
invite_only = False # TODO implement invites!
|
||||||
)
|
)
|
||||||
|
|
||||||
return hi
|
return hi
|
||||||
|
|
@ -375,6 +377,50 @@ async def logout():
|
||||||
logout_user()
|
logout_user()
|
||||||
return '', 204
|
return '', 204
|
||||||
|
|
||||||
|
from ..accounts import RegisterIn, RegisterStatus, validate_register
|
||||||
|
|
||||||
|
|
||||||
|
@bp.post('/register')
|
||||||
|
@validate_request(RegisterIn)
|
||||||
|
async def register(data: RegisterIn):
|
||||||
|
# validate register form
|
||||||
|
match await validate_register(data):
|
||||||
|
case RegisterStatus.DATE_INVALID:
|
||||||
|
abort(400, "Invalid date format")
|
||||||
|
case RegisterStatus.USERNAME_INVALID:
|
||||||
|
abort(400, "Username can contain only letters, digits, underscores and dashes.")
|
||||||
|
case RegisterStatus.PASSWORD_INVALID:
|
||||||
|
abort(400, "Passwords do not match")
|
||||||
|
case RegisterStatus.USERNAME_TAKEN:
|
||||||
|
abort(409, "A user with this username already exists.")
|
||||||
|
case RegisterStatus.IP_BANNED:
|
||||||
|
abort(403, "Your IP address is banned.")
|
||||||
|
case user_data:
|
||||||
|
if not isinstance(user_data, dict):
|
||||||
|
abort(500)
|
||||||
|
async with db as session:
|
||||||
|
new_user_id: int = (await session.execute(insert(User).values(**user_data).returning(User.id))).scalar()
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
return dict(id=Snowflake(new_user_id).to_b32l()), 200
|
||||||
|
|
||||||
|
|
||||||
|
@bp.get('/username/@<username>')
|
||||||
|
async def username_availability(username):
|
||||||
|
is_valid = username_is_legal(username)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
async with db as session:
|
||||||
|
user = (await session.execute(select(User).where(User.username == username))).scalar()
|
||||||
|
|
||||||
|
is_available = user is None or user == current_user.user
|
||||||
|
else:
|
||||||
|
is_available = False
|
||||||
|
|
||||||
|
return {
|
||||||
|
"is_available": is_available,
|
||||||
|
"is_valid": is_valid
|
||||||
|
}
|
||||||
|
|
||||||
## HOME ##
|
## HOME ##
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -269,9 +269,10 @@ header.header {
|
||||||
border-radius: 0;
|
border-radius: 0;
|
||||||
border: 0;
|
border: 0;
|
||||||
border-bottom: 2px solid var(--border);
|
border-bottom: 2px solid var(--border);
|
||||||
background-color: inherit;
|
background-color: inherit;}
|
||||||
focus-background-color: var(--bg-sharp);
|
header.header .mini-search-bar [type="search"]:focus {
|
||||||
focus-border-color: var(--accent); }
|
background-color: var(--bg-sharp);
|
||||||
|
border-color: var(--accent); }
|
||||||
header.header .mini-search-bar [type="submit"] {
|
header.header .mini-search-bar [type="submit"] {
|
||||||
height: 0;
|
height: 0;
|
||||||
width: 0;
|
width: 0;
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ def _check_ip_bans(ip) -> bool:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@deprecated('register accounts from the API instead')
|
||||||
async def validate_register_form() -> dict:
|
async def validate_register_form() -> dict:
|
||||||
form = await get_request_form()
|
form = await get_request_form()
|
||||||
f = dict()
|
f = dict()
|
||||||
|
|
@ -106,10 +107,10 @@ async def validate_register_form() -> dict:
|
||||||
|
|
||||||
f['email'] = form['email'] or None
|
f['email'] = form['email'] or None
|
||||||
|
|
||||||
is_ip_banned: bool = await _check_ip_bans()
|
# is_ip_banned: bool = await _check_ip_bans()
|
||||||
|
|
||||||
if is_ip_banned:
|
# if is_ip_banned:
|
||||||
raise ValueError('Your IP address is banned.')
|
# raise ValueError('Your IP address is banned.')
|
||||||
|
|
||||||
if _currently_logged_in() and not form.get('confirm_another'):
|
if _currently_logged_in() and not form.get('confirm_another'):
|
||||||
raise ValueError('You are already logged in. Please confirm you want to create another account by checking the option.')
|
raise ValueError('You are already logged in. Please confirm you want to create another account by checking the option.')
|
||||||
|
|
@ -119,13 +120,6 @@ async def validate_register_form() -> dict:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
class RegisterStatus(enum.Enum):
|
|
||||||
SUCCESS = 0
|
|
||||||
ERROR = 1
|
|
||||||
USERNAME_TAKEN = 2
|
|
||||||
IP_BANNED = 3
|
|
||||||
|
|
||||||
|
|
||||||
@bp.post('/register')
|
@bp.post('/register')
|
||||||
async def register_post():
|
async def register_post():
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue