Compare commits
10 commits
d6e54f192f
...
04ce86a43e
| Author | SHA1 | Date | |
|---|---|---|---|
| 04ce86a43e | |||
| 1c2bd11212 | |||
| bc4ea9b101 | |||
| 3d03cc00fa | |||
| e5d9c8e4a6 | |||
| 01d0464da2 | |||
| 121fbe83b0 | |||
| 1d6d5d72f8 | |||
| e615cbb628 | |||
| 946973f732 |
10 changed files with 400 additions and 45 deletions
19
CHANGELOG.md
19
CHANGELOG.md
|
|
@ -1,14 +1,31 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## 0.3.4
|
||||||
|
|
||||||
|
- Bug fixes in `.flask_sqlalchemy` and `.sqlalchemy` — `require_auth()` is unusable before this point!
|
||||||
|
|
||||||
|
## 0.3.3
|
||||||
|
|
||||||
|
- Fixed leftovers in `snowflake` module from unchecked code copying — i.e. `SnowflakeGen.generate_one()` used to require an unused typ= parameter
|
||||||
|
- Fixed a bug in `id_column()` that made it fail to provide a working generator — again, this won't be backported
|
||||||
|
|
||||||
|
## 0.3.2
|
||||||
|
|
||||||
|
- Fixed bugs in Snowflake generation and serialization of negative values
|
||||||
|
|
||||||
## 0.3.0
|
## 0.3.0
|
||||||
|
|
||||||
|
- Fixed `cb32encode()` and `b32lencode()` doing wrong padding — **UNSOLVED in 0.2.x** which is out of support, effective immediately
|
||||||
|
- **Changed behavior** of `kwargs_prefix()` which now removes keys from original mapping by default
|
||||||
- Add SQLAlchemy auth loaders i.e. `sqlalchemy.require_auth_base()`, `flask_sqlalchemy`.
|
- Add SQLAlchemy auth loaders i.e. `sqlalchemy.require_auth_base()`, `flask_sqlalchemy`.
|
||||||
What auth loaders do is loading user token and signature into app
|
What auth loaders do is loading user token and signature into app
|
||||||
|
- `sqlalchemy`: add `parent_children()` and `create_session()`
|
||||||
- Implement `UserSigner()`
|
- Implement `UserSigner()`
|
||||||
- Improve JSON handling in `flask_restx`
|
- Improve JSON handling in `flask_restx`
|
||||||
- Add base2048 (i.e. [BIP-39](https://github.com/bitcoin/bips/blob/master/bip-0039.mediawiki)) codec
|
- Add base2048 (i.e. [BIP-39](https://github.com/bitcoin/bips/blob/master/bip-0039.mediawiki)) codec
|
||||||
- Add `split_bits()`, `join_bits()`, `ltuple()`, `rtuple()`
|
- Add `split_bits()`, `join_bits()`, `ltuple()`, `rtuple()`, `ssv_list()`, `additem()`
|
||||||
- Add `markdown` extensions
|
- Add `markdown` extensions
|
||||||
|
- Add Snowflake manipulation utilities
|
||||||
|
|
||||||
## 0.2.3
|
## 0.2.3
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,20 +17,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .iding import Siq, SiqCache, SiqType, SiqGen
|
from .iding import Siq, SiqCache, SiqType, SiqGen
|
||||||
from .codecs import StringCase, cb32encode, cb32decode, jsonencode, want_bytes, want_str, b2048encode, b2048decode
|
from .codecs import (StringCase, cb32encode, cb32decode, b32lencode, b32ldecode, b64encode, b64decode, b2048encode, b2048decode,
|
||||||
|
jsonencode, want_bytes, want_str, ssv_list)
|
||||||
from .bits import count_ones, mask_shift, split_bits, join_bits
|
from .bits import count_ones, mask_shift, split_bits, join_bits
|
||||||
from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource
|
from .configparse import MissingConfigError, MissingConfigWarning, ConfigOptions, ConfigParserConfigSource, ConfigSource, DictConfigSource, ConfigValue, EnvConfigSource
|
||||||
from .functools import deprecated, not_implemented
|
from .functools import deprecated, not_implemented
|
||||||
from .classtools import Wanted, Incomplete
|
from .classtools import Wanted, Incomplete
|
||||||
from .itertools import makelist, kwargs_prefix, ltuple, rtuple
|
from .itertools import makelist, kwargs_prefix, ltuple, rtuple, additem
|
||||||
from .i18n import I18n, JsonI18n, TomlI18n
|
from .i18n import I18n, JsonI18n, TomlI18n
|
||||||
|
from .snowflake import Snowflake, SnowflakeGen
|
||||||
|
|
||||||
__version__ = "0.3.0-dev22"
|
__version__ = "0.3.4.rc1"
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase',
|
'Siq', 'SiqCache', 'SiqType', 'SiqGen', 'StringCase',
|
||||||
'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource',
|
'MissingConfigError', 'MissingConfigWarning', 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', 'EnvConfigSource', 'DictConfigSource',
|
||||||
'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple',
|
'deprecated', 'not_implemented', 'Wanted', 'Incomplete', 'jsonencode', 'ltuple', 'rtuple',
|
||||||
'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift',
|
'makelist', 'kwargs_prefix', 'I18n', 'JsonI18n', 'TomlI18n', 'cb32encode', 'cb32decode', 'count_ones', 'mask_shift',
|
||||||
'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode'
|
'want_bytes', 'want_str', 'version', 'b2048encode', 'split_bits', 'join_bits', 'b2048decode',
|
||||||
|
'Snowflake', 'SnowflakeGen', 'ssv_list', 'additem', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -162,7 +162,7 @@ def cb32decode(val: bytes | str) -> str:
|
||||||
'''
|
'''
|
||||||
Decode bytes from Crockford Base32.
|
Decode bytes from Crockford Base32.
|
||||||
'''
|
'''
|
||||||
return base64.b32decode(want_bytes(val).upper().translate(CROCKFORD_TO_B32) + b'=' * ((5 - len(val) % 5) % 5))
|
return base64.b32decode(want_bytes(val).upper().translate(CROCKFORD_TO_B32) + b'=' * ((8 - len(val) % 8) % 8))
|
||||||
|
|
||||||
def b32lencode(val: bytes) -> str:
|
def b32lencode(val: bytes) -> str:
|
||||||
'''
|
'''
|
||||||
|
|
@ -174,7 +174,7 @@ def b32ldecode(val: bytes | str) -> bytes:
|
||||||
'''
|
'''
|
||||||
Decode a lowercase base32 encoded byte sequence. Padding is managed automatically.
|
Decode a lowercase base32 encoded byte sequence. Padding is managed automatically.
|
||||||
'''
|
'''
|
||||||
return base64.b32decode(want_bytes(val).upper() + b'=' * ((5 - len(val) % 5) % 5))
|
return base64.b32decode(want_bytes(val).upper() + b'=' * ((8 - len(val) % 8) % 8))
|
||||||
|
|
||||||
def b64encode(val: bytes, *, strip: bool = True) -> str:
|
def b64encode(val: bytes, *, strip: bool = True) -> str:
|
||||||
'''
|
'''
|
||||||
|
|
@ -229,6 +229,35 @@ def jsonencode(obj: dict, *, skipkeys: bool = True, separators: tuple[str, str]
|
||||||
|
|
||||||
jsondecode = deprecated('just use json.loads()')(json.loads)
|
jsondecode = deprecated('just use json.loads()')(json.loads)
|
||||||
|
|
||||||
|
def ssv_list(s: str, *, sep_chars = ',;') -> list[str]:
|
||||||
|
"""
|
||||||
|
Parse values from a Space Separated Values (SSV) string.
|
||||||
|
|
||||||
|
By default, values are split on spaces, commas (,) and semicolons (;), configurable
|
||||||
|
with sepchars= argument.
|
||||||
|
|
||||||
|
Double quotes (") can be used to allow spaces, commas etc. in values. Doubled double
|
||||||
|
quotes ("") are parsed as literal double quotes.
|
||||||
|
|
||||||
|
Useful for environment variables: pass it to ConfigValue() as the cast= argument.
|
||||||
|
"""
|
||||||
|
sep_re = r'\s+|\s*[' + re.escape(sep_chars) + r']\s*'
|
||||||
|
parts = s.split('"')
|
||||||
|
parts[::2] = [re.split(sep_re, x) for x in parts[::2]]
|
||||||
|
l: list[str] = parts[0].copy()
|
||||||
|
for i in range(1, len(parts), 2):
|
||||||
|
p0, *pt = parts[i+1]
|
||||||
|
# two "strings" sandwiching each other case
|
||||||
|
if i < len(parts)-2 and parts[i] and parts[i+2] and not p0 and not pt:
|
||||||
|
p0 = '"'
|
||||||
|
l[-1] += ('"' if parts[i] == '' else parts[i]) + p0
|
||||||
|
l.extend(pt)
|
||||||
|
if l and l[0] == '':
|
||||||
|
l.pop(0)
|
||||||
|
if l and l[-1] == '':
|
||||||
|
l.pop()
|
||||||
|
return l
|
||||||
|
|
||||||
class StringCase(enum.Enum):
|
class StringCase(enum.Enum):
|
||||||
"""
|
"""
|
||||||
Enum values used by regex validators and storage converters.
|
Enum values used by regex validators and storage converters.
|
||||||
|
|
@ -237,7 +266,7 @@ class StringCase(enum.Enum):
|
||||||
LOWER = case insensitive, force lowercase
|
LOWER = case insensitive, force lowercase
|
||||||
UPPER = case insensitive, force uppercase
|
UPPER = case insensitive, force uppercase
|
||||||
IGNORE = case insensitive, leave as is, use lowercase in comparison
|
IGNORE = case insensitive, leave as is, use lowercase in comparison
|
||||||
IGNORE_UPPER = same as above, but use uppercase il comparison
|
IGNORE_UPPER = same as above, but use uppercase in comparison
|
||||||
"""
|
"""
|
||||||
AS_IS = 0
|
AS_IS = 0
|
||||||
LOWER = FORCE_LOWER = 1
|
LOWER = FORCE_LOWER = 1
|
||||||
|
|
@ -264,5 +293,5 @@ class StringCase(enum.Enum):
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'cb32encode', 'cb32decode', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode', 'jsonencode'
|
'cb32encode', 'cb32decode', 'b32lencode', 'b32ldecode', 'b64encode', 'b64decode', 'jsonencode'
|
||||||
'StringCase', 'want_bytes', 'want_str', 'jsondecode'
|
'StringCase', 'want_bytes', 'want_str', 'jsondecode', 'ssv_list'
|
||||||
)
|
)
|
||||||
|
|
@ -54,7 +54,10 @@ class Api(_Api):
|
||||||
Notably, all JSON is whitespace-free and .message is remapped to .error
|
Notably, all JSON is whitespace-free and .message is remapped to .error
|
||||||
"""
|
"""
|
||||||
def handle_error(self, e):
|
def handle_error(self, e):
|
||||||
|
### XXX apparently this handle_error does not get called AT ALL.
|
||||||
|
print(e)
|
||||||
res = super().handle_error(e)
|
res = super().handle_error(e)
|
||||||
|
print(res)
|
||||||
if isinstance(res, Mapping) and 'message' in res:
|
if isinstance(res, Mapping) and 'message' in res:
|
||||||
res['error'] = res['message']
|
res['error'] = res['message']
|
||||||
del res['message']
|
del res['message']
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from flask_sqlalchemy import SQLAlchemy
|
||||||
from sqlalchemy.orm import DeclarativeBase, Session
|
from sqlalchemy.orm import DeclarativeBase, Session
|
||||||
|
|
||||||
from .codecs import want_bytes
|
from .codecs import want_bytes
|
||||||
from .sqlalchemy import require_auth_base
|
from .sqlalchemy import AuthSrc, require_auth_base
|
||||||
|
|
||||||
class FlaskAuthSrc(AuthSrc):
|
class FlaskAuthSrc(AuthSrc):
|
||||||
'''
|
'''
|
||||||
|
|
@ -35,6 +35,7 @@ class FlaskAuthSrc(AuthSrc):
|
||||||
def get_session(self) -> Session:
|
def get_session(self) -> Session:
|
||||||
return self.db.session
|
return self.db.session
|
||||||
def get_token(self):
|
def get_token(self):
|
||||||
|
if request.authorization:
|
||||||
return request.authorization.token
|
return request.authorization.token
|
||||||
def get_signature(self) -> bytes:
|
def get_signature(self) -> bytes:
|
||||||
sig = request.headers.get('authorization-signature', None)
|
sig = request.headers.get('authorization-signature', None)
|
||||||
|
|
@ -42,15 +43,18 @@ class FlaskAuthSrc(AuthSrc):
|
||||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||||
abort(400, msg)
|
abort(400, msg)
|
||||||
def required_exc(self):
|
def required_exc(self):
|
||||||
abort(401)
|
abort(401, 'Login required')
|
||||||
|
|
||||||
def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Callable]:
|
def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable:
|
||||||
"""
|
"""
|
||||||
Make an auth_required() decorator for Flask views.
|
Make an auth_required() decorator for Flask views.
|
||||||
|
|
||||||
This looks for a token in the Authorization header, validates it, loads the
|
This looks for a token in the Authorization header, validates it, loads the
|
||||||
appropriate object, and injects it as the user= parameter.
|
appropriate object, and injects it as the user= parameter.
|
||||||
|
|
||||||
|
NOTE: the actual decorator to be used on routes is **auth_required()**,
|
||||||
|
NOT require_auth() which is the **constructor** for it.
|
||||||
|
|
||||||
cls is a SQLAlchemy table.
|
cls is a SQLAlchemy table.
|
||||||
db is a flask_sqlalchemy.SQLAlchemy() binding.
|
db is a flask_sqlalchemy.SQLAlchemy() binding.
|
||||||
|
|
||||||
|
|
@ -63,7 +67,12 @@ def require_auth(cls: type[DeclarativeBase], db: SQLAlchemy) -> Callable[Any, Ca
|
||||||
def super_secret_stuff(user):
|
def super_secret_stuff(user):
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
return partial(require_auth_base, cls=cls, src=FlaskAuthSrc(db))
|
def auth_required(**kwargs):
|
||||||
|
return require_auth_base(cls=cls, src=FlaskAuthSrc(db), **kwargs)
|
||||||
|
|
||||||
|
auth_required.__doc__ = require_auth_base.__doc__
|
||||||
|
|
||||||
|
return auth_required
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('require_auth', )
|
__all__ = ('require_auth', )
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ try:
|
||||||
from warnings import deprecated
|
from warnings import deprecated
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Python <=3.12 does not implement warnings.deprecated
|
# Python <=3.12 does not implement warnings.deprecated
|
||||||
def deprecated(message: str, /, *, category=DeprecationWarning):
|
def deprecated(message: str, /, *, category=DeprecationWarning, stacklevel:int=1):
|
||||||
"""
|
"""
|
||||||
Backport of PEP 702 for Python <=3.12.
|
Backport of PEP 702 for Python <=3.12.
|
||||||
The stack_level stuff is not reimplemented on purpose because
|
The stack_level stuff is not reimplemented on purpose because
|
||||||
|
|
@ -32,7 +32,7 @@ except ImportError:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*a, **ka):
|
def wrapper(*a, **ka):
|
||||||
if category is not None:
|
if category is not None:
|
||||||
warnings.warn(message, category)
|
warnings.warn(message, category, stacklevel=stacklevel)
|
||||||
return func(*a, **ka)
|
return func(*a, **ka)
|
||||||
func.__deprecated__ = True
|
func.__deprecated__ = True
|
||||||
wrapper.__deprecated__ = True
|
wrapper.__deprecated__ = True
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ import os
|
||||||
from typing import Iterable, override
|
from typing import Iterable, override
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .functools import not_implemented, deprecated
|
from .functools import deprecated
|
||||||
from .codecs import b32lencode, b64encode, cb32encode
|
from .codecs import b32lencode, b64encode, cb32encode
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -206,7 +206,9 @@ class SiqCache:
|
||||||
return self.generator.last_gen_ts
|
return self.generator.last_gen_ts
|
||||||
def cur_timestamp(self):
|
def cur_timestamp(self):
|
||||||
return self.generator.cur_timestamp()
|
return self.generator.cur_timestamp()
|
||||||
def __init__(self, generator: SiqGen, typ: SiqType, size: int = 64, max_age: int = 1024):
|
def __init__(self, generator: SiqGen | str, typ: SiqType, size: int = 64, max_age: int = 1024):
|
||||||
|
if isinstance(generator, str):
|
||||||
|
generator = SiqGen(generator)
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
self.typ = typ
|
self.typ = typ
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
@ -220,6 +222,9 @@ class SiqCache:
|
||||||
return self._cache.pop(0)
|
return self._cache.pop(0)
|
||||||
|
|
||||||
class Siq(int):
|
class Siq(int):
|
||||||
|
"""
|
||||||
|
Representation of a SIQ as an integer.
|
||||||
|
"""
|
||||||
def to_bytes(self, length: int = 14, byteorder = 'big', *, signed: bool = False) -> bytes:
|
def to_bytes(self, length: int = 14, byteorder = 'big', *, signed: bool = False) -> bytes:
|
||||||
return super().to_bytes(length, byteorder, signed=signed)
|
return super().to_bytes(length, byteorder, signed=signed)
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -237,9 +242,14 @@ class Siq(int):
|
||||||
return f'{self:x}'
|
return f'{self:x}'
|
||||||
def to_oct(self) -> str:
|
def to_oct(self) -> str:
|
||||||
return f'{self:o}'
|
return f'{self:o}'
|
||||||
@deprecated('use str() instead')
|
def to_b32l(self) -> str:
|
||||||
def to_dec(self) -> str:
|
"""
|
||||||
return f'{self}'
|
This is NOT the URI serializer!
|
||||||
|
"""
|
||||||
|
return b32lencode(self.to_bytes(15, 'big'))
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return int.__str__(self)
|
||||||
|
to_dec = deprecated('use str() instead')(__str__)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def __format__(self, opt: str, /) -> str:
|
def __format__(self, opt: str, /) -> str:
|
||||||
|
|
@ -256,7 +266,9 @@ class Siq(int):
|
||||||
case '0c':
|
case '0c':
|
||||||
return '0' + self.to_cb32()
|
return '0' + self.to_cb32()
|
||||||
case 'd' | '':
|
case 'd' | '':
|
||||||
return int.__str__(self)
|
return int.__repr__(self)
|
||||||
|
case 'l':
|
||||||
|
return self.to_b32l()
|
||||||
case 'o' | 'x':
|
case 'o' | 'x':
|
||||||
return int.__format__(self, opt)
|
return int.__format__(self, opt)
|
||||||
case 'u':
|
case 'u':
|
||||||
|
|
@ -287,6 +299,15 @@ class Siq(int):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({super().__repr__()})'
|
return f'{self.__class__.__name__}({super().__repr__()})'
|
||||||
|
|
||||||
|
# convenience methods
|
||||||
|
def timestamp(self):
|
||||||
|
return (self >> 56) / (1 << 16)
|
||||||
|
|
||||||
|
def shard_id(self):
|
||||||
|
return (self >> 48) % 256
|
||||||
|
|
||||||
|
def domain_name(self):
|
||||||
|
return (self >> 16) % 0xffffffff
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'Siq', 'SiqCache', 'SiqType', 'SiqGen'
|
'Siq', 'SiqCache', 'SiqType', 'SiqGen'
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ This software is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from typing import Any, Iterable, TypeVar
|
from typing import Any, Iterable, MutableMapping, TypeVar
|
||||||
|
import warnings
|
||||||
|
|
||||||
_T = TypeVar('_T')
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|
@ -50,12 +51,38 @@ def rtuple(seq: Iterable[_T], size: int, /, pad = None) -> tuple:
|
||||||
return seq
|
return seq
|
||||||
|
|
||||||
|
|
||||||
def kwargs_prefix(it: dict[str, Any], prefix: str) -> dict[str, Any]:
|
def kwargs_prefix(it: dict[str, Any], prefix: str, *, remove = True, keep_prefix = False) -> dict[str, Any]:
|
||||||
'''
|
'''
|
||||||
Subset of keyword arguments. Useful for callable wrapping.
|
Subset of keyword arguments. Useful for callable wrapping.
|
||||||
|
|
||||||
|
By default, it removes arguments from original kwargs as well. You can prevent by
|
||||||
|
setting remove=False.
|
||||||
|
|
||||||
|
By default, specified prefix is removed from each key of the returned
|
||||||
|
dictionary; keep_prefix=True keeps the prefix on keys.
|
||||||
'''
|
'''
|
||||||
return {k.removeprefix(prefix): v for k, v in it.items() if k.startswith(prefix)}
|
keys = [k for k in it.keys() if k.startswith(prefix)]
|
||||||
|
|
||||||
|
ka = dict()
|
||||||
|
for k in keys:
|
||||||
|
ka[k if keep_prefix else k.removeprefix(prefix)] = it[k]
|
||||||
|
if remove:
|
||||||
|
for k in keys:
|
||||||
|
it.pop(k)
|
||||||
|
return ka
|
||||||
|
|
||||||
|
def additem(obj: MutableMapping, /, name: str = None):
|
||||||
|
"""
|
||||||
|
Syntax sugar for adding a function to a mapping, immediately.
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
key = name or func.__name__
|
||||||
|
if key in obj:
|
||||||
|
warnings.warn(f'mapping does already have item {key!r}')
|
||||||
|
obj[key] = func
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple', 'additem')
|
||||||
|
|
||||||
__all__ = ('makelist', 'kwargs_prefix', 'ltuple', 'rtuple')
|
|
||||||
|
|
|
||||||
194
src/suou/snowflake.py
Normal file
194
src/suou/snowflake.py
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
"""
|
||||||
|
Utilities for Snowflake-like identifiers.
|
||||||
|
|
||||||
|
Here for applications who benefit from their use. I (sakuragasaki46)
|
||||||
|
recommend using SIQ (.iding) when applicable; there also utilities to
|
||||||
|
convert snowflakes into SIQ's in .migrate.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Copyright (c) 2025 Sakuragasaki46.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
See LICENSE for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
This software is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
|
from threading import Lock
|
||||||
|
import time
|
||||||
|
from typing import override
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from .migrate import SnowflakeSiqMigrator
|
||||||
|
from .iding import SiqType
|
||||||
|
from .codecs import b32ldecode, b32lencode, b64encode, cb32encode
|
||||||
|
from .functools import deprecated
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeGen:
|
||||||
|
"""
|
||||||
|
Implements a generator Snowflake ID's (i.e. the ones in use at Twitter / Discord).
|
||||||
|
|
||||||
|
Discord snowflakes are in this format:
|
||||||
|
tttttttt tttttttt tttttttt tttttttt
|
||||||
|
tttttttt ttddddds sssspppp pppppppp
|
||||||
|
|
||||||
|
where:
|
||||||
|
t: timestamp (in milliseconds) — 42 bits
|
||||||
|
d: local ID — 5 bits
|
||||||
|
s: shard ID — 5 bits
|
||||||
|
p: progressive counter — 10 bits
|
||||||
|
|
||||||
|
Converter takes local ID and shard ID as one; latter 8 bits are taken for
|
||||||
|
the shard ID, while the former 2 are added to timestamp, taking advantage of
|
||||||
|
more precision — along with up to 2 most significant bits of progressive co
|
||||||
|
|
||||||
|
The constructor takes an epoch argument, since snowflakes, due to
|
||||||
|
optimization requirements, are based on a different epoch (e.g.
|
||||||
|
Jan 1, 2015 for Discord); epoch is wanted as seconds since Unix epoch
|
||||||
|
(i.e. midnight of Jan 1, 1970).
|
||||||
|
"""
|
||||||
|
epoch: int
|
||||||
|
local_id: int
|
||||||
|
shard_id: int
|
||||||
|
counter: int
|
||||||
|
last_gen_ts: int
|
||||||
|
|
||||||
|
TS_ACCURACY = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, epoch: int, local_id: int = 0, shard_id: int | None = None,
|
||||||
|
last_id: int = 0
|
||||||
|
):
|
||||||
|
self.epoch = epoch
|
||||||
|
self.local_id = local_id
|
||||||
|
self.shard_id = (shard_id or os.getpid()) % 32
|
||||||
|
self.counter = 0
|
||||||
|
self.last_gen_ts = min(last_id >> 22, self.cur_timestamp())
|
||||||
|
def cur_timestamp(self) -> int:
|
||||||
|
return int((time.time() - self.epoch) * self.TS_ACCURACY)
|
||||||
|
def generate(self, /, n: int = 1):
|
||||||
|
"""
|
||||||
|
Generate one or more snowflakes.
|
||||||
|
The generated ids are returned as integers.
|
||||||
|
Bulk generation is supported.
|
||||||
|
|
||||||
|
Returns as an iterator, to allow generation “on the fly”.
|
||||||
|
To get a scalar or a list, use .generate_one() or next(), or
|
||||||
|
.generate_list() or list(.generate()), respectively.
|
||||||
|
|
||||||
|
Warning: the function **may block**.
|
||||||
|
"""
|
||||||
|
now = self.cur_timestamp()
|
||||||
|
if now < self.last_gen_ts:
|
||||||
|
time.sleep((self.last_gen_ts - now) / (1 << 16))
|
||||||
|
elif now > self.last_gen_ts:
|
||||||
|
self.counter = 0
|
||||||
|
while n:
|
||||||
|
if self.counter >= 4096:
|
||||||
|
while (now := self.cur_timestamp()) <= self.last_gen_ts:
|
||||||
|
time.sleep(1 / (1 << 16))
|
||||||
|
with Lock():
|
||||||
|
self.counter %= 1 << 16
|
||||||
|
# XXX the lock is here "just in case", MULTITHREADED GENERATION IS NOT ADVISED!
|
||||||
|
with Lock():
|
||||||
|
siq = (
|
||||||
|
(now << 22) |
|
||||||
|
((self.local_id % 32) << 17) |
|
||||||
|
((self.shard_id % 32) << 12) |
|
||||||
|
(self.counter % (1 << 12))
|
||||||
|
)
|
||||||
|
n -= 1
|
||||||
|
self.counter += 1
|
||||||
|
yield siq
|
||||||
|
def generate_one(self, /) -> int:
|
||||||
|
return next(self.generate(1))
|
||||||
|
def generate_list(self, /, n: int = 1) -> list[int]:
|
||||||
|
return list(self.generate(n))
|
||||||
|
|
||||||
|
|
||||||
|
class Snowflake(int):
|
||||||
|
"""
|
||||||
|
Representation of a Snowflake as an integer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def to_bytes(self, length: int = 14, byteorder = "big", *, signed: bool = False) -> bytes:
|
||||||
|
return super().to_bytes(length, byteorder, signed=signed)
|
||||||
|
def to_base64(self, length: int = 9, *, strip: bool = True) -> str:
|
||||||
|
return b64encode(self.to_bytes(length), strip=strip)
|
||||||
|
def to_cb32(self)-> str:
|
||||||
|
return cb32encode(self.to_bytes(8, 'big'))
|
||||||
|
to_crockford = to_cb32
|
||||||
|
def to_hex(self) -> str:
|
||||||
|
return f'{self:x}'
|
||||||
|
def to_oct(self) -> str:
|
||||||
|
return f'{self:o}'
|
||||||
|
def to_b32l(self) -> str:
|
||||||
|
# PSA Snowflake Base32 representations are padded to 10 bytes!
|
||||||
|
if self < 0:
|
||||||
|
return '_' + Snowflake.to_b32l(-self)
|
||||||
|
return b32lencode(self.to_bytes(10, 'big')).lstrip('a')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, b: bytes, byteorder = 'big', *, signed: bool = False) -> Snowflake:
|
||||||
|
if len(b) not in (8, 10):
|
||||||
|
warnings.warn('Snowflakes are exactly 8 bytes long', BytesWarning)
|
||||||
|
return super().from_bytes(b, byteorder, signed=signed)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_b32l(cls, val: str) -> Snowflake:
|
||||||
|
if val.startswith('_'):
|
||||||
|
## support for negative Snowflakes
|
||||||
|
return -cls.from_b32l(val.lstrip('_'))
|
||||||
|
return Snowflake.from_bytes(b32ldecode(val.rjust(16, 'a')))
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __format__(self, opt: str, /) -> str:
|
||||||
|
try:
|
||||||
|
return self.format(opt)
|
||||||
|
except ValueError:
|
||||||
|
return super().__format__(opt)
|
||||||
|
def format(self, opt: str, /) -> str:
|
||||||
|
match opt:
|
||||||
|
case 'b':
|
||||||
|
return self.to_base64()
|
||||||
|
case 'c':
|
||||||
|
return self.to_cb32()
|
||||||
|
case '0c':
|
||||||
|
return '0' + self.to_cb32()
|
||||||
|
case 'd' | '':
|
||||||
|
return int.__repr__(self)
|
||||||
|
case 'l':
|
||||||
|
return self.to_b32l()
|
||||||
|
case 'o' | 'x':
|
||||||
|
return int.__format__(self, opt)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f'unknown format: {opt!r}')
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return int.__str__(self)
|
||||||
|
to_dec = deprecated('use str() instead')(__str__)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({super().__repr__()})'
|
||||||
|
|
||||||
|
def to_siq(self, domain: str, epoch: int, target_type: SiqType, **kwargs):
|
||||||
|
"""
|
||||||
|
Convenience method for conversion to SIQ.
|
||||||
|
|
||||||
|
(!) This does not check for existence! Always do the check yourself.
|
||||||
|
"""
|
||||||
|
return SnowflakeSiqMigrator(domain, epoch, **kwargs).to_siq(self, target_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'Snowflake', 'SnowflakeGen'
|
||||||
|
)
|
||||||
|
|
@ -18,16 +18,17 @@ from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Iterable, Never, TypeVar
|
from typing import Callable, Iterable, Never, TypeVar
|
||||||
import warnings
|
import warnings
|
||||||
from sqlalchemy import CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, select, text
|
from sqlalchemy import BigInteger, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
||||||
from sqlalchemy.orm import DeclarativeBase, Session, declarative_base as _declarative_base
|
from sqlalchemy.orm import DeclarativeBase, Session, declarative_base as _declarative_base, relationship
|
||||||
|
|
||||||
|
from .snowflake import SnowflakeGen
|
||||||
from .itertools import kwargs_prefix, makelist
|
from .itertools import kwargs_prefix, makelist
|
||||||
from .signing import HasSigner, UserSigner
|
from .signing import HasSigner, UserSigner
|
||||||
from .codecs import StringCase
|
from .codecs import StringCase
|
||||||
from .functools import deprecated
|
from .functools import deprecated, not_implemented
|
||||||
from .iding import SiqType, SiqCache
|
from .iding import Siq, SiqGen, SiqType, SiqCache
|
||||||
from .classtools import Incomplete, Wanted
|
from .classtools import Incomplete, Wanted
|
||||||
|
|
||||||
_T = TypeVar('_T')
|
_T = TypeVar('_T')
|
||||||
|
|
@ -36,7 +37,7 @@ _T = TypeVar('_T')
|
||||||
# Not to be confused with SiqType.
|
# Not to be confused with SiqType.
|
||||||
IdType = LargeBinary(16)
|
IdType = LargeBinary(16)
|
||||||
|
|
||||||
|
@not_implemented
|
||||||
def sql_escape(s: str, /, dialect: Dialect) -> str:
|
def sql_escape(s: str, /, dialect: Dialect) -> str:
|
||||||
"""
|
"""
|
||||||
Escape a value for SQL embedding, using SQLAlchemy's literal processors.
|
Escape a value for SQL embedding, using SQLAlchemy's literal processors.
|
||||||
|
|
@ -49,20 +50,49 @@ def sql_escape(s: str, /, dialect: Dialect) -> str:
|
||||||
raise TypeError('invalid data type')
|
raise TypeError('invalid data type')
|
||||||
|
|
||||||
|
|
||||||
def id_column(typ: SiqType, *, primary_key: bool = True):
|
def create_session(url: str) -> Session:
|
||||||
|
"""
|
||||||
|
Create a session on the fly, given a database URL. Useful for
|
||||||
|
contextless environments, such as Python REPL.
|
||||||
|
|
||||||
|
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
||||||
|
completely differently!!
|
||||||
|
"""
|
||||||
|
engine = create_engine(url)
|
||||||
|
return Session(bind = engine)
|
||||||
|
|
||||||
|
def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Marks a column which contains a SIQ.
|
Marks a column which contains a SIQ.
|
||||||
"""
|
"""
|
||||||
def new_id_factory(owner: DeclarativeBase) -> Callable:
|
def new_id_factory(owner: DeclarativeBase) -> Callable:
|
||||||
domain_name = owner.metadata.info['domain_name']
|
domain_name = owner.metadata.info['domain_name']
|
||||||
idgen = SiqCache(domain_name, typ)
|
idgen = SiqCache(SiqGen(domain_name), typ)
|
||||||
def new_id() -> bytes:
|
def new_id() -> bytes:
|
||||||
return idgen.generate().to_bytes()
|
return Siq(idgen.generate()).to_bytes()
|
||||||
return new_id
|
return new_id
|
||||||
if primary_key:
|
if primary_key:
|
||||||
return Incomplete(Column, IdType, primary_key = True, default = Wanted(new_id_factory))
|
return Incomplete(Column, IdType, primary_key = True, default = Wanted(new_id_factory), **kwargs)
|
||||||
else:
|
else:
|
||||||
return Incomplete(Column, IdType, unique = True, nullable = False, default = Wanted(new_id_factory))
|
return Incomplete(Column, IdType, unique = True, nullable = False, default = Wanted(new_id_factory), **kwargs)
|
||||||
|
|
||||||
|
def snowflake_column(*, primary_key: bool = True, **kwargs):
|
||||||
|
"""
|
||||||
|
Same as id_column() but with snowflakes.
|
||||||
|
|
||||||
|
XXX this is meant ONLY as means of transition; for new stuff, use id_column() and SIQ.
|
||||||
|
"""
|
||||||
|
def new_id_factory(owner: DeclarativeBase) -> Callable:
|
||||||
|
epoch = owner.metadata.info['snowflake_epoch']
|
||||||
|
# more arguments will be passed on (?)
|
||||||
|
idgen = SnowflakeGen(epoch)
|
||||||
|
def new_id() -> int:
|
||||||
|
return idgen.generate_one()
|
||||||
|
return new_id
|
||||||
|
if primary_key:
|
||||||
|
return Incomplete(Column, BigInteger, primary_key = True, default = Wanted(new_id_factory), **kwargs)
|
||||||
|
else:
|
||||||
|
return Incomplete(Column, BigInteger, unique = True, nullable = False, default = Wanted(new_id_factory), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def match_constraint(col_name: str, regex: str, /, dialect: str = 'default', constraint_name: str | None = None) -> CheckConstraint:
|
def match_constraint(col_name: str, regex: str, /, dialect: str = 'default', constraint_name: str | None = None) -> CheckConstraint:
|
||||||
|
|
@ -99,9 +129,12 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No
|
||||||
metadata = dict()
|
metadata = dict()
|
||||||
if 'info' not in metadata:
|
if 'info' not in metadata:
|
||||||
metadata['info'] = dict()
|
metadata['info'] = dict()
|
||||||
|
# snowflake metadata
|
||||||
|
snowflake_kwargs = kwargs_prefix(kwargs, 'snowflake_', remove=True, keep_prefix=True)
|
||||||
metadata['info'].update(
|
metadata['info'].update(
|
||||||
domain_name = domain_name,
|
domain_name = domain_name,
|
||||||
secret_key = master_secret
|
secret_key = master_secret,
|
||||||
|
**snowflake_kwargs
|
||||||
)
|
)
|
||||||
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
|
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
|
||||||
return Base
|
return Base
|
||||||
|
|
@ -160,6 +193,26 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]:
|
||||||
return (date_col, acc_col)
|
return (date_col, acc_col)
|
||||||
|
|
||||||
|
|
||||||
|
def parent_children(keyword: str, /, **kwargs):
|
||||||
|
"""
|
||||||
|
Self-referential one-to-many relationship pair.
|
||||||
|
Parent comes first, children come later.
|
||||||
|
|
||||||
|
keyword is used in back_populates column names: convention over
|
||||||
|
configuration. Naming it otherwise will BREAK your models.
|
||||||
|
|
||||||
|
Additional keyword arguments can be sourced with parent_ and child_ argument prefixes,
|
||||||
|
obviously.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_kwargs = kwargs_prefix(kwargs, 'parent_')
|
||||||
|
child_kwargs = kwargs_prefix(kwargs, 'child_')
|
||||||
|
|
||||||
|
parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', **parent_kwargs)
|
||||||
|
child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', **child_kwargs)
|
||||||
|
|
||||||
|
return parent, child
|
||||||
|
|
||||||
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
||||||
"""
|
"""
|
||||||
Return a table's column given its name.
|
Return a table's column given its name.
|
||||||
|
|
@ -200,8 +253,7 @@ class AuthSrc(metaclass=ABCMeta):
|
||||||
|
|
||||||
|
|
||||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||||
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None,
|
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
||||||
invalid_exc: Callable | None = None, required_exc: Callable | None = None):
|
|
||||||
'''
|
'''
|
||||||
Inject the current user into a view, given the Authorization: Bearer header.
|
Inject the current user into a view, given the Authorization: Bearer header.
|
||||||
|
|
||||||
|
|
@ -222,11 +274,11 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str |
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _default_invalid(msg: str):
|
def _default_invalid(msg: str = 'validation failed'):
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
invalid_exc = invalid_exc or _default_invalid
|
invalid_exc = src.invalid_exc or _default_invalid
|
||||||
required_exc = required_exc or (lambda: _default_invalid())
|
required_exc = src.required_exc or (lambda: _default_invalid())
|
||||||
|
|
||||||
def decorator(func: Callable):
|
def decorator(func: Callable):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue