diff --git a/pyproject.toml b/pyproject.toml index ca602dd..d452ebd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,15 @@ authors = [ { name = "Sakuragasaki46" } ] dynamic = [ "version" ] +requires-python = ">=3.10" +license = "Apache-2.0" +readme = "README.md" + dependencies = [ "itsdangerous", "toml" ] -requires-python = ">=3.10" -license = "Apache-2.0" -readme = "README.md" +# - further devdependencies below - # # - publishing - classifiers = [ @@ -28,6 +30,7 @@ classifiers = [ Repository = "https://github.com/sakuragasaki46/suou" [project.optional-dependencies] +# the below are all dev dependencies (and probably already installed) sqlalchemy = [ "SQLAlchemy>=2.0.0" ] diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 5caf6cf..6328936 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -22,7 +22,7 @@ from .bits import count_ones, mask_shift, split_bits, join_bits from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource from .functools import deprecated, not_implemented from .classtools import Wanted, Incomplete -from .itertools import makelist, kwargs_prefix +from .itertools import makelist, kwargs_prefix, ltuple, rtuple from .i18n import I18n, JsonI18n, TomlI18n __version__ = "0.3.0-dev21" @@ -30,7 +30,7 @@ __version__ = "0.3.0-dev21" __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase', 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource', - 'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', + 'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple', 'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift', 'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode' ) diff --git a/src/suou/flask_sqlalchemy.py b/src/suou/flask_sqlalchemy.py index dfe05af..8ef0d12 100644 --- a/src/suou/flask_sqlalchemy.py +++ b/src/suou/flask_sqlalchemy.py @@ -21,6 +21,7 @@ from flask import abort, request from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import DeclarativeBase, Session +from .codecs import want_bytes from .sqlalchemy import require_auth_base class FlaskAuthSrc(AuthSrc): @@ -35,7 +36,9 @@ class FlaskAuthSrc(AuthSrc): return self.db.session def get_token(self): return request.authorization.token - + def get_signature(self) -> bytes: + sig = request.headers.get('authorization-signature', None) + return want_bytes(sig) if sig else None def invalid_exc(self, msg: str = 'validation failed') -> Never: abort(400, msg) def required_exc(self): diff --git a/src/suou/itertools.py b/src/suou/itertools.py index 99a5477..dad51f4 100644 --- a/src/suou/itertools.py +++ b/src/suou/itertools.py @@ -14,8 +14,9 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ''' -from typing import Any, Iterable +from typing import Any, Iterable, TypeVar +_T = TypeVar('_T') def makelist(l: Any) -> list: ''' @@ -30,6 +31,25 @@ def makelist(l: Any) -> list: else: return [l] +def ltuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple: + """ + Truncate an iterable into a fixed size tuple, if necessary padding it. + """ + seq = tuple(seq)[:size] + if len(seq) < size: + seq = seq + (pad,) * (size - len(seq)) + return seq + +def rtuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple: + """ + Same as rtuple() but the padding and truncation is made right to left. + """ + seq = tuple(seq)[-size:] + if len(seq) < size: + seq = (pad,) * (size - len(seq)) + seq + return seq + + def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]: ''' Subset of keyword arguments. Useful for callable wrapping. @@ -38,4 +58,4 @@ def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]: -__all__ = ('makelist', 'kwargs_prefix') +__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple') diff --git a/src/suou/signing.py b/src/suou/signing.py index 02a7e2a..6f2dcb9 100644 --- a/src/suou/signing.py +++ b/src/suou/signing.py @@ -16,10 +16,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABC from base64 import b64decode -from typing import Any, Callable +from typing import Any, Callable, Sequence from itsdangerous import TimestampSigner -from .codecs import want_str +from suou.itertools import rtuple + +from .functools import not_implemented +from .codecs import jsondecode, jsonencode, want_bytes, want_str from .iding import Siq class UserSigner(TimestampSigner): @@ -36,6 +39,20 @@ class UserSigner(TimestampSigner): def split_token(cls, /, token: str | bytes) : a, b, c = want_str(token).rsplit('.', 2) return b64decode(a), b, b64decode(c) + def sign_object(self, obj: dict, /, *, encoder=jsonencode, **kwargs): + """ + Return a signed JSON payload of an object. + + MUST be passed as a dict: ser/deser it's not the signer's job. + """ + return self.sign(encoder(obj), **kwargs) + def unsign_object(self, payload: str | bytes, /, *, decoder=jsondecode, **kwargs): + """ + Unsign and parse a JSON object signed payload. Returns a dict. + """ + return decoder(self.unsign(payload, **kwargs)) + def split_signed(self, payload: str | bytes) -> Sequence[bytes]: + return rtuple(want_bytes(payload).rsplit(b'.', 2), 3, b'') class HasSigner(ABC): diff --git a/src/suou/sqlalchemy.py b/src/suou/sqlalchemy.py index 0d0b83d..3886d0d 100644 --- a/src/suou/sqlalchemy.py +++ b/src/suou/sqlalchemy.py @@ -194,22 +194,18 @@ class AuthSrc(metaclass=ABCMeta): @abstractmethod def get_token(self): pass + @abstractmethod + def get_signature(self): + pass -def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Callable[[Callable[[], _T]], Any], session_src: Callable[[], Session], - column: str | Column[_T] = 'id', dest: str = 'user', required: bool = False, validators: Callable | Iterable[Callable] | None = None, +def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None, invalid_exc: Callable | None = None, required_exc: Callable | None = None): ''' Inject the current user into a view, given the Authorization: Bearer header. - For portability reasons, this is a partial, two-component function. - - The value_src() callable takes a callable as its only and required argument, and the supplied - callable, when called, returns the value of the column. - - The session_src() callable takes no argument, and returns a Session object. - - XXX maybe a class is better than a thousand callables? + For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object. ''' col = want_column(cls, column) validators = makelist(validators) @@ -218,7 +214,7 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Ca if token is None: return None tok_parts = UserSigner.split_token(token) - user: HasSigner = session_src().execute(select(cls).where(col == tok_parts[0])).scalar() + user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar() try: signer: UserSigner = user.signer() signer.unsign(token) @@ -235,9 +231,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Ca def decorator(func: Callable): @wraps(func) def wrapper(*a, **ka): - ka[dest] = value_src(get_user) + ka[dest] = get_user(src.get_token()) if not ka[dest] and required: required_exc() + if signed: + ka[sig_dest] = src.get_signature() for valid in validators: if not valid(ka[dest]): invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))