split sqlalchemy modules
This commit is contained in:
parent
f7807ff05a
commit
97194b2b85
4 changed files with 379 additions and 307 deletions
169
src/suou/sqlalchemy/__init__.py
Normal file
169
src/suou/sqlalchemy/__init__.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
Utilities for SQLAlchemy
|
||||
|
||||
---
|
||||
|
||||
Copyright (c) 2025 Sakuragasaki46.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
See LICENSE for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
This software is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Iterable, Never, TypeVar
|
||||
import warnings
|
||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from ..snowflake import SnowflakeGen
|
||||
from ..itertools import kwargs_prefix, makelist
|
||||
from ..signing import HasSigner, UserSigner
|
||||
from ..codecs import StringCase
|
||||
from ..functools import deprecated, not_implemented
|
||||
from ..iding import Siq, SiqGen, SiqType, SiqCache
|
||||
from ..classtools import Incomplete, Wanted
|
||||
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# SIQs are 14 bytes long. Storage is padded for alignment
|
||||
# Not to be confused with SiqType.
|
||||
IdType: TypeEngine = LargeBinary(16)
|
||||
|
||||
def create_session(url: str) -> Session:
|
||||
"""
|
||||
Create a session on the fly, given a database URL. Useful for
|
||||
contextless environments, such as Python REPL.
|
||||
|
||||
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
||||
completely differently!!
|
||||
"""
|
||||
engine = create_engine(url)
|
||||
return Session(bind = engine)
|
||||
|
||||
|
||||
|
||||
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
|
||||
"""
|
||||
Generate a user signing function.
|
||||
|
||||
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
|
||||
and a user secret.
|
||||
"""
|
||||
id_val: Column | Wanted[Column]
|
||||
if isinstance(id_attr, Column):
|
||||
id_val = id_attr
|
||||
elif isinstance(id_attr, str):
|
||||
id_val = Wanted(id_attr)
|
||||
if isinstance(secret_attr, Column):
|
||||
secret_val = secret_attr
|
||||
elif isinstance(secret_attr, str):
|
||||
secret_val = Wanted(secret_attr)
|
||||
def token_signer_factory(owner: DeclarativeBase, name: str):
|
||||
def my_signer(self):
|
||||
return UserSigner(
|
||||
owner.metadata.info['secret_key'],
|
||||
id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
my_signer.__name__ = name
|
||||
return my_signer
|
||||
return Incomplete(Wanted(token_signer_factory))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Utilities for use in web apps below
|
||||
|
||||
@deprecated('not part of the public API and not even working')
|
||||
class AuthSrc(metaclass=ABCMeta):
|
||||
'''
|
||||
AuthSrc object required for require_auth_base().
|
||||
|
||||
This is an abstract class and is NOT usable directly.
|
||||
|
||||
This is not part of the public API
|
||||
'''
|
||||
def required_exc(self) -> Never:
|
||||
raise ValueError('required field missing')
|
||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||
raise ValueError(msg)
|
||||
@abstractmethod
|
||||
def get_session(self) -> Session:
|
||||
pass
|
||||
def get_user(self, getter: Callable):
|
||||
return getter(self.get_token())
|
||||
@abstractmethod
|
||||
def get_token(self):
|
||||
pass
|
||||
@abstractmethod
|
||||
def get_signature(self):
|
||||
pass
|
||||
|
||||
|
||||
@deprecated('not working and too complex to use')
|
||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
||||
'''
|
||||
Inject the current user into a view, given the Authorization: Bearer header.
|
||||
|
||||
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
||||
'''
|
||||
col = want_column(cls, column)
|
||||
validators = makelist(validators)
|
||||
|
||||
def get_user(token) -> DeclarativeBase:
|
||||
if token is None:
|
||||
return None
|
||||
tok_parts = UserSigner.split_token(token)
|
||||
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
||||
try:
|
||||
signer: UserSigner = user.signer()
|
||||
signer.unsign(token)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _default_invalid(msg: str = 'Validation failed'):
|
||||
raise ValueError(msg)
|
||||
|
||||
invalid_exc = src.invalid_exc or _default_invalid
|
||||
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*a, **ka):
|
||||
ka[dest] = get_user(src.get_token())
|
||||
if not ka[dest] and required:
|
||||
required_exc()
|
||||
if signed:
|
||||
ka[sig_dest] = src.get_signature()
|
||||
for valid in validators:
|
||||
if not valid(ka[dest]):
|
||||
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
||||
return func(*a, **ka)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query
|
||||
from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = (
|
||||
'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer',
|
||||
'match_column', 'match_constraint', 'bool_column', 'parent_children',
|
||||
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
|
||||
# .asyncio
|
||||
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
|
||||
)
|
||||
176
src/suou/sqlalchemy/asyncio.py
Normal file
176
src/suou/sqlalchemy/asyncio.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
|
||||
"""
|
||||
Helpers for asynchronous use of SQLAlchemy.
|
||||
|
||||
NEW 0.5.0; moved to current location 0.6.0
|
||||
|
||||
---
|
||||
|
||||
Copyright (c) 2025 Sakuragasaki46.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
See LICENSE for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
This software is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
|
||||
|
||||
from sqlalchemy import Engine, Select, func, select
|
||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from suou.exceptions import NotFoundError
|
||||
|
||||
class SQLAlchemy:
|
||||
"""
|
||||
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||
eligible for async environments
|
||||
|
||||
NEW 0.5.0
|
||||
"""
|
||||
base: DeclarativeBase
|
||||
engine: AsyncEngine
|
||||
_sessions: list[AsyncSession]
|
||||
NotFound = NotFoundError
|
||||
|
||||
def __init__(self, model_class: DeclarativeBase):
|
||||
self.base = model_class
|
||||
self.engine = None
|
||||
self._sessions = []
|
||||
def bind(self, url: str):
|
||||
self.engine = create_async_engine(url)
|
||||
def _ensure_engine(self):
|
||||
if self.engine is None:
|
||||
raise RuntimeError('database is not connected')
|
||||
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||
self._ensure_engine()
|
||||
## XXX is it accurate?
|
||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||
self._sessions.append(s)
|
||||
return s
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return await self.begin()
|
||||
async def __aexit__(self, e1, e2, e3):
|
||||
## XXX is it accurate?
|
||||
s = self._sessions.pop()
|
||||
if e1:
|
||||
await s.rollback()
|
||||
else:
|
||||
await s.commit()
|
||||
await s.close()
|
||||
async def paginate(self, select: Select, *,
|
||||
page: int | None = None, per_page: int | None = None,
|
||||
max_per_page: int | None = None, error_out: bool = True,
|
||||
count: bool = True) -> AsyncSelectPagination:
|
||||
"""
|
||||
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
|
||||
"""
|
||||
async with self as session:
|
||||
return AsyncSelectPagination(
|
||||
select = select,
|
||||
session = session,
|
||||
page = page,
|
||||
per_page=per_page, max_per_page=max_per_page,
|
||||
error_out=self.NotFound if error_out else None, count=count
|
||||
)
|
||||
async def create_all(self, *, checkfirst = True):
|
||||
"""
|
||||
Initialize database
|
||||
"""
|
||||
self._ensure_engine()
|
||||
self.base.metadata.create_all(
|
||||
self.engine, checkfirst=checkfirst
|
||||
)
|
||||
|
||||
|
||||
|
||||
class AsyncSelectPagination(Pagination):
|
||||
"""
|
||||
flask_sqlalchemy.SelectPagination but asynchronous.
|
||||
|
||||
Pagination is not part of the public API, therefore expect that it may break
|
||||
"""
|
||||
|
||||
async def _query_items(self) -> list:
|
||||
select_q: Select = self._query_args["select"]
|
||||
select = select_q.limit(self.per_page).offset(self._query_offset)
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select)).scalars()
|
||||
return out
|
||||
|
||||
async def _query_count(self) -> int:
|
||||
select_q: Select = self._query_args["select"]
|
||||
sub = select_q.options(lazyload("*")).order_by(None).subquery()
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
|
||||
return out
|
||||
|
||||
def __init__(self,
|
||||
page: int | None = None,
|
||||
per_page: int | None = None,
|
||||
max_per_page: int | None = 100,
|
||||
error_out: Exception | None = NotFoundError,
|
||||
count: bool = True,
|
||||
**kwargs):
|
||||
## XXX flask-sqlalchemy says Pagination() is not public API.
|
||||
## Things may break; beware.
|
||||
self._query_args = kwargs
|
||||
page, per_page = self._prepare_page_args(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=max_per_page,
|
||||
error_out=error_out,
|
||||
)
|
||||
|
||||
self.page: int = page
|
||||
"""The current page."""
|
||||
|
||||
self.per_page: int = per_page
|
||||
"""The maximum number of items on a page."""
|
||||
|
||||
self.max_per_page: int | None = max_per_page
|
||||
"""The maximum allowed value for ``per_page``."""
|
||||
|
||||
self.items = None
|
||||
self.total = None
|
||||
self.error_out = error_out
|
||||
self.has_count = count
|
||||
|
||||
async def __aiter__(self):
|
||||
self.items = await self._query_items()
|
||||
if self.items is None:
|
||||
raise RuntimeError('query returned None')
|
||||
if not self.items and self.page != 1 and self.error_out:
|
||||
raise self.error_out
|
||||
if self.has_count:
|
||||
self.total = await self._query_count()
|
||||
for i in self.items:
|
||||
yield i
|
||||
|
||||
|
||||
def async_query(db: SQLAlchemy, multi: False):
|
||||
"""
|
||||
Wraps a query returning function into an executor coroutine.
|
||||
|
||||
The query function remains available as the .q or .query attribute.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def executor(*args, **kwargs):
|
||||
async with db as session:
|
||||
result = await session.execute(func(*args, **kwargs))
|
||||
return result.scalars() if multi else result.scalar()
|
||||
executor.query = executor.q = func
|
||||
return executor
|
||||
return decorator
|
||||
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
"""
|
||||
Utilities for SQLAlchemy
|
||||
Utilities for SQLAlchemy; ORM
|
||||
|
||||
NEW 0.6.0
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -14,54 +16,38 @@ This software is distributed on an "AS IS" BASIS,
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Iterable, Never, TypeVar
|
||||
|
||||
from binascii import Incomplete
|
||||
from typing import Any, Callable
|
||||
import warnings
|
||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
|
||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text
|
||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, declarative_base as _declarative_base, relationship
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from suou.classtools import Wanted
|
||||
from suou.codecs import StringCase
|
||||
from suou.iding import Siq, SiqCache, SiqGen, SiqType
|
||||
from suou.itertools import kwargs_prefix
|
||||
from suou.snowflake import SnowflakeGen
|
||||
from suou.sqlalchemy import IdType
|
||||
|
||||
from .snowflake import SnowflakeGen
|
||||
from .itertools import kwargs_prefix, makelist
|
||||
from .signing import HasSigner, UserSigner
|
||||
from .codecs import StringCase
|
||||
from .functools import deprecated, not_implemented
|
||||
from .iding import Siq, SiqGen, SiqType, SiqCache
|
||||
from .classtools import Incomplete, Wanted
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# SIQs are 14 bytes long. Storage is padded for alignment
|
||||
# Not to be confused with SiqType.
|
||||
IdType: TypeEngine = LargeBinary(16)
|
||||
|
||||
@not_implemented
|
||||
def sql_escape(s: str, /, dialect: Dialect) -> str:
|
||||
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
||||
"""
|
||||
Escape a value for SQL embedding, using SQLAlchemy's literal processors.
|
||||
Requires a dialect argument.
|
||||
Return a table's column given its name.
|
||||
|
||||
XXX this function is not mature yet, do not use
|
||||
XXX does it belong outside any scopes?
|
||||
"""
|
||||
if isinstance(s, str):
|
||||
return String().literal_processor(dialect=dialect)(s)
|
||||
raise TypeError('invalid data type')
|
||||
if isinstance(col, Incomplete):
|
||||
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
|
||||
elif isinstance(col, Column):
|
||||
return col
|
||||
elif isinstance(col, str):
|
||||
return getattr(cls, col)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
|
||||
def create_session(url: str) -> Session:
|
||||
"""
|
||||
Create a session on the fly, given a database URL. Useful for
|
||||
contextless environments, such as Python REPL.
|
||||
|
||||
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
||||
completely differently!!
|
||||
"""
|
||||
engine = create_engine(url)
|
||||
return Session(bind = engine)
|
||||
|
||||
def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs):
|
||||
"""
|
||||
Marks a column which contains a SIQ.
|
||||
|
|
@ -120,7 +106,6 @@ def match_column(length: int, regex: str, /, case: StringCase = StringCase.AS_IS
|
|||
return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name,
|
||||
constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs)
|
||||
|
||||
|
||||
def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]:
|
||||
"""
|
||||
Column for a single boolean value.
|
||||
|
|
@ -149,35 +134,9 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No
|
|||
)
|
||||
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
|
||||
return Base
|
||||
entity_base = deprecated('use declarative_base() instead')(declarative_base)
|
||||
entity_base = warnings.deprecated('use declarative_base() instead')(declarative_base)
|
||||
|
||||
|
||||
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
|
||||
"""
|
||||
Generate a user signing function.
|
||||
|
||||
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
|
||||
and a user secret.
|
||||
"""
|
||||
id_val: Column | Wanted[Column]
|
||||
if isinstance(id_attr, Column):
|
||||
id_val = id_attr
|
||||
elif isinstance(id_attr, str):
|
||||
id_val = Wanted(id_attr)
|
||||
if isinstance(secret_attr, Column):
|
||||
secret_val = secret_attr
|
||||
elif isinstance(secret_attr, str):
|
||||
secret_val = Wanted(secret_attr)
|
||||
def token_signer_factory(owner: DeclarativeBase, name: str):
|
||||
def my_signer(self):
|
||||
return UserSigner(
|
||||
owner.metadata.info['secret_key'],
|
||||
id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
my_signer.__name__ = name
|
||||
return my_signer
|
||||
return Incomplete(Wanted(token_signer_factory))
|
||||
|
||||
|
||||
def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
|
||||
"""
|
||||
|
|
@ -208,6 +167,7 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]:
|
|||
return (date_col, acc_col)
|
||||
|
||||
|
||||
|
||||
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]:
|
||||
"""
|
||||
Self-referential one-to-many relationship pair.
|
||||
|
|
@ -272,93 +232,3 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa
|
|||
|
||||
return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs)
|
||||
|
||||
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
||||
"""
|
||||
Return a table's column given its name.
|
||||
|
||||
XXX does it belong outside any scopes?
|
||||
"""
|
||||
if isinstance(col, Incomplete):
|
||||
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
|
||||
elif isinstance(col, Column):
|
||||
return col
|
||||
elif isinstance(col, str):
|
||||
return getattr(cls, col)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
## Utilities for use in web apps below
|
||||
|
||||
class AuthSrc(metaclass=ABCMeta):
|
||||
'''
|
||||
AuthSrc object required for require_auth_base().
|
||||
|
||||
This is an abstract class and is NOT usable directly.
|
||||
|
||||
This is not part of the public API
|
||||
'''
|
||||
def required_exc(self) -> Never:
|
||||
raise ValueError('required field missing')
|
||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||
raise ValueError(msg)
|
||||
@abstractmethod
|
||||
def get_session(self) -> Session:
|
||||
pass
|
||||
def get_user(self, getter: Callable):
|
||||
return getter(self.get_token())
|
||||
@abstractmethod
|
||||
def get_token(self):
|
||||
pass
|
||||
@abstractmethod
|
||||
def get_signature(self):
|
||||
pass
|
||||
|
||||
|
||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
||||
'''
|
||||
Inject the current user into a view, given the Authorization: Bearer header.
|
||||
|
||||
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
||||
'''
|
||||
col = want_column(cls, column)
|
||||
validators = makelist(validators)
|
||||
|
||||
def get_user(token) -> DeclarativeBase:
|
||||
if token is None:
|
||||
return None
|
||||
tok_parts = UserSigner.split_token(token)
|
||||
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
||||
try:
|
||||
signer: UserSigner = user.signer()
|
||||
signer.unsign(token)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _default_invalid(msg: str = 'Validation failed'):
|
||||
raise ValueError(msg)
|
||||
|
||||
invalid_exc = src.invalid_exc or _default_invalid
|
||||
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*a, **ka):
|
||||
ka[dest] = get_user(src.get_token())
|
||||
if not ka[dest] and required:
|
||||
required_exc()
|
||||
if signed:
|
||||
ka[sig_dest] = src.get_signature()
|
||||
for valid in validators:
|
||||
if not valid(ka[dest]):
|
||||
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
||||
return func(*a, **ka)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = (
|
||||
'IdType', 'id_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint',
|
||||
'author_pair', 'age_pair', 'require_auth_base', 'bound_fk', 'unbound_fk', 'want_column'
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Helpers for asynchronous use of SQLAlchemy.
|
||||
|
||||
NEW 0.5.0
|
||||
NEW 0.5.0; MOVED to sqlalchemy.asyncio in 0.6.0
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -17,159 +17,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
|
||||
|
||||
from sqlalchemy import Engine, Select, func, select
|
||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from suou.exceptions import NotFoundError
|
||||
|
||||
class SQLAlchemy:
|
||||
"""
|
||||
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||
eligible for async environments
|
||||
|
||||
NEW 0.5.0
|
||||
"""
|
||||
base: DeclarativeBase
|
||||
engine: AsyncEngine
|
||||
_sessions: list[AsyncSession]
|
||||
NotFound = NotFoundError
|
||||
|
||||
def __init__(self, model_class: DeclarativeBase):
|
||||
self.base = model_class
|
||||
self.engine = None
|
||||
self._sessions = []
|
||||
def bind(self, url: str):
|
||||
self.engine = create_async_engine(url)
|
||||
def _ensure_engine(self):
|
||||
if self.engine is None:
|
||||
raise RuntimeError('database is not connected')
|
||||
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||
self._ensure_engine()
|
||||
## XXX is it accurate?
|
||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||
self._sessions.append(s)
|
||||
return s
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return await self.begin()
|
||||
async def __aexit__(self, e1, e2, e3):
|
||||
## XXX is it accurate?
|
||||
s = self._sessions.pop()
|
||||
if e1:
|
||||
await s.rollback()
|
||||
else:
|
||||
await s.commit()
|
||||
await s.close()
|
||||
async def paginate(self, select: Select, *,
|
||||
page: int | None = None, per_page: int | None = None,
|
||||
max_per_page: int | None = None, error_out: bool = True,
|
||||
count: bool = True) -> AsyncSelectPagination:
|
||||
"""
|
||||
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
|
||||
"""
|
||||
async with self as session:
|
||||
return AsyncSelectPagination(
|
||||
select = select,
|
||||
session = session,
|
||||
page = page,
|
||||
per_page=per_page, max_per_page=max_per_page,
|
||||
error_out=self.NotFound if error_out else None, count=count
|
||||
)
|
||||
async def create_all(self, *, checkfirst = True):
|
||||
"""
|
||||
Initialize database
|
||||
"""
|
||||
self._ensure_engine()
|
||||
self.base.metadata.create_all(
|
||||
self.engine, checkfirst=checkfirst
|
||||
)
|
||||
from .functools import deprecated
|
||||
|
||||
|
||||
|
||||
class AsyncSelectPagination(Pagination):
|
||||
"""
|
||||
flask_sqlalchemy.SelectPagination but asynchronous.
|
||||
from .sqlalchemy.asyncio import SQLAlchemy, AsyncSelectPagination, async_query
|
||||
|
||||
Pagination is not part of the public API, therefore expect that it may break
|
||||
"""
|
||||
|
||||
async def _query_items(self) -> list:
|
||||
select_q: Select = self._query_args["select"]
|
||||
select = select_q.limit(self.per_page).offset(self._query_offset)
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select)).scalars()
|
||||
return out
|
||||
|
||||
async def _query_count(self) -> int:
|
||||
select_q: Select = self._query_args["select"]
|
||||
sub = select_q.options(lazyload("*")).order_by(None).subquery()
|
||||
session: AsyncSession = self._query_args["session"]
|
||||
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
|
||||
return out
|
||||
|
||||
def __init__(self,
|
||||
page: int | None = None,
|
||||
per_page: int | None = None,
|
||||
max_per_page: int | None = 100,
|
||||
error_out: Exception | None = NotFoundError,
|
||||
count: bool = True,
|
||||
**kwargs):
|
||||
## XXX flask-sqlalchemy says Pagination() is not public API.
|
||||
## Things may break; beware.
|
||||
self._query_args = kwargs
|
||||
page, per_page = self._prepare_page_args(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=max_per_page,
|
||||
error_out=error_out,
|
||||
)
|
||||
|
||||
self.page: int = page
|
||||
"""The current page."""
|
||||
|
||||
self.per_page: int = per_page
|
||||
"""The maximum number of items on a page."""
|
||||
|
||||
self.max_per_page: int | None = max_per_page
|
||||
"""The maximum allowed value for ``per_page``."""
|
||||
|
||||
self.items = None
|
||||
self.total = None
|
||||
self.error_out = error_out
|
||||
self.has_count = count
|
||||
|
||||
async def __aiter__(self):
|
||||
self.items = await self._query_items()
|
||||
if self.items is None:
|
||||
raise RuntimeError('query returned None')
|
||||
if not self.items and self.page != 1 and self.error_out:
|
||||
raise self.error_out
|
||||
if self.has_count:
|
||||
self.total = await self._query_count()
|
||||
for i in self.items:
|
||||
yield i
|
||||
|
||||
|
||||
def async_query(db: SQLAlchemy, multi: False):
|
||||
"""
|
||||
Wraps a query returning function into an executor coroutine.
|
||||
|
||||
The query function remains available as the .q or .query attribute.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def executor(*args, **kwargs):
|
||||
async with db as session:
|
||||
result = await session.execute(func(*args, **kwargs))
|
||||
return result.scalars() if multi else result.scalar()
|
||||
executor.query = executor.q = func
|
||||
return executor
|
||||
return decorator
|
||||
|
||||
SQLAlchemy = deprecated('import from suou.sqlalchemy.asyncio instead')(SQLAlchemy)
|
||||
AsyncSelectPagination = deprecated('import from suou.sqlalchemy.asyncio instead')(AsyncSelectPagination)
|
||||
async_query = deprecated('import from suou.sqlalchemy.asyncio instead')(async_query)
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = ('SQLAlchemy', 'async_query')
|
||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||
Loading…
Add table
Add a link
Reference in a new issue