From eb8371757dcd87b1d3675963726352d16f274323 Mon Sep 17 00:00:00 2001 From: Yusur Princeps Date: Thu, 4 Sep 2025 01:25:25 +0200 Subject: [PATCH] add ArgConfigSource(), 3 helpers to .sqlalchemy, add .waiter --- CHANGELOG.md | 5 +- src/suou/configparse.py | 27 ++++++++++- src/suou/flask_sqlalchemy.py | 5 +- src/suou/sqlalchemy/__init__.py | 6 ++- src/suou/sqlalchemy/orm.py | 81 +++++++++++++++++++++++++++++++++ src/suou/waiter.py | 57 +++++++++++++++++++++++ 6 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 src/suou/waiter.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fffb0ea..a31ce90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,10 @@ ## 0.6.0 -... ++ `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports. ++ Add `.waiter` module. For now, non-functional. ++ Add those new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`. Also removed dead batteries. ++ Add `ArgConfigSource` to `.configparse` ## 0.5.3 diff --git a/src/suou/configparse.py b/src/suou/configparse.py index 9709b0a..8687cb4 100644 --- a/src/suou/configparse.py +++ b/src/suou/configparse.py @@ -23,6 +23,8 @@ import os from typing import Any, Callable, Iterator, override from collections import OrderedDict +from argparse import Namespace + from .classtools import ValueSource, ValueProperty from .functools import deprecated from .exceptions import MissingConfigError, MissingConfigWarning @@ -105,6 +107,28 @@ class DictConfigSource(ConfigSource): def __len__(self) -> int: return len(self._d) +class ArgConfigSource(ValueSource): + """ + It assumes arguments have already been parsed + + NEW 0.6""" + _ns: Namespace + def __init__(self, ns: Namespace): + super().__init__() + self._ns = ns + def __getitem__(self, key): + return getattr(self._ns, key) + def get(self, key, value): + return getattr(self._ns, key, value) + def __contains__(self, key: str, /) -> bool: + return hasattr(self._ns, key) + @deprecated('Here for Mapping() implementation. Untested and unused') + def __iter__(self) -> Iterator[str]: + yield from self._ns._get_args() + @deprecated('Here for Mapping() implementation. Untested and unused') + def __len__(self) -> int: + return len(self._ns._get_args()) + class ConfigValue(ValueProperty): """ A single config property. @@ -205,7 +229,8 @@ class ConfigOptions: __all__ = ( - 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'EnvConfigSource', 'ConfigParserConfigSource', 'DictConfigSource', 'ConfigSource', 'ConfigValue' + 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'EnvConfigSource', 'ConfigParserConfigSource', 'DictConfigSource', 'ConfigSource', 'ConfigValue', + 'ArgConfigSource' ) diff --git a/src/suou/flask_sqlalchemy.py b/src/suou/flask_sqlalchemy.py index 0704460..94afc6f 100644 --- a/src/suou/flask_sqlalchemy.py +++ b/src/suou/flask_sqlalchemy.py @@ -20,10 +20,12 @@ from typing import Any, Callable, Never from flask import abort, request from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import DeclarativeBase, Session +from .functools import deprecated from .codecs import want_bytes from .sqlalchemy import AuthSrc, require_auth_base +@deprecated('inherits from deprecated and unused class') class FlaskAuthSrc(AuthSrc): ''' @@ -45,6 +47,7 @@ class FlaskAuthSrc(AuthSrc): def required_exc(self): abort(401, 'Login required') +@deprecated('not intuitive to use') def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Callable]: """ Make an auth_required() decorator for Flask views. @@ -77,4 +80,4 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca return auth_required # Optional dependency: do not import into __init__.py -__all__ = ('require_auth', ) +__all__ = () diff --git a/src/suou/sqlalchemy/__init__.py b/src/suou/sqlalchemy/__init__.py index 81f61f8..2603d0b 100644 --- a/src/suou/sqlalchemy/__init__.py +++ b/src/suou/sqlalchemy/__init__.py @@ -157,13 +157,17 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query -from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column +from .orm import ( + id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, + author_pair, age_pair, bound_fk, unbound_fk, want_column, a_relationship, BitSelector, secret_column +) # Optional dependency: do not import into __init__.py __all__ = ( 'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint', 'bool_column', 'parent_children', 'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column', + 'a_relationship', 'BitSelector', 'secret_column', # .asyncio 'SQLAlchemy', 'AsyncSelectPagination', 'async_query' ) \ No newline at end of file diff --git a/src/suou/sqlalchemy/orm.py b/src/suou/sqlalchemy/orm.py index 6d75a4f..063bcef 100644 --- a/src/suou/sqlalchemy/orm.py +++ b/src/suou/sqlalchemy/orm.py @@ -19,6 +19,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from binascii import Incomplete +import os from typing import Any, Callable import warnings from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text @@ -167,6 +168,16 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]: return (date_col, acc_col) +def secret_column(length: int = 64, max_length: int | None = None, gen: Callable[[int], bytes] = os.urandom, nullable=False, **kwargs): + """ + Column filled in by default with random bits (64 by default). Useful for secrets. + + NEW 0.6.0 + """ + max_length = max_length or length + return Column(LargeBinary(max_length), default=lambda: gen(length), nullable=nullable, **kwargs) + + def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]: """ @@ -191,6 +202,17 @@ def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Inco return parent, child +def a_relationship(primary = None, /, j=None, *, lazy='selectin', **kwargs): + """ + Shorthand for relationship() that sets lazy='selectin' automatically. + + NEW 0.6.0 + """ + if j: + kwargs['primaryjoin'] = j + return relationship(primary, lazy=lazy, **kwargs) # pyright: ignore[reportArgumentType] + + def unbound_fk(target: str | Column | InstrumentedAttribute, typ: _T | None = None, **kwargs) -> Column[_T | IdType]: """ Shorthand for creating a "unbound" foreign key column from a column name, the referenced column. @@ -232,3 +254,62 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs) + +class _BitComparator(Comparator): + """ + Comparator object for BitSelector() + + NEW 0.6.0 + """ + _column: Column + _flag: int + def __init__(self, col, flag): + self._column = col + self._flag = flag + def _bulk_update_tuples(self, value): + return [ (self._column, self._upd_exp(value)) ] + def operate(self, op, other, **kwargs): + return op(self._sel_exp(), self._flag if other else 0, **kwargs) + def __clause_element__(self): + return self._column + def __str__(self): + return self._column + def _sel_exp(self): + return self._column.op('&')(self._flag) + def _upd_exp(self, value): + return self._column.op('|')(self._flag) if value else self._column.op('&')(~self._flag) + +class BitSelector: + """ + "Virtual" column representing a single bit in an integer column (usually a BigInteger). + + Mimicks peewee's 'BitField()' behavior, with SQLAlchemy. + + NEW 0.6.0 + """ + _column: Column + _flag: int + _name: str + def __init__(self, column, flag: int): + if bin(flag := int(flag))[2:].rstrip('0') != '1': + warnings.warn('using non-powers of 2 as flags may cause errors or undefined behavior', FutureWarning) + self._column = column + self._flag = flag + def __set_name__(self, name, owner=None): + self._name = name + def __get__(self, obj, objtype=None): + if obj: + return getattr(obj, self._column.name) & self._flag > 0 + else: + return _BitComparator(self._column, self._flag) + def __set__(self, obj, val): + if obj: + orig = getattr(obj, self._column.name) + if val: + orig |= self._flag + else: + orig &= ~(self._flag) + setattr(obj, self._column.name, orig) + else: + raise NotImplementedError + diff --git a/src/suou/waiter.py b/src/suou/waiter.py new file mode 100644 index 0000000..74d2d0e --- /dev/null +++ b/src/suou/waiter.py @@ -0,0 +1,57 @@ +""" +Content serving API over HTTP, based on Starlette. + +NEW 0.6.0 + +--- + +Copyright (c) 2025 Sakuragasaki46. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +See LICENSE for the specific language governing permissions and +limitations under the License. + +This software is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" + +import warnings +from starlette.applications import Starlette +from starlette.responses import JSONResponse, PlainTextResponse, Response +from starlette.routing import Route + +class Waiter(): + def __init__(self): + self.routes: list[Route] = [] + self.production = False + + def _build_app(self) -> Starlette: + return Starlette( + debug = not self.production, + routes= self.routes + ) + + ## TODO get, post, etc. + +def ok(content = None, **ka): + if content is None: + return Response(status_code=204, **ka) + elif isinstance(content, dict): + return JSONResponse(content, **ka) + elif isinstance(content, str): + return PlainTextResponse(content, **ka) + return content + +def ko(status: int, /, content = None, **ka): + if status < 400 or status > 599: + warnings.warn(f'HTTP {status} is not an error status', UserWarning) + if content is None: + return Response(status_code=status, **ka) + elif isinstance(content, dict): + return JSONResponse(content, status_code=status, **ka) + elif isinstance(content, str): + return PlainTextResponse(content, status_code=status, **ka) + return content + +__all__ = ('ko', 'ok', 'Waiter') \ No newline at end of file