diff --git a/freak/models.py b/freak/models.py index 392f6ad..0cbc61a 100644 --- a/freak/models.py +++ b/freak/models.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from collections import namedtuple import datetime from functools import partial @@ -663,6 +664,17 @@ class Post(Base): created_at = self.created_at ) + async def feed_info_counts(self): + pj = self.feed_info() + if self.is_text_post(): + pj['content'] = self.text_content[:181] + (pj['comment_count'], pj['votes'], pj['my_vote']) = await asyncio.gather( + self.comment_count(), + self.upvotes(), + self.upvoted_by(current_user.user) + ) + return pj + class Comment(Base): __tablename__ = 'freak_comment' __table_args__ = ( diff --git a/freak/rest/__init__.py b/freak/rest/__init__.py index 1f66eb5..5656d2f 100644 --- a/freak/rest/__init__.py +++ b/freak/rest/__init__.py @@ -7,7 +7,7 @@ from quart import abort, Blueprint, redirect, request, url_for from pydantic import BaseModel from quart_auth import AuthUser, current_user, login_required, login_user, logout_user from quart_schema import QuartSchema, validate_request, validate_response -from sqlalchemy import select +from sqlalchemy import delete, insert, select from suou import Snowflake, deprecated, makelist, not_implemented, want_isodate from werkzeug.security import check_password_hash @@ -17,7 +17,7 @@ from freak.accounts import LoginStatus, check_login from freak.algorithms import public_timeline, top_guilds_query, topic_timeline, user_timeline from freak.search import SearchQuery -from ..models import Guild, Post, User, db +from ..models import Guild, Post, PostUpvote, User, db from .. import UserLoader, app, app_config, __version__ as freak_version, csrf bp = Blueprint('rest', __name__, url_prefix='/v1') @@ -116,7 +116,7 @@ async def user_feed_get(id: int): algo = user_timeline(u) posts = await db.paginate(algo) async for p in posts: - feed.append(p.feed_info()) + feed.append(await p.feed_info_counts()) return dict(users={f'{Snowflake(id):l}': uj}, feed=feed) @@ -155,8 +155,45 @@ async def get_post(id: int): if p.is_text_post(): pj['content'] = p.text_content + pj['comment_count'] = await p.comment_count() + pj['votes'] = await p.upvotes() + pj['my_vote'] = await p.upvoted_by(current_user.user) + return dict(posts={f'{Snowflake(id):l}': pj}) +class VoteIn(BaseModel): + vote: int + +@bp.post('/post//upvote') +@validate_request(VoteIn) +async def upvote_post(id: int, data: VoteIn): + async with db as session: + p: Post | None = (await session.execute(select(Post).where(Post.id == id))).scalar() + + if p is None: + return { 'status': 404, 'error': 'Post not found' }, 404 + + cur_score = await p.upvoted_by(current_user.user) + + match (data.vote, cur_score): + case (1, 0) | (1, -1): + await session.execute(delete(PostUpvote).where(PostUpvote.c.post_id == p.id, PostUpvote.c.voter_id == current_user.id, PostUpvote.c.is_downvote == True)) + await session.execute(insert(PostUpvote).values(post_id = p.id, voter_id = current_user.id, is_downvote = False)) + case (0, _): + await session.execute(delete(PostUpvote).where(PostUpvote.c.post_id == p.id, PostUpvote.c.voter_id == current_user.id)) + case (-1, 1) | (-1, 0): + await session.execute(delete(PostUpvote).where(PostUpvote.c.post_id == p.id, PostUpvote.c.voter_id == current_user.id, PostUpvote.c.is_downvote == False)) + await session.execute(insert(PostUpvote).values(post_id = p.id, voter_id = current_user.id, is_downvote = True)) + case (1, 1) | (1, -1): + pass + case _: + await session.rollback() + return { 'status': 400, 'error': 'Invalid score' }, 400 + + await session.commit() + return { 'votes': await p.upvotes() } + + ## GUILDS ## async def _guild_info(gu: Guild): @@ -204,7 +241,7 @@ async def guild_feed(gname: str): algo = topic_timeline(gname) posts = await db.paginate(algo) async for p in posts: - feed.append(p.feed_info()) + feed.append(await p.feed_info_counts()) return dict(guilds={f'{Snowflake(gu.id):l}': gj}, feed=feed) @@ -253,7 +290,7 @@ async def home_feed(): posts = await db.paginate(public_timeline()) feed = [] async for post in posts: - feed.append(post.feed_info()) + feed.append(await post.feed_info_counts()) return dict(feed=feed)