diff --git a/CHANGELOG.md b/CHANGELOG.md index eab0c55..3536f9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## 0.4.0 + ++ Added `ValueProperty`, abstract superclass for `ConfigProperty`. + +## 0.3.4 + +- Bug fixes in `.flask_restx` regarding error handling +- Fixed a bug in `.configparse` dealing with unset values from multiple sources + ## 0.3.3 - Fixed leftovers in `snowflake` module from unchecked code copying — i.e. `SnowflakeGen.generate_one()` used to require an unused typ= parameter diff --git a/README.md b/README.md index 5ae6d56..29ee187 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Good morning, my brother! Welcome the SUOU (SIS Unified Object Underarmor), a library for the management of the storage of objects into a database. -It provides utilities such as [SIQ](https://sakux.moe/protocols/siq.html), signing and generation of access tokens (on top of [ItsDangerous](https://github.com/pallets/itsdangerous)) and various utilities, including helpers for use in Flask and SQLAlchemy. +It provides utilities such as [SIQ](https://yusur.moe/protocols/siq.html), signing and generation of access tokens (on top of [ItsDangerous](https://github.com/pallets/itsdangerous)) and various utilities, including helpers for use in Flask and SQLAlchemy. **It is not an ORM** nor a replacement of it; it works along existing ORMs (currently only SQLAlchemy is supported lol). diff --git a/pyproject.toml b/pyproject.toml index 36540a6..58766fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,8 @@ readme = "README.md" dependencies = [ "itsdangerous", - "toml" + "toml", + "pydantic" ] # - further devdependencies below - # @@ -36,10 +37,12 @@ sqlalchemy = [ ] flask = [ "Flask>=2.0.0", - "Flask-RestX" + "Flask-RestX", + "Quart", + "Quart-Schema" ] flask_sqlalchemy = [ - "Flask-SqlAlchemy" + "Flask-SqlAlchemy", ] peewee = [ "peewee>=3.0.0, <4.0" diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 97743c8..4a14073 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -27,7 +27,7 @@ from .itertools import makelist, kwargs_prefix, ltuple, rtuple, additem from .i18n import I18n, JsonI18n, TomlI18n from .snowflake import Snowflake, SnowflakeGen -__version__ = "0.3.3" +__version__ = "0.4.0-dev26" __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase', diff --git a/src/suou/classtools.py b/src/suou/classtools.py index ebe673b..34ad58b 100644 --- a/src/suou/classtools.py +++ b/src/suou/classtools.py @@ -14,10 +14,17 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ -from typing import Any, Callable, Generic, Iterable, TypeVar +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import Any, Callable, Generic, Iterable, Mapping, TypeVar + +from suou.codecs import StringCase _T = TypeVar('_T') +MISSING = object() + class Wanted(Generic[_T]): """ Placeholder for parameters wanted by Incomplete(). @@ -98,6 +105,78 @@ class Incomplete(Generic[_T]): clsdict[k] = v.instance() return clsdict -__all__ = ( - 'Wanted', 'Incomplete' -) \ No newline at end of file + +class ValueSource(Mapping): + """ + Abstract value source. + """ + pass + + +class ValueProperty(Generic[_T]): + _name: str | None + _srcs: dict[str, str] + _val: Any | MISSING + _default: Any | None + _cast: Callable | None + _required: bool + _pub_name: str | bool = False + _not_found = LookupError + + def __init__(self, /, src: str | None = None, *, + default = None, cast: Callable | None = None, + required: bool = False, public: str | bool = False, + **kwargs + ): + self._srcs = dict() + if src: + self._srcs['default'] = src + self._default = default + self._cast = cast + self._required = required + self._pub_name = public + self._val = MISSING + for k, v in kwargs.items(): + if k.endswith('_src'): + self._srcs[k[:-4]] = v + else: + raise TypeError(f'unknown keyword argument {k!r}') + + def __set_name__(self, owner, name: str, *, src_name: str | None = None): + self._name = name + self._srcs.setdefault('default', src_name or name) + nsrcs = dict() + for k, v in self._srcs.items(): + if v.endswith('?'): + nsrcs[k] = v.rstrip('?') + (src_name or name) + self._srcs.update(nsrcs) + if self._pub_name is True: + self._pub_name = name + def __get__(self, obj: Any, owner = None): + if self._val is MISSING: + v = MISSING + for srckey, src in self._srcs.items(): + if (getter := self._getter(obj, srckey)): + v = getter.get(src, v) + if self._required and (not v or v is MISSING): + raise self._not_found(f'required config {self._srcs['default']} not set!') + if v is MISSING: + v = self._default + if callable(self._cast): + v = self._cast(v) if v is not None else self._cast() + self._val = v + return self._val + + @abstractmethod + def _getter(self, obj: Any, name: str = 'default') -> ValueSource: + pass + + @property + def name(self): + return self._name + + @property + def source(self, /): + return self._srcs['default'] + + diff --git a/src/suou/codecs.py b/src/suou/codecs.py index 2bee255..3efe53f 100644 --- a/src/suou/codecs.py +++ b/src/suou/codecs.py @@ -227,7 +227,7 @@ def jsonencode(obj: dict, *, skipkeys: bool = True, separators: tuple[str, str] ''' return json.dumps(obj, skipkeys=skipkeys, separators=separators, default=_json_default(default), **kwargs) -jsondecode = deprecated('just use json.loads()')(json.loads) +jsondecode: Callable[Any, dict] = deprecated('just use json.loads()')(json.loads) def ssv_list(s: str, *, sep_chars = ',;') -> list[str]: """ diff --git a/src/suou/configparse.py b/src/suou/configparse.py index ace075f..9709b0a 100644 --- a/src/suou/configparse.py +++ b/src/suou/configparse.py @@ -15,41 +15,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ from __future__ import annotations -from abc import abstractmethod + from ast import TypeVar from collections.abc import Mapping from configparser import ConfigParser as _ConfigParser import os -from typing import Any, Callable, Iterable, Iterator +from typing import Any, Callable, Iterator, override from collections import OrderedDict -from .functools import deprecated_alias +from .classtools import ValueSource, ValueProperty +from .functools import deprecated +from .exceptions import MissingConfigError, MissingConfigWarning + -MISSING = object() _T = TypeVar('T') -class MissingConfigError(LookupError): - """ - Config variable not found. - Raised when a config property is marked as required, but no property with - that name is found. - """ - pass - - -class MissingConfigWarning(MissingConfigError, Warning): - """ - A required config property is missing, and the application is assuming a default value. - """ - pass - - -class ConfigSource(Mapping): +class ConfigSource(ValueSource): ''' - Abstract config source. + Abstract config value source. ''' __slots__ = () @@ -78,6 +64,8 @@ class ConfigParserConfigSource(ConfigSource): _cfp: _ConfigParser def __init__(self, cfp: _ConfigParser): + if not isinstance(cfp, _ConfigParser): + raise TypeError(f'a ConfigParser object is required (got {cfp.__class__.__name__!r})') self._cfp = cfp def __getitem__(self, key: str, /) -> str: k1, _, k2 = key.partition('.') @@ -117,7 +105,7 @@ class DictConfigSource(ConfigSource): def __len__(self) -> int: return len(self._d) -class ConfigValue: +class ConfigValue(ValueProperty): """ A single config property. @@ -133,61 +121,43 @@ class ConfigValue: - preserve_case: if True, src is not CAPITALIZED. Useful for parsing from Python dictionaries or ConfigParser's - required: throw an error if empty or not supplied """ - # XXX disabled for https://stackoverflow.com/questions/45864273/slots-conflicts-with-a-class-variable-in-a-generic-class - #__slots__ = ('_srcs', '_val', '_default', '_cast', '_required', '_preserve_case') - - _srcs: dict[str, str] | None + _preserve_case: bool = False - _val: Any | MISSING = MISSING - _default: Any | None - _cast: Callable | None - _required: bool - _pub_name: str | bool = False + _prefix: str | None = None + _not_found = MissingConfigError + def __init__(self, /, src: str | None = None, *, default = None, cast: Callable | None = None, required: bool = False, preserve_case: bool = False, prefix: str | None = None, public: str | bool = False, **kwargs): - self._srcs = dict() self._preserve_case = preserve_case - if src: - self._srcs['default'] = src if preserve_case else src.upper() - elif prefix: - self._srcs['default'] = f'{prefix if preserve_case else prefix.upper}?' - self._default = default - self._cast = cast - self._required = required - self._pub_name = public - for k, v in kwargs.items(): - if k.endswith('_src'): - self._srcs[k[:-4]] = v + if src and not preserve_case: + src = src.upper() + if not src and prefix: + self._prefix = prefix + if not preserve_case: + src = f'{prefix.upper()}?' else: - raise TypeError(f'unknown keyword argument {k!r}') - def __set_name__(self, owner, name: str): - if 'default' not in self._srcs: - self._srcs['default'] = name if self._preserve_case else name.upper() - elif self._srcs['default'].endswith('?'): - self._srcs['default'] = self._srcs['default'].rstrip('?') + (name if self._preserve_case else name.upper() ) + src = f'{prefix}?' + + super().__init__(src, default=default, cast=cast, + required=required, public=public, **kwargs + ) - if self._pub_name is True: - self._pub_name = name + def __set_name__(self, owner, name: str): + src_name = name if self._preserve_case else name.upper() + + super().__set_name__(owner, name, src_name=src_name) + if self._pub_name and isinstance(owner, ConfigOptions): owner.expose(self._pub_name, name) - def __get__(self, obj: ConfigOptions, owner=False): - if self._val is MISSING: - for srckey, src in obj._srcs.items(): - v = src.get(self._srcs[srckey], MISSING) - if self._required and not v: - raise MissingConfigError(f'required config {self._src} not set!') - if v is MISSING: - v = self._default - if callable(self._cast): - v = self._cast(v) if v is not None else self._cast() - self._val = v - return self._val + - @property - def source(self, /): - return self._srcs['default'] + @override + def _getter(self, obj: ConfigOptions, name: str = 'default') -> ConfigSource: + if not isinstance(obj._srcs, Mapping): + raise RuntimeError('attempt to get config value with no source configured') + return obj._srcs.get(name) class ConfigOptions: @@ -216,7 +186,7 @@ class ConfigOptions: if first: self._srcs.move_to_end(key, False) - add_config_source = deprecated_alias(add_source) + add_config_source = deprecated('use add_source() instead')(add_source) def expose(self, public_name: str, attr_name: str | None = None) -> None: ''' diff --git a/src/suou/configparsev0_3.py b/src/suou/configparsev0_3.py new file mode 100644 index 0000000..282e248 --- /dev/null +++ b/src/suou/configparsev0_3.py @@ -0,0 +1,239 @@ +""" +Utilities for parsing config variables. + +BREAKING older, non-generalized version, kept for backwards compability +in case 0.4+ version happens to break. + +WILL BE removed in 0.5.0. + +--- + +Copyright (c) 2025 Sakuragasaki46. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +See LICENSE for the specific language governing permissions and +limitations under the License. + +This software is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" + +from __future__ import annotations +from ast import TypeVar +from collections.abc import Mapping +from configparser import ConfigParser as _ConfigParser +import os +from typing import Any, Callable, Iterator +from collections import OrderedDict +import warnings + +from .functools import deprecated +from .exceptions import MissingConfigError, MissingConfigWarning + +warnings.warn('This module will be removed in 0.5.0 and is kept only in case new implementation breaks!\n'\ + 'Do not use unless you know what you are doing.', DeprecationWarning) + +MISSING = object() +_T = TypeVar('T') + + +@deprecated('use configparse') +class ConfigSource(Mapping): + ''' + Abstract config source. + ''' + __slots__ = () + +@deprecated('use configparse') +class EnvConfigSource(ConfigSource): + ''' + Config source from os.environ aka .env + ''' + def __getitem__(self, key: str, /) -> str: + return os.environ[key] + def get(self, key: str, fallback = None, /): + return os.getenv(key, fallback) + def __contains__(self, key: str, /) -> bool: + return key in os.environ + def __iter__(self) -> Iterator[str]: + yield from os.environ + def __len__(self) -> int: + return len(os.environ) + +@deprecated('use configparse') +class ConfigParserConfigSource(ConfigSource): + ''' + Config source from ConfigParser + ''' + __slots__ = ('_cfp', ) + _cfp: _ConfigParser + + def __init__(self, cfp: _ConfigParser): + if not isinstance(cfp, _ConfigParser): + raise TypeError(f'a ConfigParser object is required (got {cfp.__class__.__name__!r})') + self._cfp = cfp + def __getitem__(self, key: str, /) -> str: + k1, _, k2 = key.partition('.') + return self._cfp.get(k1, k2) + def get(self, key: str, fallback = None, /): + k1, _, k2 = key.partition('.') + return self._cfp.get(k1, k2, fallback=fallback) + def __contains__(self, key: str, /) -> bool: + k1, _, k2 = key.partition('.') + return self._cfp.has_option(k1, k2) + def __iter__(self) -> Iterator[str]: + for k1, v1 in self._cfp.items(): + for k2 in v1: + yield f'{k1}.{k2}' + def __len__(self) -> int: + ## XXX might be incorrect but who cares + return sum(len(x) for x in self._cfp) + +@deprecated('use configparse') +class DictConfigSource(ConfigSource): + ''' + Config source from Python mappings. Useful with JSON/TOML config + ''' + __slots__ = ('_d',) + + _d: dict[str, Any] + + def __init__(self, mapping: dict[str, Any]): + self._d = mapping + def __getitem__(self, key: str, /) -> str: + return self._d[key] + def get(self, key: str, fallback: _T = None, /): + return self._d.get(key, fallback) + def __contains__(self, key: str, /) -> bool: + return key in self._d + def __iter__(self) -> Iterator[str]: + yield from self._d + def __len__(self) -> int: + return len(self._d) + +@deprecated('use configparse') +class ConfigValue: + """ + A single config property. + + By default, it is sourced from os.environ — i.e. environment variables, + and property name is upper cased. + + You can specify further sources, if the parent ConfigOptions class + supports them. + + Arguments: + - public: mark value as public, making it available across the app (e.g. in Jinja2 templates). + - prefix: src but for the lazy + - preserve_case: if True, src is not CAPITALIZED. Useful for parsing from Python dictionaries or ConfigParser's + - required: throw an error if empty or not supplied + """ + # XXX disabled per https://stackoverflow.com/questions/45864273/slots-conflicts-with-a-class-variable-in-a-generic-class + #__slots__ = ('_srcs', '_val', '_default', '_cast', '_required', '_preserve_case') + + _srcs: dict[str, str] | None + _preserve_case: bool = False + _val: Any | MISSING = MISSING + _default: Any | None + _cast: Callable | None + _required: bool + _pub_name: str | bool = False + def __init__(self, /, + src: str | None = None, *, default = None, cast: Callable | None = None, + required: bool = False, preserve_case: bool = False, prefix: str | None = None, + public: str | bool = False, **kwargs): + self._srcs = dict() + self._preserve_case = preserve_case + if src: + self._srcs['default'] = src if preserve_case else src.upper() + elif prefix: + self._srcs['default'] = f'{prefix if preserve_case else prefix.upper}?' + self._default = default + self._cast = cast + self._required = required + self._pub_name = public + for k, v in kwargs.items(): + if k.endswith('_src'): + self._srcs[k[:-4]] = v + else: + raise TypeError(f'unknown keyword argument {k!r}') + def __set_name__(self, owner, name: str): + if 'default' not in self._srcs: + self._srcs['default'] = name if self._preserve_case else name.upper() + elif self._srcs['default'].endswith('?'): + self._srcs['default'] = self._srcs['default'].rstrip('?') + (name if self._preserve_case else name.upper() ) + + if self._pub_name is True: + self._pub_name = name + if self._pub_name and isinstance(owner, ConfigOptions): + owner.expose(self._pub_name, name) + def __get__(self, obj: ConfigOptions, owner=False): + if self._val is MISSING: + v = MISSING + for srckey, src in obj._srcs.items(): + if srckey in self._srcs: + v = src.get(self._srcs[srckey], v) + if self._required and (not v or v is MISSING): + raise MissingConfigError(f'required config {self._srcs['default']} not set!') + if v is MISSING: + v = self._default + if callable(self._cast): + v = self._cast(v) if v is not None else self._cast() + self._val = v + return self._val + + @property + def source(self, /): + return self._srcs['default'] + +@deprecated('use configparse') +class ConfigOptions: + """ + Base class for loading config values. + + It is intended to get subclassed; config values must be defined as + ConfigValue() properties. + + Further config sources can be added with .add_source() + """ + + __slots__ = ('_srcs', '_pub') + + _srcs: OrderedDict[str, ConfigSource] + _pub: dict[str, str] + + def __init__(self, /): + self._srcs = OrderedDict( + default = EnvConfigSource() + ) + self._pub = dict() + + def add_source(self, key: str, csrc: ConfigSource, /, *, first: bool = False): + self._srcs[key] = csrc + if first: + self._srcs.move_to_end(key, False) + + add_config_source = deprecated_alias(add_source) + + def expose(self, public_name: str, attr_name: str | None = None) -> None: + ''' + Mark a config value as public. + + Called automatically by ConfigValue.__set_name__(). + ''' + attr_name = attr_name or public_name + self._pub[public_name] = attr_name + + def to_dict(self, /): + d = dict() + for k, v in self._pub.items(): + d[k] = getattr(self, v) + return d + + +__all__ = ( + 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'EnvConfigSource', 'ConfigParserConfigSource', 'DictConfigSource', 'ConfigSource', 'ConfigValue' +) + + diff --git a/src/suou/exceptions.py b/src/suou/exceptions.py new file mode 100644 index 0000000..bc71037 --- /dev/null +++ b/src/suou/exceptions.py @@ -0,0 +1,33 @@ +""" +Exceptions and throwables for various purposes + +--- + +Copyright (c) 2025 Sakuragasaki46. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +See LICENSE for the specific language governing permissions and +limitations under the License. + +This software is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" + + + +class MissingConfigError(LookupError): + """ + Config variable not found. + + Raised when a config property is marked as required, but no property with + that name is found. + """ + pass + + +class MissingConfigWarning(MissingConfigError, Warning): + """ + A required config property is missing, and the application is assuming a default value. + """ + pass \ No newline at end of file diff --git a/src/suou/flask_restx.py b/src/suou/flask_restx.py index 9d4955a..cef777e 100644 --- a/src/suou/flask_restx.py +++ b/src/suou/flask_restx.py @@ -16,10 +16,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Any, Mapping import warnings -from flask import current_app, make_response +from flask import Response, current_app, make_response from flask_restx import Api as _Api -from .codecs import jsonencode +from .codecs import jsondecode, jsonencode, want_bytes, want_str def output_json(data, code, headers=None): @@ -54,10 +54,21 @@ class Api(_Api): Notably, all JSON is whitespace-free and .message is remapped to .error """ def handle_error(self, e): + ### XXX in order for errors to get handled the correct way, import + ### suou.flask_restx.Api() NOT flask_restx.Api() !!!! res = super().handle_error(e) if isinstance(res, Mapping) and 'message' in res: res['error'] = res['message'] del res['message'] + elif isinstance(res, Response): + try: + body = want_str(res.response[0]) + bodj = jsondecode(body) + if 'message' in bodj: + bodj['error'] = bodj.pop('message') + res.response = [want_bytes(jsonencode(bodj))] + except (IndexError, KeyError): + pass return res def __init__(self, *a, **ka): super().__init__(*a, **ka) diff --git a/src/suou/flask_sqlalchemy.py b/src/suou/flask_sqlalchemy.py index 8ef0d12..5af6a8c 100644 --- a/src/suou/flask_sqlalchemy.py +++ b/src/suou/flask_sqlalchemy.py @@ -22,7 +22,7 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import DeclarativeBase, Session from .codecs import want_bytes -from .sqlalchemy import require_auth_base +from .sqlalchemy import AuthSrc, require_auth_base class FlaskAuthSrc(AuthSrc): ''' @@ -35,14 +35,15 @@ class FlaskAuthSrc(AuthSrc): def get_session(self) -> Session: return self.db.session def get_token(self): - return request.authorization.token + if request.authorization: + return request.authorization.token def get_signature(self) -> bytes: sig = request.headers.get('authorization-signature', None) return want_bytes(sig) if sig else None def invalid_exc(self, msg: str = 'validation failed') -> Never: abort(400, msg) def required_exc(self): - abort(401) + abort(401, 'Login required') def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Callable]: """ @@ -51,6 +52,9 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca This looks for a token in the Authorization header, validates it, loads the appropriate object, and injects it as the user= parameter. + NOTE: the actual decorator to be used on routes is **auth_required()**, + NOT require_auth() which is the **constructor** for it. + cls is a SQLAlchemy table. db is a flask_sqlalchemy.SQLAlchemy() binding. @@ -62,8 +66,15 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca @auth_required(validators=[lambda x: x.is_administrator]) def super_secret_stuff(user): pass + + NOTE: require_auth() DOES NOT work with flask_restx. """ - return partial(require_auth_base, cls=cls, src=FlaskAuthSrc(db)) + def auth_required(**kwargs): + return require_auth_base(cls=cls, src=FlaskAuthSrc(db), **kwargs) + + auth_required.__doc__ = require_auth_base.__doc__ + + return auth_required __all__ = ('require_auth', ) diff --git a/src/suou/forms.py b/src/suou/forms.py index 552aed8..8f5318f 100644 --- a/src/suou/forms.py +++ b/src/suou/forms.py @@ -2,5 +2,7 @@ Form validation, done right. Why this? Why not, let's say, WTForms or Marshmallow? Well, I have my reasons. + +TODO """ diff --git a/src/suou/iding.py b/src/suou/iding.py index 6cfcd5e..2fe2364 100644 --- a/src/suou/iding.py +++ b/src/suou/iding.py @@ -41,7 +41,7 @@ from typing import Iterable, override import warnings from .functools import deprecated -from .codecs import b32lencode, b64encode, cb32encode +from .codecs import b32lencode, b64encode, cb32decode, cb32encode, want_str class SiqType(enum.Enum): @@ -225,6 +225,7 @@ class Siq(int): """ Representation of a SIQ as an integer. """ + def to_bytes(self, length: int = 14, byteorder = 'big', *, signed: bool = False) -> bytes: return super().to_bytes(length, byteorder, signed=signed) @classmethod @@ -236,7 +237,7 @@ class Siq(int): def to_base64(self, length: int = 15, *, strip: bool = True) -> str: return b64encode(self.to_bytes(length), strip=strip) def to_cb32(self) -> str: - return cb32encode(self.to_bytes(15, 'big')) + return cb32encode(self.to_bytes(15, 'big')).lstrip('0') to_crockford = to_cb32 def to_hex(self) -> str: return f'{self:x}' @@ -291,6 +292,10 @@ class Siq(int): raise ValueError('checksum mismatch') return cls(int.from_bytes(b, 'big')) + @classmethod + def from_cb32(cls, val: str | bytes): + return cls.from_bytes(cb32decode(want_str(val).zfill(24))) + def to_mastodon(self, /, domain: str | None = None): return f'@{self:u}{"@" if domain else ""}{domain}' def to_matrix(self, /, domain: str): diff --git a/src/suou/sqlalchemy.py b/src/suou/sqlalchemy.py index 2019424..249b104 100644 --- a/src/suou/sqlalchemy.py +++ b/src/suou/sqlalchemy.py @@ -253,8 +253,7 @@ class AuthSrc(metaclass=ABCMeta): def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', - required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None, - invalid_exc: Callable | None = None, required_exc: Callable | None = None): + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): ''' Inject the current user into a view, given the Authorization: Bearer header. @@ -275,11 +274,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | except Exception: return None - def _default_invalid(msg: str): + def _default_invalid(msg: str = 'Validation failed'): raise ValueError(msg) - invalid_exc = invalid_exc or _default_invalid - required_exc = required_exc or (lambda: _default_invalid()) + invalid_exc = src.invalid_exc or _default_invalid + required_exc = src.required_exc or (lambda: _default_invalid('Login required')) def decorator(func: Callable): @wraps(func)