diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ce0f75..4011856 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,31 @@ # Changelog +## 0.3.4 + +- Bug fixes in `.flask_sqlalchemy` and `.sqlalchemy` — `require_auth()` is unusable before this point! + +## 0.3.3 + +- Fixed leftovers in `snowflake` module from unchecked code copying — i.e. `SnowflakeGen.generate_one()` used to require an unused typ= parameter +- Fixed a bug in `id_column()` that made it fail to provide a working generator — again, this won't be backported + +## 0.3.2 + +- Fixed bugs in Snowflake generation and serialization of negative values + ## 0.3.0 +- Fixed `cb32encode()` and `b32lencode()` doing wrong padding — **UNSOLVED in 0.2.x** which is out of support, effective immediately +- **Changed behavior** of `kwargs_prefix()` which now removes keys from original mapping by default - Add SQLAlchemy auth loaders i.e. `sqlalchemy.require_auth_base()`, `flask_sqlalchemy`. What auth loaders do is loading user token and signature into app +- `sqlalchemy`: add `parent_children()` and `create_session()` - Implement `UserSigner()` - Improve JSON handling in `flask_restx` - Add base2048 (i.e. [BIP-39](https://github.com/bitcoin/bips/blob/master/bip-0039.mediawiki)) codec -- Add `split_bits()`, `join_bits()`, `ltuple()`, `rtuple()` +- Add `split_bits()`, `join_bits()`, `ltuple()`, `rtuple()`, `ssv_list()`, `additem()` - Add `markdown` extensions +- Add Snowflake manipulation utilities ## 0.2.3 diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 8b87268..51253f2 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -17,20 +17,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ from .iding import Siq, SiqCache, SiqType, SiqGen -from .codecs import StringCase, cb32encode, cb32decode, jsonencode, want_bytes, want_str, b2048encode, b2048decode +from .codecs import (StringCase, cb32encode, cb32decode, b32lencode, b32ldecode, b64encode, b64decode, b2048encode, b2048decode, + jsonencode, want_bytes, want_str, ssv_list) from .bits import count_ones, mask_shift, split_bits, join_bits from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource from .functools import deprecated, not_implemented from .classtools import Wanted, Incomplete -from .itertools import makelist, kwargs_prefix, ltuple, rtuple +from .itertools import makelist, kwargs_prefix, ltuple, rtuple, additem from .i18n import I18n, JsonI18n, TomlI18n +from .snowflake import Snowflake, SnowflakeGen -__version__ = "0.3.0-dev22" +__version__ = "0.3.4.rc1" __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase', 'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource', 'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple', 'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift', - 'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode' + 'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode', + 'Snowflake', 'SnowflakeGen', 'ssv_list', 'additem', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode' ) diff --git a/src/suou/codecs.py b/src/suou/codecs.py index 5f74867..2bee255 100644 --- a/src/suou/codecs.py +++ b/src/suou/codecs.py @@ -162,7 +162,7 @@ def cb32decode(val: bytes | str) -> str: ''' Decode bytes from Crockford Base32. ''' - return base64.b32decode(want_bytes(val).upper().translate(CROCKFORD_TO_B32) + b'=' * ((5 - len(val) % 5) % 5)) + return base64.b32decode(want_bytes(val).upper().translate(CROCKFORD_TO_B32) + b'=' * ((8 - len(val) % 8) % 8)) def b32lencode(val: bytes) -> str: ''' @@ -174,7 +174,7 @@ def b32ldecode(val: bytes | str) -> bytes: ''' Decode a lowercase base32 encoded byte sequence. Padding is managed automatically. ''' - return base64.b32decode(want_bytes(val).upper() + b'=' * ((5 - len(val) % 5) % 5)) + return base64.b32decode(want_bytes(val).upper() + b'=' * ((8 - len(val) % 8) % 8)) def b64encode(val: bytes, *, strip: bool = True) -> str: ''' @@ -229,6 +229,35 @@ def jsonencode(obj: dict, *, skipkeys: bool = True, separators: tuple[str, str] jsondecode = deprecated('just use json.loads()')(json.loads) +def ssv_list(s: str, *, sep_chars = ',;') -> list[str]: + """ + Parse values from a Space Separated Values (SSV) string. + + By default, values are split on spaces, commas (,) and semicolons (;), configurable + with sepchars= argument. + + Double quotes (") can be used to allow spaces, commas etc. in values. Doubled double + quotes ("") are parsed as literal double quotes. + + Useful for environment variables: pass it to ConfigValue() as the cast= argument. + """ + sep_re = r'\s+|\s*[' + re.escape(sep_chars) + r']\s*' + parts = s.split('"') + parts[::2] = [re.split(sep_re, x) for x in parts[::2]] + l: list[str] = parts[0].copy() + for i in range(1, len(parts), 2): + p0, *pt = parts[i+1] + # two "strings" sandwiching each other case + if i < len(parts)-2 and parts[i] and parts[i+2] and not p0 and not pt: + p0 = '"' + l[-1] += ('"' if parts[i] == '' else parts[i]) + p0 + l.extend(pt) + if l and l[0] == '': + l.pop(0) + if l and l[-1] == '': + l.pop() + return l + class StringCase(enum.Enum): """ Enum values used by regex validators and storage converters. @@ -237,7 +266,7 @@ class StringCase(enum.Enum): LOWER = case insensitive, force lowercase UPPER = case insensitive, force uppercase IGNORE = case insensitive, leave as is, use lowercase in comparison - IGNORE_UPPER = same as above, but use uppercase il comparison + IGNORE_UPPER = same as above, but use uppercase in comparison """ AS_IS = 0 LOWER = FORCE_LOWER = 1 @@ -264,5 +293,5 @@ class StringCase(enum.Enum): __all__ = ( 'cb32encode', 'cb32decode', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode', 'jsonencode' - 'StringCase', 'want_bytes', 'want_str', 'jsondecode' + 'StringCase', 'want_bytes', 'want_str', 'jsondecode', 'ssv_list' ) \ No newline at end of file diff --git a/src/suou/flask_restx.py b/src/suou/flask_restx.py index 9d4955a..ecdf3da 100644 --- a/src/suou/flask_restx.py +++ b/src/suou/flask_restx.py @@ -54,7 +54,10 @@ class Api(_Api): Notably, all JSON is whitespace-free and .message is remapped to .error """ def handle_error(self, e): + ### XXX apparently this handle_error does not get called AT ALL. + print(e) res = super().handle_error(e) + print(res) if isinstance(res, Mapping) and 'message' in res: res['error'] = res['message'] del res['message'] diff --git a/src/suou/flask_sqlalchemy.py b/src/suou/flask_sqlalchemy.py index 8ef0d12..508e296 100644 --- a/src/suou/flask_sqlalchemy.py +++ b/src/suou/flask_sqlalchemy.py @@ -22,7 +22,7 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import DeclarativeBase, Session from .codecs import want_bytes -from .sqlalchemy import require_auth_base +from .sqlalchemy import AuthSrc, require_auth_base class FlaskAuthSrc(AuthSrc): ''' @@ -35,22 +35,26 @@ class FlaskAuthSrc(AuthSrc): def get_session(self) -> Session: return self.db.session def get_token(self): - return request.authorization.token + if request.authorization: + return request.authorization.token def get_signature(self) -> bytes: sig = request.headers.get('authorization-signature', None) return want_bytes(sig) if sig else None def invalid_exc(self, msg: str = 'validation failed') -> Never: abort(400, msg) def required_exc(self): - abort(401) + abort(401, 'Login required') -def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Callable]: +def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable: """ Make an auth_required() decorator for Flask views. This looks for a token in the Authorization header, validates it, loads the appropriate object, and injects it as the user= parameter. + NOTE: the actual decorator to be used on routes is **auth_required()**, + NOT require_auth() which is the **constructor** for it. + cls is a SQLAlchemy table. db is a flask_sqlalchemy.SQLAlchemy() binding. @@ -63,7 +67,12 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca def super_secret_stuff(user): pass """ - return partial(require_auth_base, cls=cls, src=FlaskAuthSrc(db)) + def auth_required(**kwargs): + return require_auth_base(cls=cls, src=FlaskAuthSrc(db), **kwargs) + + auth_required.__doc__ = require_auth_base.__doc__ + + return auth_required __all__ = ('require_auth', ) diff --git a/src/suou/functools.py b/src/suou/functools.py index eb56d2e..9041f92 100644 --- a/src/suou/functools.py +++ b/src/suou/functools.py @@ -22,7 +22,7 @@ try: from warnings import deprecated except ImportError: # Python <=3.12 does not implement warnings.deprecated - def deprecated(message: str, /, *, category=DeprecationWarning): + def deprecated(message: str, /, *, category=DeprecationWarning, stacklevel:int=1): """ Backport of PEP 702 for Python <=3.12. The stack_level stuff is not reimplemented on purpose because @@ -32,7 +32,7 @@ except ImportError: @wraps(func) def wrapper(*a, **ka): if category is not None: - warnings.warn(message, category) + warnings.warn(message, category, stacklevel=stacklevel) return func(*a, **ka) func.__deprecated__ = True wrapper.__deprecated__ = True diff --git a/src/suou/iding.py b/src/suou/iding.py index dba591c..6cfcd5e 100644 --- a/src/suou/iding.py +++ b/src/suou/iding.py @@ -40,7 +40,7 @@ import os from typing import Iterable, override import warnings -from .functools import not_implemented, deprecated +from .functools import deprecated from .codecs import b32lencode, b64encode, cb32encode @@ -206,7 +206,9 @@ class SiqCache: return self.generator.last_gen_ts def cur_timestamp(self): return self.generator.cur_timestamp() - def __init__(self, generator: SiqGen, typ: SiqType, size: int = 64, max_age: int = 1024): + def __init__(self, generator: SiqGen | str, typ: SiqType, size: int = 64, max_age: int = 1024): + if isinstance(generator, str): + generator = SiqGen(generator) self.generator = generator self.typ = typ self.size = size @@ -220,6 +222,9 @@ class SiqCache: return self._cache.pop(0) class Siq(int): + """ + Representation of a SIQ as an integer. + """ def to_bytes(self, length: int = 14, byteorder = 'big', *, signed: bool = False) -> bytes: return super().to_bytes(length, byteorder, signed=signed) @classmethod @@ -230,17 +235,22 @@ class Siq(int): def to_base64(self, length: int = 15, *, strip: bool = True) -> str: return b64encode(self.to_bytes(length), strip=strip) - def to_cb32(self)-> str: + def to_cb32(self) -> str: return cb32encode(self.to_bytes(15, 'big')) to_crockford = to_cb32 def to_hex(self) -> str: return f'{self:x}' def to_oct(self) -> str: return f'{self:o}' - @deprecated('use str() instead') - def to_dec(self) -> str: - return f'{self}' - + def to_b32l(self) -> str: + """ + This is NOT the URI serializer! + """ + return b32lencode(self.to_bytes(15, 'big')) + def __str__(self) -> str: + return int.__str__(self) + to_dec = deprecated('use str() instead')(__str__) + @override def __format__(self, opt: str, /) -> str: try: @@ -256,7 +266,9 @@ class Siq(int): case '0c': return '0' + self.to_cb32() case 'd' | '': - return int.__str__(self) + return int.__repr__(self) + case 'l': + return self.to_b32l() case 'o' | 'x': return int.__format__(self, opt) case 'u': @@ -287,6 +299,15 @@ class Siq(int): def __repr__(self): return f'{self.__class__.__name__}({super().__repr__()})' + # convenience methods + def timestamp(self): + return (self >> 56) / (1 << 16) + + def shard_id(self): + return (self >> 48) % 256 + + def domain_name(self): + return (self >> 16) % 0xffffffff __all__ = ( 'Siq', 'SiqCache', 'SiqType', 'SiqGen' diff --git a/src/suou/itertools.py b/src/suou/itertools.py index dad51f4..9f80faa 100644 --- a/src/suou/itertools.py +++ b/src/suou/itertools.py @@ -14,7 +14,8 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ''' -from typing import Any, Iterable, TypeVar +from typing import Any, Iterable, MutableMapping, TypeVar +import warnings _T = TypeVar('_T') @@ -50,12 +51,38 @@ def rtuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple: return seq -def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]: +def kwargs_prefix(it: dict[str, Any], prefix: str, *, remove = True, keep_prefix = False) -> dict[str, Any]: ''' Subset of keyword arguments. Useful for callable wrapping. + + By default, it removes arguments from original kwargs as well. You can prevent by + setting remove=False. + + By default, specified prefix is removed from each key of the returned + dictionary; keep_prefix=True keeps the prefix on keys. ''' - return {k.removeprefix(prefix): v for k, v in it.items() if k.startswith(prefix)} + keys = [k for k in it.keys() if k.startswith(prefix)] + + ka = dict() + for k in keys: + ka[k if keep_prefix else k.removeprefix(prefix)] = it[k] + if remove: + for k in keys: + it.pop(k) + return ka + +def additem(obj: MutableMapping, /, name: str = None): + """ + Syntax sugar for adding a function to a mapping, immediately. + """ + def decorator(func): + key = name or func.__name__ + if key in obj: + warnings.warn(f'mapping does already have item {key!r}') + obj[key] = func + return func + return decorator +__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple', 'additem') -__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple') diff --git a/src/suou/snowflake.py b/src/suou/snowflake.py new file mode 100644 index 0000000..3f9190e --- /dev/null +++ b/src/suou/snowflake.py @@ -0,0 +1,194 @@ +""" +Utilities for Snowflake-like identifiers. + +Here for applications who benefit from their use. I (sakuragasaki46) +recommend using SIQ (.iding) when applicable; there also utilities to +convert snowflakes into SIQ's in .migrate. + +--- + +Copyright (c) 2025 Sakuragasaki46. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +See LICENSE for the specific language governing permissions and +limitations under the License. + +This software is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" + + +from __future__ import annotations +import os +from threading import Lock +import time +from typing import override +import warnings + +from .migrate import SnowflakeSiqMigrator +from .iding import SiqType +from .codecs import b32ldecode, b32lencode, b64encode, cb32encode +from .functools import deprecated + + +class SnowflakeGen: + """ + Implements a generator Snowflake ID's (i.e. the ones in use at Twitter / Discord). + + Discord snowflakes are in this format: + tttttttt tttttttt tttttttt tttttttt + tttttttt ttddddds sssspppp pppppppp + + where: + t: timestamp (in milliseconds) — 42 bits + d: local ID — 5 bits + s: shard ID — 5 bits + p: progressive counter — 10 bits + + Converter takes local ID and shard ID as one; latter 8 bits are taken for + the shard ID, while the former 2 are added to timestamp, taking advantage of + more precision — along with up to 2 most significant bits of progressive co + + The constructor takes an epoch argument, since snowflakes, due to + optimization requirements, are based on a different epoch (e.g. + Jan 1, 2015 for Discord); epoch is wanted as seconds since Unix epoch + (i.e. midnight of Jan 1, 1970). + """ + epoch: int + local_id: int + shard_id: int + counter: int + last_gen_ts: int + + TS_ACCURACY = 1000 + + + def __init__(self, epoch: int, local_id: int = 0, shard_id: int | None = None, + last_id: int = 0 + ): + self.epoch = epoch + self.local_id = local_id + self.shard_id = (shard_id or os.getpid()) % 32 + self.counter = 0 + self.last_gen_ts = min(last_id >> 22, self.cur_timestamp()) + def cur_timestamp(self) -> int: + return int((time.time() - self.epoch) * self.TS_ACCURACY) + def generate(self, /, n: int = 1): + """ + Generate one or more snowflakes. + The generated ids are returned as integers. + Bulk generation is supported. + + Returns as an iterator, to allow generation “on the fly”. + To get a scalar or a list, use .generate_one() or next(), or + .generate_list() or list(.generate()), respectively. + + Warning: the function **may block**. + """ + now = self.cur_timestamp() + if now < self.last_gen_ts: + time.sleep((self.last_gen_ts - now) / (1 << 16)) + elif now > self.last_gen_ts: + self.counter = 0 + while n: + if self.counter >= 4096: + while (now := self.cur_timestamp()) <= self.last_gen_ts: + time.sleep(1 / (1 << 16)) + with Lock(): + self.counter %= 1 << 16 + # XXX the lock is here "just in case", MULTITHREADED GENERATION IS NOT ADVISED! + with Lock(): + siq = ( + (now << 22) | + ((self.local_id % 32) << 17) | + ((self.shard_id % 32) << 12) | + (self.counter % (1 << 12)) + ) + n -= 1 + self.counter += 1 + yield siq + def generate_one(self, /) -> int: + return next(self.generate(1)) + def generate_list(self, /, n: int = 1) -> list[int]: + return list(self.generate(n)) + + +class Snowflake(int): + """ + Representation of a Snowflake as an integer. + """ + + def to_bytes(self, length: int = 14, byteorder = "big", *, signed: bool = False) -> bytes: + return super().to_bytes(length, byteorder, signed=signed) + def to_base64(self, length: int = 9, *, strip: bool = True) -> str: + return b64encode(self.to_bytes(length), strip=strip) + def to_cb32(self)-> str: + return cb32encode(self.to_bytes(8, 'big')) + to_crockford = to_cb32 + def to_hex(self) -> str: + return f'{self:x}' + def to_oct(self) -> str: + return f'{self:o}' + def to_b32l(self) -> str: + # PSA Snowflake Base32 representations are padded to 10 bytes! + if self < 0: + return '_' + Snowflake.to_b32l(-self) + return b32lencode(self.to_bytes(10, 'big')).lstrip('a') + + @classmethod + def from_bytes(cls, b: bytes, byteorder = 'big', *, signed: bool = False) -> Snowflake: + if len(b) not in (8, 10): + warnings.warn('Snowflakes are exactly 8 bytes long', BytesWarning) + return super().from_bytes(b, byteorder, signed=signed) + + @classmethod + def from_b32l(cls, val: str) -> Snowflake: + if val.startswith('_'): + ## support for negative Snowflakes + return -cls.from_b32l(val.lstrip('_')) + return Snowflake.from_bytes(b32ldecode(val.rjust(16, 'a'))) + + @override + def __format__(self, opt: str, /) -> str: + try: + return self.format(opt) + except ValueError: + return super().__format__(opt) + def format(self, opt: str, /) -> str: + match opt: + case 'b': + return self.to_base64() + case 'c': + return self.to_cb32() + case '0c': + return '0' + self.to_cb32() + case 'd' | '': + return int.__repr__(self) + case 'l': + return self.to_b32l() + case 'o' | 'x': + return int.__format__(self, opt) + case _: + raise ValueError(f'unknown format: {opt!r}') + + def __str__(self) -> str: + return int.__str__(self) + to_dec = deprecated('use str() instead')(__str__) + + def __repr__(self): + return f'{self.__class__.__name__}({super().__repr__()})' + + def to_siq(self, domain: str, epoch: int, target_type: SiqType, **kwargs): + """ + Convenience method for conversion to SIQ. + + (!) This does not check for existence! Always do the check yourself. + """ + return SnowflakeSiqMigrator(domain, epoch, **kwargs).to_siq(self, target_type) + + + +__all__ = ( + 'Snowflake', 'SnowflakeGen' +) \ No newline at end of file diff --git a/src/suou/sqlalchemy.py b/src/suou/sqlalchemy.py index 3886d0d..8ff32a4 100644 --- a/src/suou/sqlalchemy.py +++ b/src/suou/sqlalchemy.py @@ -18,16 +18,17 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod from functools import wraps -from typing import Any, Callable, Iterable, Never, TypeVar +from typing import Callable, Iterable, Never, TypeVar import warnings -from sqlalchemy import CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, select, text -from sqlalchemy.orm import DeclarativeBase, Session, declarative_base as _declarative_base +from sqlalchemy import BigInteger, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text +from sqlalchemy.orm import DeclarativeBase, Session, declarative_base as _declarative_base, relationship +from .snowflake import SnowflakeGen from .itertools import kwargs_prefix, makelist from .signing import HasSigner, UserSigner from .codecs import StringCase -from .functools import deprecated -from .iding import SiqType, SiqCache +from .functools import deprecated, not_implemented +from .iding import Siq, SiqGen, SiqType, SiqCache from .classtools import Incomplete, Wanted _T = TypeVar('_T') @@ -36,7 +37,7 @@ _T = TypeVar('_T') # Not to be confused with SiqType. IdType = LargeBinary(16) - +@not_implemented def sql_escape(s: str, /, dialect: Dialect) -> str: """ Escape a value for SQL embedding, using SQLAlchemy's literal processors. @@ -49,20 +50,49 @@ def sql_escape(s: str, /, dialect: Dialect) -> str: raise TypeError('invalid data type') -def id_column(typ: SiqType, *, primary_key: bool = True): +def create_session(url: str) -> Session: + """ + Create a session on the fly, given a database URL. Useful for + contextless environments, such as Python REPL. + + Heads up: a function with the same name exists in core sqlalchemy, but behaves + completely differently!! + """ + engine = create_engine(url) + return Session(bind = engine) + +def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs): """ Marks a column which contains a SIQ. """ def new_id_factory(owner: DeclarativeBase) -> Callable: domain_name = owner.metadata.info['domain_name'] - idgen = SiqCache(domain_name, typ) + idgen = SiqCache(SiqGen(domain_name), typ) def new_id() -> bytes: - return idgen.generate().to_bytes() + return Siq(idgen.generate()).to_bytes() return new_id if primary_key: - return Incomplete(Column, IdType, primary_key = True, default = Wanted(new_id_factory)) + return Incomplete(Column, IdType, primary_key = True, default = Wanted(new_id_factory), **kwargs) else: - return Incomplete(Column, IdType, unique = True, nullable = False, default = Wanted(new_id_factory)) + return Incomplete(Column, IdType, unique = True, nullable = False, default = Wanted(new_id_factory), **kwargs) + +def snowflake_column(*, primary_key: bool = True, **kwargs): + """ + Same as id_column() but with snowflakes. + + XXX this is meant ONLY as means of transition; for new stuff, use id_column() and SIQ. + """ + def new_id_factory(owner: DeclarativeBase) -> Callable: + epoch = owner.metadata.info['snowflake_epoch'] + # more arguments will be passed on (?) + idgen = SnowflakeGen(epoch) + def new_id() -> int: + return idgen.generate_one() + return new_id + if primary_key: + return Incomplete(Column, BigInteger, primary_key = True, default = Wanted(new_id_factory), **kwargs) + else: + return Incomplete(Column, BigInteger, unique = True, nullable = False, default = Wanted(new_id_factory), **kwargs) def match_constraint(col_name: str, regex: str, /, dialect: str = 'default', constraint_name: str | None = None) -> CheckConstraint: @@ -99,9 +129,12 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No metadata = dict() if 'info' not in metadata: metadata['info'] = dict() + # snowflake metadata + snowflake_kwargs = kwargs_prefix(kwargs, 'snowflake_', remove=True, keep_prefix=True) metadata['info'].update( domain_name = domain_name, - secret_key = master_secret + secret_key = master_secret, + **snowflake_kwargs ) Base = _declarative_base(metadata=MetaData(**metadata), **kwargs) return Base @@ -160,6 +193,26 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]: return (date_col, acc_col) +def parent_children(keyword: str, /, **kwargs): + """ + Self-referential one-to-many relationship pair. + Parent comes first, children come later. + + keyword is used in back_populates column names: convention over + configuration. Naming it otherwise will BREAK your models. + + Additional keyword arguments can be sourced with parent_ and child_ argument prefixes, + obviously. + """ + + parent_kwargs = kwargs_prefix(kwargs, 'parent_') + child_kwargs = kwargs_prefix(kwargs, 'child_') + + parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', **parent_kwargs) + child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', **child_kwargs) + + return parent, child + def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]: """ Return a table's column given its name. @@ -200,8 +253,7 @@ class AuthSrc(metaclass=ABCMeta): def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', - required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None, - invalid_exc: Callable | None = None, required_exc: Callable | None = None): + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): ''' Inject the current user into a view, given the Authorization: Bearer header. @@ -222,11 +274,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | except Exception: return None - def _default_invalid(msg: str): + def _default_invalid(msg: str = 'validation failed'): raise ValueError(msg) - invalid_exc = invalid_exc or _default_invalid - required_exc = required_exc or (lambda: _default_invalid()) + invalid_exc = src.invalid_exc or _default_invalid + required_exc = src.required_exc or (lambda: _default_invalid()) def decorator(func: Callable): @wraps(func)