add abstract get_signature() to sqlalchemy.AuthSrc(), itertools.ltuple() and rtuple() and UserSigner.sign_object()

This commit is contained in:
Yusur 2025-06-11 01:16:29 +02:00
parent 9f75d983ba
commit 38eac17109
6 changed files with 63 additions and 22 deletions

View file

@ -4,13 +4,15 @@ authors = [
{ name = "Sakuragasaki46" }
]
dynamic = [ "version" ]
requires-python = ">=3.10"
license = "Apache-2.0"
readme = "README.md"
dependencies = [
"itsdangerous",
"toml"
]
requires-python = ">=3.10"
license = "Apache-2.0"
readme = "README.md"
# - further devdependencies below - #
# - publishing -
classifiers = [
@ -28,6 +30,7 @@ classifiers = [
Repository = "https://github.com/sakuragasaki46/suou"
[project.optional-dependencies]
# the below are all dev dependencies (and probably already installed)
sqlalchemy = [
"SQLAlchemy>=2.0.0"
]

View file

@ -22,7 +22,7 @@ from .bits import count_ones, mask_shift, split_bits, join_bits
from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource
from .functools import deprecated, not_implemented
from .classtools import Wanted, Incomplete
from .itertools import makelist, kwargs_prefix
from .itertools import makelist, kwargs_prefix, ltuple, rtuple
from .i18n import I18n, JsonI18n, TomlI18n
__version__ = "0.3.0-dev21"
@ -30,7 +30,7 @@ __version__ = "0.3.0-dev21"
__all__ = (
'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase',
'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource',
'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode',
'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple',
'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift',
'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode'
)

View file

@ -21,6 +21,7 @@ from flask import abort, request
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import DeclarativeBase, Session
from .codecs import want_bytes
from .sqlalchemy import require_auth_base
class FlaskAuthSrc(AuthSrc):
@ -35,7 +36,9 @@ class FlaskAuthSrc(AuthSrc):
return self.db.session
def get_token(self):
return request.authorization.token
def get_signature(self) -> bytes:
sig = request.headers.get('authorization-signature', None)
return want_bytes(sig) if sig else None
def invalid_exc(self, msg: str = 'validation failed') -> Never:
abort(400, msg)
def required_exc(self):

View file

@ -14,8 +14,9 @@ This software is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
'''
from typing import Any, Iterable
from typing import Any, Iterable, TypeVar
_T = TypeVar('_T')
def makelist(l: Any) -> list:
'''
@ -30,6 +31,25 @@ def makelist(l: Any) -> list:
else:
return [l]
def ltuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple:
"""
Truncate an iterable into a fixed size tuple, if necessary padding it.
"""
seq = tuple(seq)[:size]
if len(seq) < size:
seq = seq + (pad,) * (size - len(seq))
return seq
def rtuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple:
"""
Same as rtuple() but the padding and truncation is made right to left.
"""
seq = tuple(seq)[-size:]
if len(seq) < size:
seq = (pad,) * (size - len(seq)) + seq
return seq
def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]:
'''
Subset of keyword arguments. Useful for callable wrapping.
@ -38,4 +58,4 @@ def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]:
__all__ = ('makelist', 'kwargs_prefix')
__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple')

View file

@ -16,10 +16,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABC
from base64 import b64decode
from typing import Any, Callable
from typing import Any, Callable, Sequence
from itsdangerous import TimestampSigner
from .codecs import want_str
from suou.itertools import rtuple
from .functools import not_implemented
from .codecs import jsondecode, jsonencode, want_bytes, want_str
from .iding import Siq
class UserSigner(TimestampSigner):
@ -36,6 +39,20 @@ class UserSigner(TimestampSigner):
def split_token(cls, /, token: str | bytes) :
a, b, c = want_str(token).rsplit('.', 2)
return b64decode(a), b, b64decode(c)
def sign_object(self, obj: dict, /, *, encoder=jsonencode, **kwargs):
"""
Return a signed JSON payload of an object.
MUST be passed as a dict: ser/deser it's not the signer's job.
"""
return self.sign(encoder(obj), **kwargs)
def unsign_object(self, payload: str | bytes, /, *, decoder=jsondecode, **kwargs):
"""
Unsign and parse a JSON object signed payload. Returns a dict.
"""
return decoder(self.unsign(payload, **kwargs))
def split_signed(self, payload: str | bytes) -> Sequence[bytes]:
return rtuple(want_bytes(payload).rsplit(b'.', 2), 3, b'')
class HasSigner(ABC):

View file

@ -194,22 +194,18 @@ class AuthSrc(metaclass=ABCMeta):
@abstractmethod
def get_token(self):
pass
@abstractmethod
def get_signature(self):
pass
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Callable[[Callable[[], _T]], Any], session_src: Callable[[], Session],
column: str | Column[_T] = 'id', dest: str = 'user', required: bool = False, validators: Callable | Iterable[Callable] | None = None,
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None,
invalid_exc: Callable | None = None, required_exc: Callable | None = None):
'''
Inject the current user into a view, given the Authorization: Bearer header.
For portability reasons, this is a partial, two-component function.
The value_src() callable takes a callable as its only and required argument, and the supplied
callable, when called, returns the value of the column.
The session_src() callable takes no argument, and returns a Session object.
XXX maybe a class is better than a thousand callables?
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
'''
col = want_column(cls, column)
validators = makelist(validators)
@ -218,7 +214,7 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Ca
if token is None:
return None
tok_parts = UserSigner.split_token(token)
user: HasSigner = session_src().execute(select(cls).where(col == tok_parts[0])).scalar()
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
try:
signer: UserSigner = user.signer()
signer.unsign(token)
@ -235,9 +231,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Ca
def decorator(func: Callable):
@wraps(func)
def wrapper(*a, **ka):
ka[dest] = value_src(get_user)
ka[dest] = get_user(src.get_token())
if not ka[dest] and required:
required_exc()
if signed:
ka[sig_dest] = src.get_signature()
for valid in validators:
if not valid(ka[dest]):
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))