Compare commits
No commits in common. "97194b2b8522ee243050466d48d7f0db37a8210a" and "a127c8815930c7f2e92b047a80c50b43fc63144e" have entirely different histories.
97194b2b85
...
a127c88159
7 changed files with 315 additions and 404 deletions
|
|
@ -1,13 +1,5 @@
|
|||
# Changelog
|
||||
|
||||
## 0.6.0
|
||||
|
||||
...
|
||||
|
||||
## 0.5.3
|
||||
|
||||
...
|
||||
|
||||
## 0.5.2
|
||||
|
||||
- Fixed poorly handled merge conflict leaving `.sqlalchemy` modulem unusable
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from .validators import matches
|
|||
from .redact import redact_url_password
|
||||
from .http import WantsContentType
|
||||
|
||||
__version__ = "0.6.0-dev35"
|
||||
__version__ = "0.5.2"
|
||||
|
||||
__all__ = (
|
||||
'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue',
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from types import EllipsisType
|
||||
from typing import Any, Callable, Generic, Iterable, Mapping, TypeVar
|
||||
import logging
|
||||
|
||||
|
|
@ -25,10 +24,7 @@ _T = TypeVar('_T')
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MissingType(object):
|
||||
__slots__ = ()
|
||||
|
||||
MISSING = MissingType()
|
||||
MISSING = object()
|
||||
|
||||
def _not_missing(v) -> bool:
|
||||
return v and v is not MISSING
|
||||
|
|
@ -47,10 +43,10 @@ class Wanted(Generic[_T]):
|
|||
Owner class will call .__set_name__() on the parent Incomplete instance;
|
||||
the __set_name__ parameters (owner class and name) will be passed down here.
|
||||
"""
|
||||
_target: Callable | str | None | EllipsisType
|
||||
def __init__(self, getter: Callable | str | None | EllipsisType):
|
||||
_target: Callable | str | None | Ellipsis
|
||||
def __init__(self, getter: Callable | str | None | Ellipsis):
|
||||
self._target = getter
|
||||
def __call__(self, owner: type, name: str | None = None) -> _T | str | None:
|
||||
def __call__(self, owner: type, name: str | None = None) -> _T:
|
||||
if self._target is None or self._target is Ellipsis:
|
||||
return name
|
||||
elif isinstance(self._target, str):
|
||||
|
|
@ -71,10 +67,10 @@ class Incomplete(Generic[_T]):
|
|||
Missing arguments must be passed in the appropriate positions
|
||||
(positional or keyword) as a Wanted() object.
|
||||
"""
|
||||
_obj: Callable[..., _T]
|
||||
_obj = Callable[Any, _T]
|
||||
_args: Iterable
|
||||
_kwargs: dict
|
||||
def __init__(self, obj: Callable[..., _T] | Wanted, *args, **kwargs):
|
||||
def __init__(self, obj: Callable[Any, _T] | Wanted, *args, **kwargs):
|
||||
if isinstance(obj, Wanted):
|
||||
self._obj = lambda x: x
|
||||
self._args = (obj, )
|
||||
|
|
@ -124,7 +120,7 @@ class ValueSource(Mapping):
|
|||
class ValueProperty(Generic[_T]):
|
||||
_name: str | None
|
||||
_srcs: dict[str, str]
|
||||
_val: Any | MissingType
|
||||
_val: Any | MISSING
|
||||
_default: Any | None
|
||||
_cast: Callable | None
|
||||
_required: bool
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
"""
|
||||
Utilities for SQLAlchemy; ORM
|
||||
|
||||
NEW 0.6.0
|
||||
Utilities for SQLAlchemy
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -16,38 +14,53 @@ This software is distributed on an "AS IS" BASIS,
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from binascii import Incomplete
|
||||
from typing import Any, Callable
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import Callable, Iterable, Never, TypeVar
|
||||
import warnings
|
||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text
|
||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, declarative_base as _declarative_base, relationship
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from suou.classtools import Wanted
|
||||
from suou.codecs import StringCase
|
||||
from suou.iding import Siq, SiqCache, SiqGen, SiqType
|
||||
from suou.itertools import kwargs_prefix
|
||||
from suou.snowflake import SnowflakeGen
|
||||
from suou.sqlalchemy import IdType
|
||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
|
||||
|
||||
from .snowflake import SnowflakeGen
|
||||
from .itertools import kwargs_prefix, makelist
|
||||
from .signing import HasSigner, UserSigner
|
||||
from .codecs import StringCase
|
||||
from .functools import deprecated, not_implemented
|
||||
from .iding import Siq, SiqGen, SiqType, SiqCache
|
||||
from .classtools import Incomplete, Wanted
|
||||
|
||||
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# SIQs are 14 bytes long. Storage is padded for alignment
|
||||
# Not to be confused with SiqType.
|
||||
IdType: type[LargeBinary] = LargeBinary(16)
|
||||
|
||||
@not_implemented
|
||||
def sql_escape(s: str, /, dialect: Dialect) -> str:
|
||||
"""
|
||||
Return a table's column given its name.
|
||||
Escape a value for SQL embedding, using SQLAlchemy's literal processors.
|
||||
Requires a dialect argument.
|
||||
|
||||
XXX does it belong outside any scopes?
|
||||
XXX this function is not mature yet, do not use
|
||||
"""
|
||||
if isinstance(col, Incomplete):
|
||||
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
|
||||
elif isinstance(col, Column):
|
||||
return col
|
||||
elif isinstance(col, str):
|
||||
return getattr(cls, col)
|
||||
else:
|
||||
raise TypeError
|
||||
if isinstance(s, str):
|
||||
return String().literal_processor(dialect=dialect)(s)
|
||||
raise TypeError('invalid data type')
|
||||
|
||||
|
||||
def create_session(url: str) -> Session:
|
||||
"""
|
||||
Create a session on the fly, given a database URL. Useful for
|
||||
contextless environments, such as Python REPL.
|
||||
|
||||
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
||||
completely differently!!
|
||||
"""
|
||||
engine = create_engine(url)
|
||||
return Session(bind = engine)
|
||||
|
||||
def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs):
|
||||
"""
|
||||
Marks a column which contains a SIQ.
|
||||
|
|
@ -106,6 +119,7 @@ def match_column(length: int, regex: str, /, case: StringCase = StringCase.AS_IS
|
|||
return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name,
|
||||
constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs)
|
||||
|
||||
|
||||
def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]:
|
||||
"""
|
||||
Column for a single boolean value.
|
||||
|
|
@ -134,11 +148,33 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No
|
|||
)
|
||||
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
|
||||
return Base
|
||||
entity_base = warnings.deprecated('use declarative_base() instead')(declarative_base)
|
||||
entity_base = deprecated('use declarative_base() instead')(declarative_base)
|
||||
|
||||
|
||||
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
|
||||
"""
|
||||
Generate a user signing function.
|
||||
|
||||
def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
|
||||
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
|
||||
and a user secret.
|
||||
"""
|
||||
if isinstance(id_attr, Column):
|
||||
id_val = id_attr
|
||||
elif isinstance(id_attr, str):
|
||||
id_val = Wanted(id_attr)
|
||||
if isinstance(secret_attr, Column):
|
||||
secret_val = secret_attr
|
||||
elif isinstance(secret_attr, str):
|
||||
secret_val = Wanted(secret_attr)
|
||||
def token_signer_factory(owner: DeclarativeBase, name: str):
|
||||
def my_signer(self):
|
||||
return UserSigner(owner.metadata.info['secret_key'], id_val.__get__(self, owner), secret_val.__get__(self, owner))
|
||||
my_signer.__name__ = name
|
||||
return my_signer
|
||||
return Incomplete(Wanted(token_signer_factory))
|
||||
|
||||
|
||||
def author_pair(fk_name: str, *, id_type: type = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
|
||||
"""
|
||||
Return an owner ID/signature column pair, for authenticated values.
|
||||
"""
|
||||
|
|
@ -167,8 +203,7 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]:
|
|||
return (date_col, acc_col)
|
||||
|
||||
|
||||
|
||||
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]:
|
||||
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship], Incomplete[Relationship]]:
|
||||
"""
|
||||
Self-referential one-to-many relationship pair.
|
||||
Parent comes first, children come later.
|
||||
|
|
@ -185,8 +220,8 @@ def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Inco
|
|||
parent_kwargs = kwargs_prefix(kwargs, 'parent_')
|
||||
child_kwargs = kwargs_prefix(kwargs, 'child_')
|
||||
|
||||
parent: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs)
|
||||
child: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs)
|
||||
parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs)
|
||||
child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs)
|
||||
|
||||
return parent, child
|
||||
|
||||
|
|
@ -232,3 +267,93 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa
|
|||
|
||||
return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs)
|
||||
|
||||
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
||||
"""
|
||||
Return a table's column given its name.
|
||||
|
||||
XXX does it belong outside any scopes?
|
||||
"""
|
||||
if isinstance(col, Incomplete):
|
||||
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
|
||||
elif isinstance(col, Column):
|
||||
return col
|
||||
elif isinstance(col, str):
|
||||
return getattr(cls, col)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
## Utilities for use in web apps below
|
||||
|
||||
class AuthSrc(metaclass=ABCMeta):
|
||||
'''
|
||||
AuthSrc object required for require_auth_base().
|
||||
|
||||
This is an abstract class and is NOT usable directly.
|
||||
|
||||
This is not part of the public API
|
||||
'''
|
||||
def required_exc(self) -> Never:
|
||||
raise ValueError('required field missing')
|
||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||
raise ValueError(msg)
|
||||
@abstractmethod
|
||||
def get_session(self) -> Session:
|
||||
pass
|
||||
def get_user(self, getter: Callable):
|
||||
return getter(self.get_token())
|
||||
@abstractmethod
|
||||
def get_token(self):
|
||||
pass
|
||||
@abstractmethod
|
||||
def get_signature(self):
|
||||
pass
|
||||
|
||||
|
||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
||||
'''
|
||||
Inject the current user into a view, given the Authorization: Bearer header.
|
||||
|
||||
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
||||
'''
|
||||
col = want_column(cls, column)
|
||||
validators = makelist(validators)
|
||||
|
||||
def get_user(token) -> DeclarativeBase:
|
||||
if token is None:
|
||||
return None
|
||||
tok_parts = UserSigner.split_token(token)
|
||||
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
||||
try:
|
||||
signer: UserSigner = user.signer()
|
||||
signer.unsign(token)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _default_invalid(msg: str = 'Validation failed'):
|
||||
raise ValueError(msg)
|
||||
|
||||
invalid_exc = src.invalid_exc or _default_invalid
|
||||
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*a, **ka):
|
||||
ka[dest] = get_user(src.get_token())
|
||||
if not ka[dest] and required:
|
||||
required_exc()
|
||||
if signed:
|
||||
ka[sig_dest] = src.get_signature()
|
||||
for valid in validators:
|
||||
if not valid(ka[dest]):
|
||||
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
||||
return func(*a, **ka)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = (
|
||||
'IdType', 'id_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint',
|
||||
'author_pair', 'age_pair', 'require_auth_base', 'bound_fk', 'unbound_fk', 'want_column'
|
||||
)
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
"""
|
||||
Utilities for SQLAlchemy
|
||||
|
||||
---
|
||||
|
||||
Copyright (c) 2025 Sakuragasaki46.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
See LICENSE for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
This software is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Iterable, Never, TypeVar
|
||||
import warnings
|
||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from ..snowflake import SnowflakeGen
|
||||
from ..itertools import kwargs_prefix, makelist
|
||||
from ..signing import HasSigner, UserSigner
|
||||
from ..codecs import StringCase
|
||||
from ..functools import deprecated, not_implemented
|
||||
from ..iding import Siq, SiqGen, SiqType, SiqCache
|
||||
from ..classtools import Incomplete, Wanted
|
||||
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# SIQs are 14 bytes long. Storage is padded for alignment
|
||||
# Not to be confused with SiqType.
|
||||
IdType: TypeEngine = LargeBinary(16)
|
||||
|
||||
def create_session(url: str) -> Session:
|
||||
"""
|
||||
Create a session on the fly, given a database URL. Useful for
|
||||
contextless environments, such as Python REPL.
|
||||
|
||||
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
||||
completely differently!!
|
||||
"""
|
||||
engine = create_engine(url)
|
||||
return Session(bind = engine)
|
||||
|
||||
|
||||
|
||||
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
|
||||
"""
|
||||
Generate a user signing function.
|
||||
|
||||
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
|
||||
and a user secret.
|
||||
"""
|
||||
id_val: Column | Wanted[Column]
|
||||
if isinstance(id_attr, Column):
|
||||
id_val = id_attr
|
||||
elif isinstance(id_attr, str):
|
||||
id_val = Wanted(id_attr)
|
||||
if isinstance(secret_attr, Column):
|
||||
secret_val = secret_attr
|
||||
elif isinstance(secret_attr, str):
|
||||
secret_val = Wanted(secret_attr)
|
||||
def token_signer_factory(owner: DeclarativeBase, name: str):
|
||||
def my_signer(self):
|
||||
return UserSigner(
|
||||
owner.metadata.info['secret_key'],
|
||||
id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
my_signer.__name__ = name
|
||||
return my_signer
|
||||
return Incomplete(Wanted(token_signer_factory))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Utilities for use in web apps below
|
||||
|
||||
@deprecated('not part of the public API and not even working')
|
||||
class AuthSrc(metaclass=ABCMeta):
|
||||
'''
|
||||
AuthSrc object required for require_auth_base().
|
||||
|
||||
This is an abstract class and is NOT usable directly.
|
||||
|
||||
This is not part of the public API
|
||||
'''
|
||||
def required_exc(self) -> Never:
|
||||
raise ValueError('required field missing')
|
||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||
raise ValueError(msg)
|
||||
@abstractmethod
|
||||
def get_session(self) -> Session:
|
||||
pass
|
||||
def get_user(self, getter: Callable):
|
||||
return getter(self.get_token())
|
||||
@abstractmethod
|
||||
def get_token(self):
|
||||
pass
|
||||
@abstractmethod
|
||||
def get_signature(self):
|
||||
pass
|
||||
|
||||
|
||||
@deprecated('not working and too complex to use')
|
||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
||||
'''
|
||||
Inject the current user into a view, given the Authorization: Bearer header.
|
||||
|
||||
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
||||
'''
|
||||
col = want_column(cls, column)
|
||||
validators = makelist(validators)
|
||||
|
||||
def get_user(token) -> DeclarativeBase:
|
||||
if token is None:
|
||||
return None
|
||||
tok_parts = UserSigner.split_token(token)
|
||||
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
||||
try:
|
||||
signer: UserSigner = user.signer()
|
||||
signer.unsign(token)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _default_invalid(msg: str = 'Validation failed'):
|
||||
raise ValueError(msg)
|
||||
|
||||
invalid_exc = src.invalid_exc or _default_invalid
|
||||
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*a, **ka):
|
||||
ka[dest] = get_user(src.get_token())
|
||||
if not ka[dest] and required:
|
||||
required_exc()
|
||||
if signed:
|
||||
ka[sig_dest] = src.get_signature()
|
||||
for valid in validators:
|
||||
if not valid(ka[dest]):
|
||||
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
||||
return func(*a, **ka)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query
|
||||
from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = (
|
||||
'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer',
|
||||
'match_column', 'match_constraint', 'bool_column', 'parent_children',
|
||||
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
|
||||
# .asyncio
|
||||
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
|
||||
)
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
|
||||
"""
|
||||
Helpers for asynchronous use of SQLAlchemy.
|
||||
|
||||
NEW 0.5.0; moved to current location 0.6.0
|
||||
|
||||
---
|
||||
|
||||
Copyright (c) 2025 Sakuragasaki46.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
See LICENSE for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
This software is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
|
||||
|
||||
from sqlalchemy import Engine, Select, func, select
|
||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from suou.exceptions import NotFoundError
|
||||
|
||||
class SQLAlchemy:
|
||||
"""
|
||||
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||
eligible for async environments
|
||||
|
||||
NEW 0.5.0
|
||||
"""
|
||||
base: DeclarativeBase
|
||||
engine: AsyncEngine
|
||||
_sessions: list[AsyncSession]
|
||||
NotFound = NotFoundError
|
||||
|
||||
def __init__(self, model_class: DeclarativeBase):
|
||||
self.base = model_class
|
||||
self.engine = None
|
||||
self._sessions = []
|
||||
def bind(self, url: str):
|
||||
self.engine = create_async_engine(url)
|
||||
def _ensure_engine(self):
|
||||
if self.engine is None:
|
||||
raise RuntimeError('database is not connected')
|
||||
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||
self._ensure_engine()
|
||||
## XXX is it accurate?
|
||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||
self._sessions.append(s)
|
||||
return s
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return await self.begin()
|
||||
async def __aexit__(self, e1, e2, e3):
|
||||
## XXX is it accurate?
|
||||
s = self._sessions.pop()
|
||||
if e1:
|
||||
await s.rollback()
|
||||
else:
|
||||
await s.commit()
|
||||
await s.close()
|
||||
async def paginate(self, select: Select, *,
|
||||
page: int | None = None, per_page: int | None = None,
|
||||
max_per_page: int | None = None, error_out: bool = True,
|
||||
count: bool = True) -> AsyncSelectPagination:
|
||||
"""
|
||||
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
|
||||
"""
|
||||
async with self as session:
|
||||
return AsyncSelectPagination(
|
||||
select = select,
|
||||
session = session,
|
||||
page = page,
|
||||
per_page=per_page, max_per_page=max_per_page,
|
||||
error_out=self.NotFound if error_out else None, count=count
|
||||
)
|
||||
async def create_all(self, *, checkfirst = True):
|
||||
"""
|
||||
Initialize database
|
||||
"""
|
||||
self._ensure_engine()
|
||||
self.base.metadata.create_all(
|
||||
self.engine, checkfirst=checkfirst
|
||||
)
|
||||
|
||||
|
||||
|
||||
class AsyncSelectPagination(Pagination):
|
||||
"""
|
||||
flask_sqlalchemy.SelectPagination but asynchronous.
|
||||
|
||||
Pagination is not part of the public API, therefore expect that it may break
|
||||
"""
|
||||
|
||||
async def _query_items(self) -> list:
|
||||
select_q: Select = self._query_args["select"]
|
||||
select = select_q.limit(self.per_page).offset(self._query_offset)
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select)).scalars()
|
||||
return out
|
||||
|
||||
async def _query_count(self) -> int:
|
||||
select_q: Select = self._query_args["select"]
|
||||
sub = select_q.options(lazyload("*")).order_by(None).subquery()
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
|
||||
return out
|
||||
|
||||
def __init__(self,
|
||||
page: int | None = None,
|
||||
per_page: int | None = None,
|
||||
max_per_page: int | None = 100,
|
||||
error_out: Exception | None = NotFoundError,
|
||||
count: bool = True,
|
||||
**kwargs):
|
||||
## XXX flask-sqlalchemy says Pagination() is not public API.
|
||||
## Things may break; beware.
|
||||
self._query_args = kwargs
|
||||
page, per_page = self._prepare_page_args(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=max_per_page,
|
||||
error_out=error_out,
|
||||
)
|
||||
|
||||
self.page: int = page
|
||||
"""The current page."""
|
||||
|
||||
self.per_page: int = per_page
|
||||
"""The maximum number of items on a page."""
|
||||
|
||||
self.max_per_page: int | None = max_per_page
|
||||
"""The maximum allowed value for ``per_page``."""
|
||||
|
||||
self.items = None
|
||||
self.total = None
|
||||
self.error_out = error_out
|
||||
self.has_count = count
|
||||
|
||||
async def __aiter__(self):
|
||||
self.items = await self._query_items()
|
||||
if self.items is None:
|
||||
raise RuntimeError('query returned None')
|
||||
if not self.items and self.page != 1 and self.error_out:
|
||||
raise self.error_out
|
||||
if self.has_count:
|
||||
self.total = await self._query_count()
|
||||
for i in self.items:
|
||||
yield i
|
||||
|
||||
|
||||
def async_query(db: SQLAlchemy, multi: False):
|
||||
"""
|
||||
Wraps a query returning function into an executor coroutine.
|
||||
|
||||
The query function remains available as the .q or .query attribute.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def executor(*args, **kwargs):
|
||||
async with db as session:
|
||||
result = await session.execute(func(*args, **kwargs))
|
||||
return result.scalars() if multi else result.scalar()
|
||||
executor.query = executor.q = func
|
||||
return executor
|
||||
return decorator
|
||||
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Helpers for asynchronous use of SQLAlchemy.
|
||||
|
||||
NEW 0.5.0; MOVED to sqlalchemy.asyncio in 0.6.0
|
||||
NEW 0.5.0
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -17,16 +17,159 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
|
||||
from .functools import deprecated
|
||||
|
||||
from sqlalchemy import Engine, Select, func, select
|
||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from suou.exceptions import NotFoundError
|
||||
|
||||
class SQLAlchemy:
|
||||
"""
|
||||
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||
eligible for async environments
|
||||
|
||||
NEW 0.5.0
|
||||
"""
|
||||
base: DeclarativeBase
|
||||
engine: AsyncEngine
|
||||
_sessions: list[AsyncSession]
|
||||
NotFound = NotFoundError
|
||||
|
||||
def __init__(self, model_class: DeclarativeBase):
|
||||
self.base = model_class
|
||||
self.engine = None
|
||||
self._sessions = []
|
||||
def bind(self, url: str):
|
||||
self.engine = create_async_engine(url)
|
||||
def _ensure_engine(self):
|
||||
if self.engine is None:
|
||||
raise RuntimeError('database is not connected')
|
||||
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||
self._ensure_engine()
|
||||
## XXX is it accurate?
|
||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||
self._sessions.append(s)
|
||||
return s
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return await self.begin()
|
||||
async def __aexit__(self, e1, e2, e3):
|
||||
## XXX is it accurate?
|
||||
s = self._sessions.pop()
|
||||
if e1:
|
||||
await s.rollback()
|
||||
else:
|
||||
await s.commit()
|
||||
await s.close()
|
||||
async def paginate(self, select: Select, *,
|
||||
page: int | None = None, per_page: int | None = None,
|
||||
max_per_page: int | None = None, error_out: bool = True,
|
||||
count: bool = True) -> AsyncSelectPagination:
|
||||
"""
|
||||
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
|
||||
"""
|
||||
async with self as session:
|
||||
return AsyncSelectPagination(
|
||||
select = select,
|
||||
session = session,
|
||||
page = page,
|
||||
per_page=per_page, max_per_page=max_per_page,
|
||||
error_out=self.NotFound if error_out else None, count=count
|
||||
)
|
||||
async def create_all(self, *, checkfirst = True):
|
||||
"""
|
||||
Initialize database
|
||||
"""
|
||||
self._ensure_engine()
|
||||
self.base.metadata.create_all(
|
||||
self.engine, checkfirst=checkfirst
|
||||
)
|
||||
|
||||
|
||||
|
||||
from .sqlalchemy.asyncio import SQLAlchemy, AsyncSelectPagination, async_query
|
||||
class AsyncSelectPagination(Pagination):
|
||||
"""
|
||||
flask_sqlalchemy.SelectPagination but asynchronous.
|
||||
|
||||
SQLAlchemy = deprecated('import from suou.sqlalchemy.asyncio instead')(SQLAlchemy)
|
||||
AsyncSelectPagination = deprecated('import from suou.sqlalchemy.asyncio instead')(AsyncSelectPagination)
|
||||
async_query = deprecated('import from suou.sqlalchemy.asyncio instead')(async_query)
|
||||
Pagination is not part of the public API, therefore expect that it may break
|
||||
"""
|
||||
|
||||
async def _query_items(self) -> list:
|
||||
select_q: Select = self._query_args["select"]
|
||||
select = select_q.limit(self.per_page).offset(self._query_offset)
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select)).scalars()
|
||||
return out
|
||||
|
||||
async def _query_count(self) -> int:
|
||||
select_q: Select = self._query_args["select"]
|
||||
sub = select_q.options(lazyload("*")).order_by(None).subquery()
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
|
||||
return out
|
||||
|
||||
def __init__(self,
|
||||
page: int | None = None,
|
||||
per_page: int | None = None,
|
||||
max_per_page: int | None = 100,
|
||||
error_out: Exception | None = NotFoundError,
|
||||
count: bool = True,
|
||||
**kwargs):
|
||||
## XXX flask-sqlalchemy says Pagination() is not public API.
|
||||
## Things may break; beware.
|
||||
self._query_args = kwargs
|
||||
page, per_page = self._prepare_page_args(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=max_per_page,
|
||||
error_out=error_out,
|
||||
)
|
||||
|
||||
self.page: int = page
|
||||
"""The current page."""
|
||||
|
||||
self.per_page: int = per_page
|
||||
"""The maximum number of items on a page."""
|
||||
|
||||
self.max_per_page: int | None = max_per_page
|
||||
"""The maximum allowed value for ``per_page``."""
|
||||
|
||||
self.items = None
|
||||
self.total = None
|
||||
self.error_out = error_out
|
||||
self.has_count = count
|
||||
|
||||
async def __aiter__(self):
|
||||
self.items = await self._query_items()
|
||||
if self.items is None:
|
||||
raise RuntimeError('query returned None')
|
||||
if not self.items and self.page != 1 and self.error_out:
|
||||
raise self.error_out
|
||||
if self.has_count:
|
||||
self.total = await self._query_count()
|
||||
for i in self.items:
|
||||
yield i
|
||||
|
||||
|
||||
def async_query(db: SQLAlchemy, multi: False):
|
||||
"""
|
||||
Wraps a query returning function into an executor coroutine.
|
||||
|
||||
The query function remains available as the .q or .query attribute.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def executor(*args, **kwargs):
|
||||
async with db as session:
|
||||
result = await session.execute(func(*args, **kwargs))
|
||||
return result.scalars() if multi else result.scalar()
|
||||
executor.query = executor.q = func
|
||||
return executor
|
||||
return decorator
|
||||
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||
__all__ = ('SQLAlchemy', 'async_query')
|
||||
Loading…
Add table
Add a link
Reference in a new issue