diff --git a/CHANGELOG.md b/CHANGELOG.md index 3536f9b..eab0c55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,5 @@ # Changelog -## 0.4.0 - -+ Added `ValueProperty`, abstract superclass for `ConfigProperty`. - -## 0.3.4 - -- Bug fixes in `.flask_restx` regarding error handling -- Fixed a bug in `.configparse` dealing with unset values from multiple sources - ## 0.3.3 - Fixed leftovers in `snowflake` module from unchecked code copying — i.e. `SnowflakeGen.generate_one()` used to require an unused typ= parameter diff --git a/README.md b/README.md index 29ee187..5ae6d56 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Good morning, my brother! Welcome the SUOU (SIS Unified Object Underarmor), a library for the management of the storage of objects into a database. -It provides utilities such as [SIQ](https://yusur.moe/protocols/siq.html), signing and generation of access tokens (on top of [ItsDangerous](https://github.com/pallets/itsdangerous)) and various utilities, including helpers for use in Flask and SQLAlchemy. +It provides utilities such as [SIQ](https://sakux.moe/protocols/siq.html), signing and generation of access tokens (on top of [ItsDangerous](https://github.com/pallets/itsdangerous)) and various utilities, including helpers for use in Flask and SQLAlchemy. **It is not an ORM** nor a replacement of it; it works along existing ORMs (currently only SQLAlchemy is supported lol). diff --git a/pyproject.toml b/pyproject.toml index 58766fa..36540a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,7 @@ readme = "README.md" dependencies = [ "itsdangerous", - "toml", - "pydantic" + "toml" ] # - further devdependencies below - # @@ -37,12 +36,10 @@ sqlalchemy = [ ] flask = [ "Flask>=2.0.0", - "Flask-RestX", - "Quart", - "Quart-Schema" + "Flask-RestX" ] flask_sqlalchemy = [ - "Flask-SqlAlchemy", + "Flask-SqlAlchemy" ] peewee = [ "peewee>=3.0.0, <4.0" diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 4a14073..97743c8 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -27,7 +27,7 @@ from .itertools import makelist, kwargs_prefix, ltuple, rtuple, additem from .i18n import I18n, JsonI18n, TomlI18n from .snowflake import Snowflake, SnowflakeGen -__version__ = "0.4.0-dev26" +__version__ = "0.3.3" __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase', diff --git a/src/suou/classtools.py b/src/suou/classtools.py index 34ad58b..ebe673b 100644 --- a/src/suou/classtools.py +++ b/src/suou/classtools.py @@ -14,17 +14,10 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ -from __future__ import annotations - -from abc import ABCMeta, abstractmethod -from typing import Any, Callable, Generic, Iterable, Mapping, TypeVar - -from suou.codecs import StringCase +from typing import Any, Callable, Generic, Iterable, TypeVar _T = TypeVar('_T') -MISSING = object() - class Wanted(Generic[_T]): """ Placeholder for parameters wanted by Incomplete(). @@ -105,78 +98,6 @@ class Incomplete(Generic[_T]): clsdict[k] = v.instance() return clsdict - -class ValueSource(Mapping): - """ - Abstract value source. - """ - pass - - -class ValueProperty(Generic[_T]): - _name: str | None - _srcs: dict[str, str] - _val: Any | MISSING - _default: Any | None - _cast: Callable | None - _required: bool - _pub_name: str | bool = False - _not_found = LookupError - - def __init__(self, /, src: str | None = None, *, - default = None, cast: Callable | None = None, - required: bool = False, public: str | bool = False, - **kwargs - ): - self._srcs = dict() - if src: - self._srcs['default'] = src - self._default = default - self._cast = cast - self._required = required - self._pub_name = public - self._val = MISSING - for k, v in kwargs.items(): - if k.endswith('_src'): - self._srcs[k[:-4]] = v - else: - raise TypeError(f'unknown keyword argument {k!r}') - - def __set_name__(self, owner, name: str, *, src_name: str | None = None): - self._name = name - self._srcs.setdefault('default', src_name or name) - nsrcs = dict() - for k, v in self._srcs.items(): - if v.endswith('?'): - nsrcs[k] = v.rstrip('?') + (src_name or name) - self._srcs.update(nsrcs) - if self._pub_name is True: - self._pub_name = name - def __get__(self, obj: Any, owner = None): - if self._val is MISSING: - v = MISSING - for srckey, src in self._srcs.items(): - if (getter := self._getter(obj, srckey)): - v = getter.get(src, v) - if self._required and (not v or v is MISSING): - raise self._not_found(f'required config {self._srcs['default']} not set!') - if v is MISSING: - v = self._default - if callable(self._cast): - v = self._cast(v) if v is not None else self._cast() - self._val = v - return self._val - - @abstractmethod - def _getter(self, obj: Any, name: str = 'default') -> ValueSource: - pass - - @property - def name(self): - return self._name - - @property - def source(self, /): - return self._srcs['default'] - - +__all__ = ( + 'Wanted', 'Incomplete' +) \ No newline at end of file diff --git a/src/suou/codecs.py b/src/suou/codecs.py index 3efe53f..2bee255 100644 --- a/src/suou/codecs.py +++ b/src/suou/codecs.py @@ -227,7 +227,7 @@ def jsonencode(obj: dict, *, skipkeys: bool = True, separators: tuple[str, str] ''' return json.dumps(obj, skipkeys=skipkeys, separators=separators, default=_json_default(default), **kwargs) -jsondecode: Callable[Any, dict] = deprecated('just use json.loads()')(json.loads) +jsondecode = deprecated('just use json.loads()')(json.loads) def ssv_list(s: str, *, sep_chars = ',;') -> list[str]: """ diff --git a/src/suou/configparse.py b/src/suou/configparse.py index 9709b0a..ace075f 100644 --- a/src/suou/configparse.py +++ b/src/suou/configparse.py @@ -15,27 +15,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ from __future__ import annotations - +from abc import abstractmethod from ast import TypeVar from collections.abc import Mapping from configparser import ConfigParser as _ConfigParser import os -from typing import Any, Callable, Iterator, override +from typing import Any, Callable, Iterable, Iterator from collections import OrderedDict -from .classtools import ValueSource, ValueProperty -from .functools import deprecated -from .exceptions import MissingConfigError, MissingConfigWarning - +from .functools import deprecated_alias +MISSING = object() _T = TypeVar('T') +class MissingConfigError(LookupError): + """ + Config variable not found. -class ConfigSource(ValueSource): + Raised when a config property is marked as required, but no property with + that name is found. + """ + pass + + +class MissingConfigWarning(MissingConfigError, Warning): + """ + A required config property is missing, and the application is assuming a default value. + """ + pass + + +class ConfigSource(Mapping): ''' - Abstract config value source. + Abstract config source. ''' __slots__ = () @@ -64,8 +78,6 @@ class ConfigParserConfigSource(ConfigSource): _cfp: _ConfigParser def __init__(self, cfp: _ConfigParser): - if not isinstance(cfp, _ConfigParser): - raise TypeError(f'a ConfigParser object is required (got {cfp.__class__.__name__!r})') self._cfp = cfp def __getitem__(self, key: str, /) -> str: k1, _, k2 = key.partition('.') @@ -105,7 +117,7 @@ class DictConfigSource(ConfigSource): def __len__(self) -> int: return len(self._d) -class ConfigValue(ValueProperty): +class ConfigValue: """ A single config property. @@ -121,43 +133,61 @@ class ConfigValue(ValueProperty): - preserve_case: if True, src is not CAPITALIZED. Useful for parsing from Python dictionaries or ConfigParser's - required: throw an error if empty or not supplied """ - - _preserve_case: bool = False - _prefix: str | None = None - _not_found = MissingConfigError + # XXX disabled for https://stackoverflow.com/questions/45864273/slots-conflicts-with-a-class-variable-in-a-generic-class + #__slots__ = ('_srcs', '_val', '_default', '_cast', '_required', '_preserve_case') + _srcs: dict[str, str] | None + _preserve_case: bool = False + _val: Any | MISSING = MISSING + _default: Any | None + _cast: Callable | None + _required: bool + _pub_name: str | bool = False def __init__(self, /, src: str | None = None, *, default = None, cast: Callable | None = None, required: bool = False, preserve_case: bool = False, prefix: str | None = None, public: str | bool = False, **kwargs): + self._srcs = dict() self._preserve_case = preserve_case - if src and not preserve_case: - src = src.upper() - if not src and prefix: - self._prefix = prefix - if not preserve_case: - src = f'{prefix.upper()}?' + if src: + self._srcs['default'] = src if preserve_case else src.upper() + elif prefix: + self._srcs['default'] = f'{prefix if preserve_case else prefix.upper}?' + self._default = default + self._cast = cast + self._required = required + self._pub_name = public + for k, v in kwargs.items(): + if k.endswith('_src'): + self._srcs[k[:-4]] = v else: - src = f'{prefix}?' - - super().__init__(src, default=default, cast=cast, - required=required, public=public, **kwargs - ) - + raise TypeError(f'unknown keyword argument {k!r}') def __set_name__(self, owner, name: str): - src_name = name if self._preserve_case else name.upper() - - super().__set_name__(owner, name, src_name=src_name) - + if 'default' not in self._srcs: + self._srcs['default'] = name if self._preserve_case else name.upper() + elif self._srcs['default'].endswith('?'): + self._srcs['default'] = self._srcs['default'].rstrip('?') + (name if self._preserve_case else name.upper() ) + + if self._pub_name is True: + self._pub_name = name if self._pub_name and isinstance(owner, ConfigOptions): owner.expose(self._pub_name, name) - + def __get__(self, obj: ConfigOptions, owner=False): + if self._val is MISSING: + for srckey, src in obj._srcs.items(): + v = src.get(self._srcs[srckey], MISSING) + if self._required and not v: + raise MissingConfigError(f'required config {self._src} not set!') + if v is MISSING: + v = self._default + if callable(self._cast): + v = self._cast(v) if v is not None else self._cast() + self._val = v + return self._val - @override - def _getter(self, obj: ConfigOptions, name: str = 'default') -> ConfigSource: - if not isinstance(obj._srcs, Mapping): - raise RuntimeError('attempt to get config value with no source configured') - return obj._srcs.get(name) + @property + def source(self, /): + return self._srcs['default'] class ConfigOptions: @@ -186,7 +216,7 @@ class ConfigOptions: if first: self._srcs.move_to_end(key, False) - add_config_source = deprecated('use add_source() instead')(add_source) + add_config_source = deprecated_alias(add_source) def expose(self, public_name: str, attr_name: str | None = None) -> None: ''' diff --git a/src/suou/configparsev0_3.py b/src/suou/configparsev0_3.py deleted file mode 100644 index 282e248..0000000 --- a/src/suou/configparsev0_3.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -Utilities for parsing config variables. - -BREAKING older, non-generalized version, kept for backwards compability -in case 0.4+ version happens to break. - -WILL BE removed in 0.5.0. - ---- - -Copyright (c) 2025 Sakuragasaki46. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -See LICENSE for the specific language governing permissions and -limitations under the License. - -This software is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -""" - -from __future__ import annotations -from ast import TypeVar -from collections.abc import Mapping -from configparser import ConfigParser as _ConfigParser -import os -from typing import Any, Callable, Iterator -from collections import OrderedDict -import warnings - -from .functools import deprecated -from .exceptions import MissingConfigError, MissingConfigWarning - -warnings.warn('This module will be removed in 0.5.0 and is kept only in case new implementation breaks!\n'\ - 'Do not use unless you know what you are doing.', DeprecationWarning) - -MISSING = object() -_T = TypeVar('T') - - -@deprecated('use configparse') -class ConfigSource(Mapping): - ''' - Abstract config source. - ''' - __slots__ = () - -@deprecated('use configparse') -class EnvConfigSource(ConfigSource): - ''' - Config source from os.environ aka .env - ''' - def __getitem__(self, key: str, /) -> str: - return os.environ[key] - def get(self, key: str, fallback = None, /): - return os.getenv(key, fallback) - def __contains__(self, key: str, /) -> bool: - return key in os.environ - def __iter__(self) -> Iterator[str]: - yield from os.environ - def __len__(self) -> int: - return len(os.environ) - -@deprecated('use configparse') -class ConfigParserConfigSource(ConfigSource): - ''' - Config source from ConfigParser - ''' - __slots__ = ('_cfp', ) - _cfp: _ConfigParser - - def __init__(self, cfp: _ConfigParser): - if not isinstance(cfp, _ConfigParser): - raise TypeError(f'a ConfigParser object is required (got {cfp.__class__.__name__!r})') - self._cfp = cfp - def __getitem__(self, key: str, /) -> str: - k1, _, k2 = key.partition('.') - return self._cfp.get(k1, k2) - def get(self, key: str, fallback = None, /): - k1, _, k2 = key.partition('.') - return self._cfp.get(k1, k2, fallback=fallback) - def __contains__(self, key: str, /) -> bool: - k1, _, k2 = key.partition('.') - return self._cfp.has_option(k1, k2) - def __iter__(self) -> Iterator[str]: - for k1, v1 in self._cfp.items(): - for k2 in v1: - yield f'{k1}.{k2}' - def __len__(self) -> int: - ## XXX might be incorrect but who cares - return sum(len(x) for x in self._cfp) - -@deprecated('use configparse') -class DictConfigSource(ConfigSource): - ''' - Config source from Python mappings. Useful with JSON/TOML config - ''' - __slots__ = ('_d',) - - _d: dict[str, Any] - - def __init__(self, mapping: dict[str, Any]): - self._d = mapping - def __getitem__(self, key: str, /) -> str: - return self._d[key] - def get(self, key: str, fallback: _T = None, /): - return self._d.get(key, fallback) - def __contains__(self, key: str, /) -> bool: - return key in self._d - def __iter__(self) -> Iterator[str]: - yield from self._d - def __len__(self) -> int: - return len(self._d) - -@deprecated('use configparse') -class ConfigValue: - """ - A single config property. - - By default, it is sourced from os.environ — i.e. environment variables, - and property name is upper cased. - - You can specify further sources, if the parent ConfigOptions class - supports them. - - Arguments: - - public: mark value as public, making it available across the app (e.g. in Jinja2 templates). - - prefix: src but for the lazy - - preserve_case: if True, src is not CAPITALIZED. Useful for parsing from Python dictionaries or ConfigParser's - - required: throw an error if empty or not supplied - """ - # XXX disabled per https://stackoverflow.com/questions/45864273/slots-conflicts-with-a-class-variable-in-a-generic-class - #__slots__ = ('_srcs', '_val', '_default', '_cast', '_required', '_preserve_case') - - _srcs: dict[str, str] | None - _preserve_case: bool = False - _val: Any | MISSING = MISSING - _default: Any | None - _cast: Callable | None - _required: bool - _pub_name: str | bool = False - def __init__(self, /, - src: str | None = None, *, default = None, cast: Callable | None = None, - required: bool = False, preserve_case: bool = False, prefix: str | None = None, - public: str | bool = False, **kwargs): - self._srcs = dict() - self._preserve_case = preserve_case - if src: - self._srcs['default'] = src if preserve_case else src.upper() - elif prefix: - self._srcs['default'] = f'{prefix if preserve_case else prefix.upper}?' - self._default = default - self._cast = cast - self._required = required - self._pub_name = public - for k, v in kwargs.items(): - if k.endswith('_src'): - self._srcs[k[:-4]] = v - else: - raise TypeError(f'unknown keyword argument {k!r}') - def __set_name__(self, owner, name: str): - if 'default' not in self._srcs: - self._srcs['default'] = name if self._preserve_case else name.upper() - elif self._srcs['default'].endswith('?'): - self._srcs['default'] = self._srcs['default'].rstrip('?') + (name if self._preserve_case else name.upper() ) - - if self._pub_name is True: - self._pub_name = name - if self._pub_name and isinstance(owner, ConfigOptions): - owner.expose(self._pub_name, name) - def __get__(self, obj: ConfigOptions, owner=False): - if self._val is MISSING: - v = MISSING - for srckey, src in obj._srcs.items(): - if srckey in self._srcs: - v = src.get(self._srcs[srckey], v) - if self._required and (not v or v is MISSING): - raise MissingConfigError(f'required config {self._srcs['default']} not set!') - if v is MISSING: - v = self._default - if callable(self._cast): - v = self._cast(v) if v is not None else self._cast() - self._val = v - return self._val - - @property - def source(self, /): - return self._srcs['default'] - -@deprecated('use configparse') -class ConfigOptions: - """ - Base class for loading config values. - - It is intended to get subclassed; config values must be defined as - ConfigValue() properties. - - Further config sources can be added with .add_source() - """ - - __slots__ = ('_srcs', '_pub') - - _srcs: OrderedDict[str, ConfigSource] - _pub: dict[str, str] - - def __init__(self, /): - self._srcs = OrderedDict( - default = EnvConfigSource() - ) - self._pub = dict() - - def add_source(self, key: str, csrc: ConfigSource, /, *, first: bool = False): - self._srcs[key] = csrc - if first: - self._srcs.move_to_end(key, False) - - add_config_source = deprecated_alias(add_source) - - def expose(self, public_name: str, attr_name: str | None = None) -> None: - ''' - Mark a config value as public. - - Called automatically by ConfigValue.__set_name__(). - ''' - attr_name = attr_name or public_name - self._pub[public_name] = attr_name - - def to_dict(self, /): - d = dict() - for k, v in self._pub.items(): - d[k] = getattr(self, v) - return d - - -__all__ = ( - 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'EnvConfigSource', 'ConfigParserConfigSource', 'DictConfigSource', 'ConfigSource', 'ConfigValue' -) - - diff --git a/src/suou/exceptions.py b/src/suou/exceptions.py deleted file mode 100644 index bc71037..0000000 --- a/src/suou/exceptions.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Exceptions and throwables for various purposes - ---- - -Copyright (c) 2025 Sakuragasaki46. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -See LICENSE for the specific language governing permissions and -limitations under the License. - -This software is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -""" - - - -class MissingConfigError(LookupError): - """ - Config variable not found. - - Raised when a config property is marked as required, but no property with - that name is found. - """ - pass - - -class MissingConfigWarning(MissingConfigError, Warning): - """ - A required config property is missing, and the application is assuming a default value. - """ - pass \ No newline at end of file diff --git a/src/suou/flask_restx.py b/src/suou/flask_restx.py index cef777e..9d4955a 100644 --- a/src/suou/flask_restx.py +++ b/src/suou/flask_restx.py @@ -16,10 +16,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Any, Mapping import warnings -from flask import Response, current_app, make_response +from flask import current_app, make_response from flask_restx import Api as _Api -from .codecs import jsondecode, jsonencode, want_bytes, want_str +from .codecs import jsonencode def output_json(data, code, headers=None): @@ -54,21 +54,10 @@ class Api(_Api): Notably, all JSON is whitespace-free and .message is remapped to .error """ def handle_error(self, e): - ### XXX in order for errors to get handled the correct way, import - ### suou.flask_restx.Api() NOT flask_restx.Api() !!!! res = super().handle_error(e) if isinstance(res, Mapping) and 'message' in res: res['error'] = res['message'] del res['message'] - elif isinstance(res, Response): - try: - body = want_str(res.response[0]) - bodj = jsondecode(body) - if 'message' in bodj: - bodj['error'] = bodj.pop('message') - res.response = [want_bytes(jsonencode(bodj))] - except (IndexError, KeyError): - pass return res def __init__(self, *a, **ka): super().__init__(*a, **ka) diff --git a/src/suou/flask_sqlalchemy.py b/src/suou/flask_sqlalchemy.py index 5af6a8c..8ef0d12 100644 --- a/src/suou/flask_sqlalchemy.py +++ b/src/suou/flask_sqlalchemy.py @@ -22,7 +22,7 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import DeclarativeBase, Session from .codecs import want_bytes -from .sqlalchemy import AuthSrc, require_auth_base +from .sqlalchemy import require_auth_base class FlaskAuthSrc(AuthSrc): ''' @@ -35,15 +35,14 @@ class FlaskAuthSrc(AuthSrc): def get_session(self) -> Session: return self.db.session def get_token(self): - if request.authorization: - return request.authorization.token + return request.authorization.token def get_signature(self) -> bytes: sig = request.headers.get('authorization-signature', None) return want_bytes(sig) if sig else None def invalid_exc(self, msg: str = 'validation failed') -> Never: abort(400, msg) def required_exc(self): - abort(401, 'Login required') + abort(401) def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Callable]: """ @@ -52,9 +51,6 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca This looks for a token in the Authorization header, validates it, loads the appropriate object, and injects it as the user= parameter. - NOTE: the actual decorator to be used on routes is **auth_required()**, - NOT require_auth() which is the **constructor** for it. - cls is a SQLAlchemy table. db is a flask_sqlalchemy.SQLAlchemy() binding. @@ -66,15 +62,8 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca @auth_required(validators=[lambda x: x.is_administrator]) def super_secret_stuff(user): pass - - NOTE: require_auth() DOES NOT work with flask_restx. """ - def auth_required(**kwargs): - return require_auth_base(cls=cls, src=FlaskAuthSrc(db), **kwargs) - - auth_required.__doc__ = require_auth_base.__doc__ - - return auth_required + return partial(require_auth_base, cls=cls, src=FlaskAuthSrc(db)) __all__ = ('require_auth', ) diff --git a/src/suou/forms.py b/src/suou/forms.py deleted file mode 100644 index 8f5318f..0000000 --- a/src/suou/forms.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Form validation, done right. - -Why this? Why not, let's say, WTForms or Marshmallow? Well, I have my reasons. - -TODO -""" - diff --git a/src/suou/iding.py b/src/suou/iding.py index 2fe2364..6cfcd5e 100644 --- a/src/suou/iding.py +++ b/src/suou/iding.py @@ -41,7 +41,7 @@ from typing import Iterable, override import warnings from .functools import deprecated -from .codecs import b32lencode, b64encode, cb32decode, cb32encode, want_str +from .codecs import b32lencode, b64encode, cb32encode class SiqType(enum.Enum): @@ -225,7 +225,6 @@ class Siq(int): """ Representation of a SIQ as an integer. """ - def to_bytes(self, length: int = 14, byteorder = 'big', *, signed: bool = False) -> bytes: return super().to_bytes(length, byteorder, signed=signed) @classmethod @@ -237,7 +236,7 @@ class Siq(int): def to_base64(self, length: int = 15, *, strip: bool = True) -> str: return b64encode(self.to_bytes(length), strip=strip) def to_cb32(self) -> str: - return cb32encode(self.to_bytes(15, 'big')).lstrip('0') + return cb32encode(self.to_bytes(15, 'big')) to_crockford = to_cb32 def to_hex(self) -> str: return f'{self:x}' @@ -292,10 +291,6 @@ class Siq(int): raise ValueError('checksum mismatch') return cls(int.from_bytes(b, 'big')) - @classmethod - def from_cb32(cls, val: str | bytes): - return cls.from_bytes(cb32decode(want_str(val).zfill(24))) - def to_mastodon(self, /, domain: str | None = None): return f'@{self:u}{"@" if domain else ""}{domain}' def to_matrix(self, /, domain: str): diff --git a/src/suou/sqlalchemy.py b/src/suou/sqlalchemy.py index 249b104..2019424 100644 --- a/src/suou/sqlalchemy.py +++ b/src/suou/sqlalchemy.py @@ -253,7 +253,8 @@ class AuthSrc(metaclass=ABCMeta): def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', - required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None, + invalid_exc: Callable | None = None, required_exc: Callable | None = None): ''' Inject the current user into a view, given the Authorization: Bearer header. @@ -274,11 +275,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | except Exception: return None - def _default_invalid(msg: str = 'Validation failed'): + def _default_invalid(msg: str): raise ValueError(msg) - invalid_exc = src.invalid_exc or _default_invalid - required_exc = src.required_exc or (lambda: _default_invalid('Login required')) + invalid_exc = invalid_exc or _default_invalid + required_exc = required_exc or (lambda: _default_invalid()) def decorator(func: Callable): @wraps(func)