diff --git a/CHANGELOG.md b/CHANGELOG.md index 4011856..6ce0f75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,31 +1,14 @@ # Changelog -## 0.3.4 - -- Bug fixes in `.flask_sqlalchemy` and `.sqlalchemy` — `require_auth()` is unusable before this point! - -## 0.3.3 - -- Fixed leftovers in `snowflake` module from unchecked code copying — i.e. `SnowflakeGen.generate_one()` used to require an unused typ= parameter -- Fixed a bug in `id_column()` that made it fail to provide a working generator — again, this won't be backported - -## 0.3.2 - -- Fixed bugs in Snowflake generation and serialization of negative values - ## 0.3.0 -- Fixed `cb32encode()` and `b32lencode()` doing wrong padding — **UNSOLVED in 0.2.x** which is out of support, effective immediately -- **Changed behavior** of `kwargs_prefix()` which now removes keys from original mapping by default - Add SQLAlchemy auth loaders i.e. `sqlalchemy.require_auth_base()`, `flask_sqlalchemy`. What auth loaders do is loading user token and signature into app -- `sqlalchemy`: add `parent_children()` and `create_session()` - Implement `UserSigner()` - Improve JSON handling in `flask_restx` - Add base2048 (i.e. [BIP-39](https://github.com/bitcoin/bips/blob/master/bip-0039.mediawiki)) codec -- Add `split_bits()`, `join_bits()`, `ltuple()`, `rtuple()`, `ssv_list()`, `additem()` +- Add `split_bits()`, `join_bits()`, `ltuple()`, `rtuple()` - Add `markdown` extensions -- Add Snowflake manipulation utilities ## 0.2.3 diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 51253f2..8b87268 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -17,23 +17,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ from .iding import Siq, SiqCache, SiqType, SiqGen -from .codecs import (StringCase, cb32encode, cb32decode, b32lencode, b32ldecode, b64encode, b64decode, b2048encode, b2048decode, - jsonencode, want_bytes, want_str, ssv_list) +from .codecs import StringCase, cb32encode, cb32decode, jsonencode, want_bytes, want_str, b2048encode, b2048decode from .bits import count_ones, mask_shift, split_bits, join_bits from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource from .functools import deprecated, not_implemented from .classtools import Wanted, Incomplete -from .itertools import makelist, kwargs_prefix, ltuple, rtuple, additem +from .itertools import makelist, kwargs_prefix, ltuple, rtuple from .i18n import I18n, JsonI18n, TomlI18n -from .snowflake import Snowflake, SnowflakeGen -__version__ = "0.3.4.rc1" +__version__ = "0.3.0-dev22" __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase', 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource', 'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple', 'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift', - 'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode', - 'Snowflake', 'SnowflakeGen', 'ssv_list', 'additem', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode' + 'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode' ) diff --git a/src/suou/codecs.py b/src/suou/codecs.py index 2bee255..5f74867 100644 --- a/src/suou/codecs.py +++ b/src/suou/codecs.py @@ -162,7 +162,7 @@ def cb32decode(val: bytes | str) -> str: ''' Decode bytes from Crockford Base32. ''' - return base64.b32decode(want_bytes(val).upper().translate(CROCKFORD_TO_B32) + b'=' * ((8 - len(val) % 8) % 8)) + return base64.b32decode(want_bytes(val).upper().translate(CROCKFORD_TO_B32) + b'=' * ((5 - len(val) % 5) % 5)) def b32lencode(val: bytes) -> str: ''' @@ -174,7 +174,7 @@ def b32ldecode(val: bytes | str) -> bytes: ''' Decode a lowercase base32 encoded byte sequence. Padding is managed automatically. ''' - return base64.b32decode(want_bytes(val).upper() + b'=' * ((8 - len(val) % 8) % 8)) + return base64.b32decode(want_bytes(val).upper() + b'=' * ((5 - len(val) % 5) % 5)) def b64encode(val: bytes, *, strip: bool = True) -> str: ''' @@ -229,35 +229,6 @@ def jsonencode(obj: dict, *, skipkeys: bool = True, separators: tuple[str, str] jsondecode = deprecated('just use json.loads()')(json.loads) -def ssv_list(s: str, *, sep_chars = ',;') -> list[str]: - """ - Parse values from a Space Separated Values (SSV) string. - - By default, values are split on spaces, commas (,) and semicolons (;), configurable - with sepchars= argument. - - Double quotes (") can be used to allow spaces, commas etc. in values. Doubled double - quotes ("") are parsed as literal double quotes. - - Useful for environment variables: pass it to ConfigValue() as the cast= argument. - """ - sep_re = r'\s+|\s*[' + re.escape(sep_chars) + r']\s*' - parts = s.split('"') - parts[::2] = [re.split(sep_re, x) for x in parts[::2]] - l: list[str] = parts[0].copy() - for i in range(1, len(parts), 2): - p0, *pt = parts[i+1] - # two "strings" sandwiching each other case - if i < len(parts)-2 and parts[i] and parts[i+2] and not p0 and not pt: - p0 = '"' - l[-1] += ('"' if parts[i] == '' else parts[i]) + p0 - l.extend(pt) - if l and l[0] == '': - l.pop(0) - if l and l[-1] == '': - l.pop() - return l - class StringCase(enum.Enum): """ Enum values used by regex validators and storage converters. @@ -266,7 +237,7 @@ class StringCase(enum.Enum): LOWER = case insensitive, force lowercase UPPER = case insensitive, force uppercase IGNORE = case insensitive, leave as is, use lowercase in comparison - IGNORE_UPPER = same as above, but use uppercase in comparison + IGNORE_UPPER = same as above, but use uppercase il comparison """ AS_IS = 0 LOWER = FORCE_LOWER = 1 @@ -293,5 +264,5 @@ class StringCase(enum.Enum): __all__ = ( 'cb32encode', 'cb32decode', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode', 'jsonencode' - 'StringCase', 'want_bytes', 'want_str', 'jsondecode', 'ssv_list' + 'StringCase', 'want_bytes', 'want_str', 'jsondecode' ) \ No newline at end of file diff --git a/src/suou/flask_restx.py b/src/suou/flask_restx.py index ecdf3da..9d4955a 100644 --- a/src/suou/flask_restx.py +++ b/src/suou/flask_restx.py @@ -54,10 +54,7 @@ class Api(_Api): Notably, all JSON is whitespace-free and .message is remapped to .error """ def handle_error(self, e): - ### XXX apparently this handle_error does not get called AT ALL. - print(e) res = super().handle_error(e) - print(res) if isinstance(res, Mapping) and 'message' in res: res['error'] = res['message'] del res['message'] diff --git a/src/suou/flask_sqlalchemy.py b/src/suou/flask_sqlalchemy.py index 508e296..8ef0d12 100644 --- a/src/suou/flask_sqlalchemy.py +++ b/src/suou/flask_sqlalchemy.py @@ -22,7 +22,7 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import DeclarativeBase, Session from .codecs import want_bytes -from .sqlalchemy import AuthSrc, require_auth_base +from .sqlalchemy import require_auth_base class FlaskAuthSrc(AuthSrc): ''' @@ -35,26 +35,22 @@ class FlaskAuthSrc(AuthSrc): def get_session(self) -> Session: return self.db.session def get_token(self): - if request.authorization: - return request.authorization.token + return request.authorization.token def get_signature(self) -> bytes: sig = request.headers.get('authorization-signature', None) return want_bytes(sig) if sig else None def invalid_exc(self, msg: str = 'validation failed') -> Never: abort(400, msg) def required_exc(self): - abort(401, 'Login required') + abort(401) -def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable: +def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Callable]: """ Make an auth_required() decorator for Flask views. This looks for a token in the Authorization header, validates it, loads the appropriate object, and injects it as the user= parameter. - NOTE: the actual decorator to be used on routes is **auth_required()**, - NOT require_auth() which is the **constructor** for it. - cls is a SQLAlchemy table. db is a flask_sqlalchemy.SQLAlchemy() binding. @@ -67,12 +63,7 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable: def super_secret_stuff(user): pass """ - def auth_required(**kwargs): - return require_auth_base(cls=cls, src=FlaskAuthSrc(db), **kwargs) - - auth_required.__doc__ = require_auth_base.__doc__ - - return auth_required + return partial(require_auth_base, cls=cls, src=FlaskAuthSrc(db)) __all__ = ('require_auth', ) diff --git a/src/suou/functools.py b/src/suou/functools.py index 9041f92..eb56d2e 100644 --- a/src/suou/functools.py +++ b/src/suou/functools.py @@ -22,7 +22,7 @@ try: from warnings import deprecated except ImportError: # Python <=3.12 does not implement warnings.deprecated - def deprecated(message: str, /, *, category=DeprecationWarning, stacklevel:int=1): + def deprecated(message: str, /, *, category=DeprecationWarning): """ Backport of PEP 702 for Python <=3.12. The stack_level stuff is not reimplemented on purpose because @@ -32,7 +32,7 @@ except ImportError: @wraps(func) def wrapper(*a, **ka): if category is not None: - warnings.warn(message, category, stacklevel=stacklevel) + warnings.warn(message, category) return func(*a, **ka) func.__deprecated__ = True wrapper.__deprecated__ = True diff --git a/src/suou/iding.py b/src/suou/iding.py index 6cfcd5e..dba591c 100644 --- a/src/suou/iding.py +++ b/src/suou/iding.py @@ -40,7 +40,7 @@ import os from typing import Iterable, override import warnings -from .functools import deprecated +from .functools import not_implemented, deprecated from .codecs import b32lencode, b64encode, cb32encode @@ -206,9 +206,7 @@ class SiqCache: return self.generator.last_gen_ts def cur_timestamp(self): return self.generator.cur_timestamp() - def __init__(self, generator: SiqGen | str, typ: SiqType, size: int = 64, max_age: int = 1024): - if isinstance(generator, str): - generator = SiqGen(generator) + def __init__(self, generator: SiqGen, typ: SiqType, size: int = 64, max_age: int = 1024): self.generator = generator self.typ = typ self.size = size @@ -222,9 +220,6 @@ class SiqCache: return self._cache.pop(0) class Siq(int): - """ - Representation of a SIQ as an integer. - """ def to_bytes(self, length: int = 14, byteorder = 'big', *, signed: bool = False) -> bytes: return super().to_bytes(length, byteorder, signed=signed) @classmethod @@ -235,22 +230,17 @@ class Siq(int): def to_base64(self, length: int = 15, *, strip: bool = True) -> str: return b64encode(self.to_bytes(length), strip=strip) - def to_cb32(self) -> str: + def to_cb32(self)-> str: return cb32encode(self.to_bytes(15, 'big')) to_crockford = to_cb32 def to_hex(self) -> str: return f'{self:x}' def to_oct(self) -> str: return f'{self:o}' - def to_b32l(self) -> str: - """ - This is NOT the URI serializer! - """ - return b32lencode(self.to_bytes(15, 'big')) - def __str__(self) -> str: - return int.__str__(self) - to_dec = deprecated('use str() instead')(__str__) - + @deprecated('use str() instead') + def to_dec(self) -> str: + return f'{self}' + @override def __format__(self, opt: str, /) -> str: try: @@ -266,9 +256,7 @@ class Siq(int): case '0c': return '0' + self.to_cb32() case 'd' | '': - return int.__repr__(self) - case 'l': - return self.to_b32l() + return int.__str__(self) case 'o' | 'x': return int.__format__(self, opt) case 'u': @@ -299,15 +287,6 @@ class Siq(int): def __repr__(self): return f'{self.__class__.__name__}({super().__repr__()})' - # convenience methods - def timestamp(self): - return (self >> 56) / (1 << 16) - - def shard_id(self): - return (self >> 48) % 256 - - def domain_name(self): - return (self >> 16) % 0xffffffff __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen' diff --git a/src/suou/itertools.py b/src/suou/itertools.py index 9f80faa..dad51f4 100644 --- a/src/suou/itertools.py +++ b/src/suou/itertools.py @@ -14,8 +14,7 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ''' -from typing import Any, Iterable, MutableMapping, TypeVar -import warnings +from typing import Any, Iterable, TypeVar _T = TypeVar('_T') @@ -51,38 +50,12 @@ def rtuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple: return seq -def kwargs_prefix(it: dict[str, Any], prefix: str, *, remove = True, keep_prefix = False) -> dict[str, Any]: +def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]: ''' Subset of keyword arguments. Useful for callable wrapping. - - By default, it removes arguments from original kwargs as well. You can prevent by - setting remove=False. - - By default, specified prefix is removed from each key of the returned - dictionary; keep_prefix=True keeps the prefix on keys. ''' - keys = [k for k in it.keys() if k.startswith(prefix)] - - ka = dict() - for k in keys: - ka[k if keep_prefix else k.removeprefix(prefix)] = it[k] - if remove: - for k in keys: - it.pop(k) - return ka - -def additem(obj: MutableMapping, /, name: str = None): - """ - Syntax sugar for adding a function to a mapping, immediately. - """ - def decorator(func): - key = name or func.__name__ - if key in obj: - warnings.warn(f'mapping does already have item {key!r}') - obj[key] = func - return func - return decorator + return {k.removeprefix(prefix): v for k, v in it.items() if k.startswith(prefix)} -__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple', 'additem') +__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple') diff --git a/src/suou/snowflake.py b/src/suou/snowflake.py deleted file mode 100644 index 3f9190e..0000000 --- a/src/suou/snowflake.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -Utilities for Snowflake-like identifiers. - -Here for applications who benefit from their use. I (sakuragasaki46) -recommend using SIQ (.iding) when applicable; there also utilities to -convert snowflakes into SIQ's in .migrate. - ---- - -Copyright (c) 2025 Sakuragasaki46. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -See LICENSE for the specific language governing permissions and -limitations under the License. - -This software is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -""" - - -from __future__ import annotations -import os -from threading import Lock -import time -from typing import override -import warnings - -from .migrate import SnowflakeSiqMigrator -from .iding import SiqType -from .codecs import b32ldecode, b32lencode, b64encode, cb32encode -from .functools import deprecated - - -class SnowflakeGen: - """ - Implements a generator Snowflake ID's (i.e. the ones in use at Twitter / Discord). - - Discord snowflakes are in this format: - tttttttt tttttttt tttttttt tttttttt - tttttttt ttddddds sssspppp pppppppp - - where: - t: timestamp (in milliseconds) — 42 bits - d: local ID — 5 bits - s: shard ID — 5 bits - p: progressive counter — 10 bits - - Converter takes local ID and shard ID as one; latter 8 bits are taken for - the shard ID, while the former 2 are added to timestamp, taking advantage of - more precision — along with up to 2 most significant bits of progressive co - - The constructor takes an epoch argument, since snowflakes, due to - optimization requirements, are based on a different epoch (e.g. - Jan 1, 2015 for Discord); epoch is wanted as seconds since Unix epoch - (i.e. midnight of Jan 1, 1970). - """ - epoch: int - local_id: int - shard_id: int - counter: int - last_gen_ts: int - - TS_ACCURACY = 1000 - - - def __init__(self, epoch: int, local_id: int = 0, shard_id: int | None = None, - last_id: int = 0 - ): - self.epoch = epoch - self.local_id = local_id - self.shard_id = (shard_id or os.getpid()) % 32 - self.counter = 0 - self.last_gen_ts = min(last_id >> 22, self.cur_timestamp()) - def cur_timestamp(self) -> int: - return int((time.time() - self.epoch) * self.TS_ACCURACY) - def generate(self, /, n: int = 1): - """ - Generate one or more snowflakes. - The generated ids are returned as integers. - Bulk generation is supported. - - Returns as an iterator, to allow generation “on the fly”. - To get a scalar or a list, use .generate_one() or next(), or - .generate_list() or list(.generate()), respectively. - - Warning: the function **may block**. - """ - now = self.cur_timestamp() - if now < self.last_gen_ts: - time.sleep((self.last_gen_ts - now) / (1 << 16)) - elif now > self.last_gen_ts: - self.counter = 0 - while n: - if self.counter >= 4096: - while (now := self.cur_timestamp()) <= self.last_gen_ts: - time.sleep(1 / (1 << 16)) - with Lock(): - self.counter %= 1 << 16 - # XXX the lock is here "just in case", MULTITHREADED GENERATION IS NOT ADVISED! - with Lock(): - siq = ( - (now << 22) | - ((self.local_id % 32) << 17) | - ((self.shard_id % 32) << 12) | - (self.counter % (1 << 12)) - ) - n -= 1 - self.counter += 1 - yield siq - def generate_one(self, /) -> int: - return next(self.generate(1)) - def generate_list(self, /, n: int = 1) -> list[int]: - return list(self.generate(n)) - - -class Snowflake(int): - """ - Representation of a Snowflake as an integer. - """ - - def to_bytes(self, length: int = 14, byteorder = "big", *, signed: bool = False) -> bytes: - return super().to_bytes(length, byteorder, signed=signed) - def to_base64(self, length: int = 9, *, strip: bool = True) -> str: - return b64encode(self.to_bytes(length), strip=strip) - def to_cb32(self)-> str: - return cb32encode(self.to_bytes(8, 'big')) - to_crockford = to_cb32 - def to_hex(self) -> str: - return f'{self:x}' - def to_oct(self) -> str: - return f'{self:o}' - def to_b32l(self) -> str: - # PSA Snowflake Base32 representations are padded to 10 bytes! - if self < 0: - return '_' + Snowflake.to_b32l(-self) - return b32lencode(self.to_bytes(10, 'big')).lstrip('a') - - @classmethod - def from_bytes(cls, b: bytes, byteorder = 'big', *, signed: bool = False) -> Snowflake: - if len(b) not in (8, 10): - warnings.warn('Snowflakes are exactly 8 bytes long', BytesWarning) - return super().from_bytes(b, byteorder, signed=signed) - - @classmethod - def from_b32l(cls, val: str) -> Snowflake: - if val.startswith('_'): - ## support for negative Snowflakes - return -cls.from_b32l(val.lstrip('_')) - return Snowflake.from_bytes(b32ldecode(val.rjust(16, 'a'))) - - @override - def __format__(self, opt: str, /) -> str: - try: - return self.format(opt) - except ValueError: - return super().__format__(opt) - def format(self, opt: str, /) -> str: - match opt: - case 'b': - return self.to_base64() - case 'c': - return self.to_cb32() - case '0c': - return '0' + self.to_cb32() - case 'd' | '': - return int.__repr__(self) - case 'l': - return self.to_b32l() - case 'o' | 'x': - return int.__format__(self, opt) - case _: - raise ValueError(f'unknown format: {opt!r}') - - def __str__(self) -> str: - return int.__str__(self) - to_dec = deprecated('use str() instead')(__str__) - - def __repr__(self): - return f'{self.__class__.__name__}({super().__repr__()})' - - def to_siq(self, domain: str, epoch: int, target_type: SiqType, **kwargs): - """ - Convenience method for conversion to SIQ. - - (!) This does not check for existence! Always do the check yourself. - """ - return SnowflakeSiqMigrator(domain, epoch, **kwargs).to_siq(self, target_type) - - - -__all__ = ( - 'Snowflake', 'SnowflakeGen' -) \ No newline at end of file diff --git a/src/suou/sqlalchemy.py b/src/suou/sqlalchemy.py index 8ff32a4..3886d0d 100644 --- a/src/suou/sqlalchemy.py +++ b/src/suou/sqlalchemy.py @@ -18,17 +18,16 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod from functools import wraps -from typing import Callable, Iterable, Never, TypeVar +from typing import Any, Callable, Iterable, Never, TypeVar import warnings -from sqlalchemy import BigInteger, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text -from sqlalchemy.orm import DeclarativeBase, Session, declarative_base as _declarative_base, relationship +from sqlalchemy import CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, select, text +from sqlalchemy.orm import DeclarativeBase, Session, declarative_base as _declarative_base -from .snowflake import SnowflakeGen from .itertools import kwargs_prefix, makelist from .signing import HasSigner, UserSigner from .codecs import StringCase -from .functools import deprecated, not_implemented -from .iding import Siq, SiqGen, SiqType, SiqCache +from .functools import deprecated +from .iding import SiqType, SiqCache from .classtools import Incomplete, Wanted _T = TypeVar('_T') @@ -37,7 +36,7 @@ _T = TypeVar('_T') # Not to be confused with SiqType. IdType = LargeBinary(16) -@not_implemented + def sql_escape(s: str, /, dialect: Dialect) -> str: """ Escape a value for SQL embedding, using SQLAlchemy's literal processors. @@ -50,49 +49,20 @@ def sql_escape(s: str, /, dialect: Dialect) -> str: raise TypeError('invalid data type') -def create_session(url: str) -> Session: - """ - Create a session on the fly, given a database URL. Useful for - contextless environments, such as Python REPL. - - Heads up: a function with the same name exists in core sqlalchemy, but behaves - completely differently!! - """ - engine = create_engine(url) - return Session(bind = engine) - -def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs): +def id_column(typ: SiqType, *, primary_key: bool = True): """ Marks a column which contains a SIQ. """ def new_id_factory(owner: DeclarativeBase) -> Callable: domain_name = owner.metadata.info['domain_name'] - idgen = SiqCache(SiqGen(domain_name), typ) + idgen = SiqCache(domain_name, typ) def new_id() -> bytes: - return Siq(idgen.generate()).to_bytes() + return idgen.generate().to_bytes() return new_id if primary_key: - return Incomplete(Column, IdType, primary_key = True, default = Wanted(new_id_factory), **kwargs) + return Incomplete(Column, IdType, primary_key = True, default = Wanted(new_id_factory)) else: - return Incomplete(Column, IdType, unique = True, nullable = False, default = Wanted(new_id_factory), **kwargs) - -def snowflake_column(*, primary_key: bool = True, **kwargs): - """ - Same as id_column() but with snowflakes. - - XXX this is meant ONLY as means of transition; for new stuff, use id_column() and SIQ. - """ - def new_id_factory(owner: DeclarativeBase) -> Callable: - epoch = owner.metadata.info['snowflake_epoch'] - # more arguments will be passed on (?) - idgen = SnowflakeGen(epoch) - def new_id() -> int: - return idgen.generate_one() - return new_id - if primary_key: - return Incomplete(Column, BigInteger, primary_key = True, default = Wanted(new_id_factory), **kwargs) - else: - return Incomplete(Column, BigInteger, unique = True, nullable = False, default = Wanted(new_id_factory), **kwargs) + return Incomplete(Column, IdType, unique = True, nullable = False, default = Wanted(new_id_factory)) def match_constraint(col_name: str, regex: str, /, dialect: str = 'default', constraint_name: str | None = None) -> CheckConstraint: @@ -129,12 +99,9 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No metadata = dict() if 'info' not in metadata: metadata['info'] = dict() - # snowflake metadata - snowflake_kwargs = kwargs_prefix(kwargs, 'snowflake_', remove=True, keep_prefix=True) metadata['info'].update( domain_name = domain_name, - secret_key = master_secret, - **snowflake_kwargs + secret_key = master_secret ) Base = _declarative_base(metadata=MetaData(**metadata), **kwargs) return Base @@ -193,26 +160,6 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]: return (date_col, acc_col) -def parent_children(keyword: str, /, **kwargs): - """ - Self-referential one-to-many relationship pair. - Parent comes first, children come later. - - keyword is used in back_populates column names: convention over - configuration. Naming it otherwise will BREAK your models. - - Additional keyword arguments can be sourced with parent_ and child_ argument prefixes, - obviously. - """ - - parent_kwargs = kwargs_prefix(kwargs, 'parent_') - child_kwargs = kwargs_prefix(kwargs, 'child_') - - parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', **parent_kwargs) - child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', **child_kwargs) - - return parent, child - def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]: """ Return a table's column given its name. @@ -253,7 +200,8 @@ class AuthSrc(metaclass=ABCMeta): def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', - required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None, + invalid_exc: Callable | None = None, required_exc: Callable | None = None): ''' Inject the current user into a view, given the Authorization: Bearer header. @@ -274,11 +222,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | except Exception: return None - def _default_invalid(msg: str = 'validation failed'): + def _default_invalid(msg: str): raise ValueError(msg) - invalid_exc = src.invalid_exc or _default_invalid - required_exc = src.required_exc or (lambda: _default_invalid()) + invalid_exc = invalid_exc or _default_invalid + required_exc = required_exc or (lambda: _default_invalid()) def decorator(func: Callable): @wraps(func)