add abstract get_signature() to sqlalchemy.AuthSrc(), itertools.ltuple() and rtuple() and UserSigner.sign_object()
This commit is contained in:
parent
9f75d983ba
commit
38eac17109
6 changed files with 63 additions and 22 deletions
|
|
@ -4,13 +4,15 @@ authors = [
|
||||||
{ name = "Sakuragasaki46" }
|
{ name = "Sakuragasaki46" }
|
||||||
]
|
]
|
||||||
dynamic = [ "version" ]
|
dynamic = [ "version" ]
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itsdangerous",
|
"itsdangerous",
|
||||||
"toml"
|
"toml"
|
||||||
]
|
]
|
||||||
requires-python = ">=3.10"
|
# - further devdependencies below - #
|
||||||
license = "Apache-2.0"
|
|
||||||
readme = "README.md"
|
|
||||||
|
|
||||||
# - publishing -
|
# - publishing -
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
|
@ -28,6 +30,7 @@ classifiers = [
|
||||||
Repository = "https://github.com/sakuragasaki46/suou"
|
Repository = "https://github.com/sakuragasaki46/suou"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
# the below are all dev dependencies (and probably already installed)
|
||||||
sqlalchemy = [
|
sqlalchemy = [
|
||||||
"SQLAlchemy>=2.0.0"
|
"SQLAlchemy>=2.0.0"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from .bits import count_ones, mask_shift, split_bits, join_bits
|
||||||
from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource
|
from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource
|
||||||
from .functools import deprecated, not_implemented
|
from .functools import deprecated, not_implemented
|
||||||
from .classtools import Wanted, Incomplete
|
from .classtools import Wanted, Incomplete
|
||||||
from .itertools import makelist, kwargs_prefix
|
from .itertools import makelist, kwargs_prefix, ltuple, rtuple
|
||||||
from .i18n import I18n, JsonI18n, TomlI18n
|
from .i18n import I18n, JsonI18n, TomlI18n
|
||||||
|
|
||||||
__version__ = "0.3.0-dev21"
|
__version__ = "0.3.0-dev21"
|
||||||
|
|
@ -30,7 +30,7 @@ __version__ = "0.3.0-dev21"
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase',
|
'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase',
|
||||||
'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource',
|
'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource',
|
||||||
'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode',
|
'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple',
|
||||||
'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift',
|
'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift',
|
||||||
'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode'
|
'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from flask import abort, request
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from sqlalchemy.orm import DeclarativeBase, Session
|
from sqlalchemy.orm import DeclarativeBase, Session
|
||||||
|
|
||||||
|
from .codecs import want_bytes
|
||||||
from .sqlalchemy import require_auth_base
|
from .sqlalchemy import require_auth_base
|
||||||
|
|
||||||
class FlaskAuthSrc(AuthSrc):
|
class FlaskAuthSrc(AuthSrc):
|
||||||
|
|
@ -35,7 +36,9 @@ class FlaskAuthSrc(AuthSrc):
|
||||||
return self.db.session
|
return self.db.session
|
||||||
def get_token(self):
|
def get_token(self):
|
||||||
return request.authorization.token
|
return request.authorization.token
|
||||||
|
def get_signature(self) -> bytes:
|
||||||
|
sig = request.headers.get('authorization-signature', None)
|
||||||
|
return want_bytes(sig) if sig else None
|
||||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||||
abort(400, msg)
|
abort(400, msg)
|
||||||
def required_exc(self):
|
def required_exc(self):
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,9 @@ This software is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable, TypeVar
|
||||||
|
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
def makelist(l: Any) -> list:
|
def makelist(l: Any) -> list:
|
||||||
'''
|
'''
|
||||||
|
|
@ -30,6 +31,25 @@ def makelist(l: Any) -> list:
|
||||||
else:
|
else:
|
||||||
return [l]
|
return [l]
|
||||||
|
|
||||||
|
def ltuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Truncate an iterable into a fixed size tuple, if necessary padding it.
|
||||||
|
"""
|
||||||
|
seq = tuple(seq)[:size]
|
||||||
|
if len(seq) < size:
|
||||||
|
seq = seq + (pad,) * (size - len(seq))
|
||||||
|
return seq
|
||||||
|
|
||||||
|
def rtuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Same as rtuple() but the padding and truncation is made right to left.
|
||||||
|
"""
|
||||||
|
seq = tuple(seq)[-size:]
|
||||||
|
if len(seq) < size:
|
||||||
|
seq = (pad,) * (size - len(seq)) + seq
|
||||||
|
return seq
|
||||||
|
|
||||||
|
|
||||||
def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]:
|
def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]:
|
||||||
'''
|
'''
|
||||||
Subset of keyword arguments. Useful for callable wrapping.
|
Subset of keyword arguments. Useful for callable wrapping.
|
||||||
|
|
@ -38,4 +58,4 @@ def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('makelist', 'kwargs_prefix')
|
__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple')
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, Sequence
|
||||||
from itsdangerous import TimestampSigner
|
from itsdangerous import TimestampSigner
|
||||||
|
|
||||||
from .codecs import want_str
|
from suou.itertools import rtuple
|
||||||
|
|
||||||
|
from .functools import not_implemented
|
||||||
|
from .codecs import jsondecode, jsonencode, want_bytes, want_str
|
||||||
from .iding import Siq
|
from .iding import Siq
|
||||||
|
|
||||||
class UserSigner(TimestampSigner):
|
class UserSigner(TimestampSigner):
|
||||||
|
|
@ -36,6 +39,20 @@ class UserSigner(TimestampSigner):
|
||||||
def split_token(cls, /, token: str | bytes) :
|
def split_token(cls, /, token: str | bytes) :
|
||||||
a, b, c = want_str(token).rsplit('.', 2)
|
a, b, c = want_str(token).rsplit('.', 2)
|
||||||
return b64decode(a), b, b64decode(c)
|
return b64decode(a), b, b64decode(c)
|
||||||
|
def sign_object(self, obj: dict, /, *, encoder=jsonencode, **kwargs):
|
||||||
|
"""
|
||||||
|
Return a signed JSON payload of an object.
|
||||||
|
|
||||||
|
MUST be passed as a dict: ser/deser it's not the signer's job.
|
||||||
|
"""
|
||||||
|
return self.sign(encoder(obj), **kwargs)
|
||||||
|
def unsign_object(self, payload: str | bytes, /, *, decoder=jsondecode, **kwargs):
|
||||||
|
"""
|
||||||
|
Unsign and parse a JSON object signed payload. Returns a dict.
|
||||||
|
"""
|
||||||
|
return decoder(self.unsign(payload, **kwargs))
|
||||||
|
def split_signed(self, payload: str | bytes) -> Sequence[bytes]:
|
||||||
|
return rtuple(want_bytes(payload).rsplit(b'.', 2), 3, b'')
|
||||||
|
|
||||||
|
|
||||||
class HasSigner(ABC):
|
class HasSigner(ABC):
|
||||||
|
|
|
||||||
|
|
@ -194,22 +194,18 @@ class AuthSrc(metaclass=ABCMeta):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_token(self):
|
def get_token(self):
|
||||||
pass
|
pass
|
||||||
|
@abstractmethod
|
||||||
|
def get_signature(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Callable[[Callable[[], _T]], Any], session_src: Callable[[], Session],
|
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||||
column: str | Column[_T] = 'id', dest: str = 'user', required: bool = False, validators: Callable | Iterable[Callable] | None = None,
|
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None,
|
||||||
invalid_exc: Callable | None = None, required_exc: Callable | None = None):
|
invalid_exc: Callable | None = None, required_exc: Callable | None = None):
|
||||||
'''
|
'''
|
||||||
Inject the current user into a view, given the Authorization: Bearer header.
|
Inject the current user into a view, given the Authorization: Bearer header.
|
||||||
|
|
||||||
For portability reasons, this is a partial, two-component function.
|
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
||||||
|
|
||||||
The value_src() callable takes a callable as its only and required argument, and the supplied
|
|
||||||
callable, when called, returns the value of the column.
|
|
||||||
|
|
||||||
The session_src() callable takes no argument, and returns a Session object.
|
|
||||||
|
|
||||||
XXX maybe a class is better than a thousand callables?
|
|
||||||
'''
|
'''
|
||||||
col = want_column(cls, column)
|
col = want_column(cls, column)
|
||||||
validators = makelist(validators)
|
validators = makelist(validators)
|
||||||
|
|
@ -218,7 +214,7 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Ca
|
||||||
if token is None:
|
if token is None:
|
||||||
return None
|
return None
|
||||||
tok_parts = UserSigner.split_token(token)
|
tok_parts = UserSigner.split_token(token)
|
||||||
user: HasSigner = session_src().execute(select(cls).where(col == tok_parts[0])).scalar()
|
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
||||||
try:
|
try:
|
||||||
signer: UserSigner = user.signer()
|
signer: UserSigner = user.signer()
|
||||||
signer.unsign(token)
|
signer.unsign(token)
|
||||||
|
|
@ -235,9 +231,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, value_src: Ca
|
||||||
def decorator(func: Callable):
|
def decorator(func: Callable):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*a, **ka):
|
def wrapper(*a, **ka):
|
||||||
ka[dest] = value_src(get_user)
|
ka[dest] = get_user(src.get_token())
|
||||||
if not ka[dest] and required:
|
if not ka[dest] and required:
|
||||||
required_exc()
|
required_exc()
|
||||||
|
if signed:
|
||||||
|
ka[sig_dest] = src.get_signature()
|
||||||
for valid in validators:
|
for valid in validators:
|
||||||
if not valid(ka[dest]):
|
if not valid(ka[dest]):
|
||||||
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue