Compare commits

..

No commits in common. "97194b2b8522ee243050466d48d7f0db37a8210a" and "a127c8815930c7f2e92b047a80c50b43fc63144e" have entirely different histories.

7 changed files with 315 additions and 404 deletions

View file

@ -1,13 +1,5 @@
# Changelog
## 0.6.0
...
## 0.5.3
...
## 0.5.2
- Fixed poorly handled merge conflict leaving `.sqlalchemy` modulem unusable

View file

@ -34,7 +34,7 @@ from .validators import matches
from .redact import redact_url_password
from .http import WantsContentType
__version__ = "0.6.0-dev35"
__version__ = "0.5.2"
__all__ = (
'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue',

View file

@ -17,7 +17,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from __future__ import annotations
from abc import abstractmethod
from types import EllipsisType
from typing import Any, Callable, Generic, Iterable, Mapping, TypeVar
import logging
@ -25,10 +24,7 @@ _T = TypeVar('_T')
logger = logging.getLogger(__name__)
class MissingType(object):
__slots__ = ()
MISSING = MissingType()
MISSING = object()
def _not_missing(v) -> bool:
return v and v is not MISSING
@ -47,10 +43,10 @@ class Wanted(Generic[_T]):
Owner class will call .__set_name__() on the parent Incomplete instance;
the __set_name__ parameters (owner class and name) will be passed down here.
"""
_target: Callable | str | None | EllipsisType
def __init__(self, getter: Callable | str | None | EllipsisType):
_target: Callable | str | None | Ellipsis
def __init__(self, getter: Callable | str | None | Ellipsis):
self._target = getter
def __call__(self, owner: type, name: str | None = None) -> _T | str | None:
def __call__(self, owner: type, name: str | None = None) -> _T:
if self._target is None or self._target is Ellipsis:
return name
elif isinstance(self._target, str):
@ -71,10 +67,10 @@ class Incomplete(Generic[_T]):
Missing arguments must be passed in the appropriate positions
(positional or keyword) as a Wanted() object.
"""
_obj: Callable[..., _T]
_obj = Callable[Any, _T]
_args: Iterable
_kwargs: dict
def __init__(self, obj: Callable[..., _T] | Wanted, *args, **kwargs):
def __init__(self, obj: Callable[Any, _T] | Wanted, *args, **kwargs):
if isinstance(obj, Wanted):
self._obj = lambda x: x
self._args = (obj, )
@ -124,7 +120,7 @@ class ValueSource(Mapping):
class ValueProperty(Generic[_T]):
_name: str | None
_srcs: dict[str, str]
_val: Any | MissingType
_val: Any | MISSING
_default: Any | None
_cast: Callable | None
_required: bool

View file

@ -1,7 +1,5 @@
"""
Utilities for SQLAlchemy; ORM
NEW 0.6.0
Utilities for SQLAlchemy
---
@ -16,38 +14,53 @@ This software is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
from __future__ import annotations
from binascii import Incomplete
from typing import Any, Callable
from abc import ABCMeta, abstractmethod
from functools import wraps
from typing import Callable, Iterable, Never, TypeVar
import warnings
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, declarative_base as _declarative_base, relationship
from sqlalchemy.types import TypeEngine
from suou.classtools import Wanted
from suou.codecs import StringCase
from suou.iding import Siq, SiqCache, SiqGen, SiqType
from suou.itertools import kwargs_prefix
from suou.snowflake import SnowflakeGen
from suou.sqlalchemy import IdType
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
from .snowflake import SnowflakeGen
from .itertools import kwargs_prefix, makelist
from .signing import HasSigner, UserSigner
from .codecs import StringCase
from .functools import deprecated, not_implemented
from .iding import Siq, SiqGen, SiqType, SiqCache
from .classtools import Incomplete, Wanted
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
_T = TypeVar('_T')
# SIQs are 14 bytes long. Storage is padded for alignment
# Not to be confused with SiqType.
IdType: type[LargeBinary] = LargeBinary(16)
@not_implemented
def sql_escape(s: str, /, dialect: Dialect) -> str:
"""
Return a table's column given its name.
Escape a value for SQL embedding, using SQLAlchemy's literal processors.
Requires a dialect argument.
XXX does it belong outside any scopes?
XXX this function is not mature yet, do not use
"""
if isinstance(col, Incomplete):
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
elif isinstance(col, Column):
return col
elif isinstance(col, str):
return getattr(cls, col)
else:
raise TypeError
if isinstance(s, str):
return String().literal_processor(dialect=dialect)(s)
raise TypeError('invalid data type')
def create_session(url: str) -> Session:
"""
Create a session on the fly, given a database URL. Useful for
contextless environments, such as Python REPL.
Heads up: a function with the same name exists in core sqlalchemy, but behaves
completely differently!!
"""
engine = create_engine(url)
return Session(bind = engine)
def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs):
"""
Marks a column which contains a SIQ.
@ -106,6 +119,7 @@ def match_column(length: int, regex: str, /, case: StringCase = StringCase.AS_IS
return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name,
constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs)
def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]:
"""
Column for a single boolean value.
@ -134,11 +148,33 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No
)
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
return Base
entity_base = warnings.deprecated('use declarative_base() instead')(declarative_base)
entity_base = deprecated('use declarative_base() instead')(declarative_base)
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
"""
Generate a user signing function.
def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
and a user secret.
"""
if isinstance(id_attr, Column):
id_val = id_attr
elif isinstance(id_attr, str):
id_val = Wanted(id_attr)
if isinstance(secret_attr, Column):
secret_val = secret_attr
elif isinstance(secret_attr, str):
secret_val = Wanted(secret_attr)
def token_signer_factory(owner: DeclarativeBase, name: str):
def my_signer(self):
return UserSigner(owner.metadata.info['secret_key'], id_val.__get__(self, owner), secret_val.__get__(self, owner))
my_signer.__name__ = name
return my_signer
return Incomplete(Wanted(token_signer_factory))
def author_pair(fk_name: str, *, id_type: type = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
"""
Return an owner ID/signature column pair, for authenticated values.
"""
@ -167,8 +203,7 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]:
return (date_col, acc_col)
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]:
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship], Incomplete[Relationship]]:
"""
Self-referential one-to-many relationship pair.
Parent comes first, children come later.
@ -185,8 +220,8 @@ def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Inco
parent_kwargs = kwargs_prefix(kwargs, 'parent_')
child_kwargs = kwargs_prefix(kwargs, 'child_')
parent: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs)
child: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs)
parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs)
child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs)
return parent, child
@ -232,3 +267,93 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa
return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs)
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
"""
Return a table's column given its name.
XXX does it belong outside any scopes?
"""
if isinstance(col, Incomplete):
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
elif isinstance(col, Column):
return col
elif isinstance(col, str):
return getattr(cls, col)
else:
raise TypeError
## Utilities for use in web apps below
class AuthSrc(metaclass=ABCMeta):
'''
AuthSrc object required for require_auth_base().
This is an abstract class and is NOT usable directly.
This is not part of the public API
'''
def required_exc(self) -> Never:
raise ValueError('required field missing')
def invalid_exc(self, msg: str = 'validation failed') -> Never:
raise ValueError(msg)
@abstractmethod
def get_session(self) -> Session:
pass
def get_user(self, getter: Callable):
return getter(self.get_token())
@abstractmethod
def get_token(self):
pass
@abstractmethod
def get_signature(self):
pass
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
'''
Inject the current user into a view, given the Authorization: Bearer header.
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
'''
col = want_column(cls, column)
validators = makelist(validators)
def get_user(token) -> DeclarativeBase:
if token is None:
return None
tok_parts = UserSigner.split_token(token)
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
try:
signer: UserSigner = user.signer()
signer.unsign(token)
return user
except Exception:
return None
def _default_invalid(msg: str = 'Validation failed'):
raise ValueError(msg)
invalid_exc = src.invalid_exc or _default_invalid
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
def decorator(func: Callable):
@wraps(func)
def wrapper(*a, **ka):
ka[dest] = get_user(src.get_token())
if not ka[dest] and required:
required_exc()
if signed:
ka[sig_dest] = src.get_signature()
for valid in validators:
if not valid(ka[dest]):
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
return func(*a, **ka)
return wrapper
return decorator
# Optional dependency: do not import into __init__.py
__all__ = (
'IdType', 'id_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint',
'author_pair', 'age_pair', 'require_auth_base', 'bound_fk', 'unbound_fk', 'want_column'
)

View file

@ -1,169 +0,0 @@
"""
Utilities for SQLAlchemy
---
Copyright (c) 2025 Sakuragasaki46.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See LICENSE for the specific language governing permissions and
limitations under the License.
This software is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from functools import wraps
from typing import Any, Callable, Iterable, Never, TypeVar
import warnings
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
from sqlalchemy.types import TypeEngine
from ..snowflake import SnowflakeGen
from ..itertools import kwargs_prefix, makelist
from ..signing import HasSigner, UserSigner
from ..codecs import StringCase
from ..functools import deprecated, not_implemented
from ..iding import Siq, SiqGen, SiqType, SiqCache
from ..classtools import Incomplete, Wanted
_T = TypeVar('_T')
# SIQs are 14 bytes long. Storage is padded for alignment
# Not to be confused with SiqType.
IdType: TypeEngine = LargeBinary(16)
def create_session(url: str) -> Session:
"""
Create a session on the fly, given a database URL. Useful for
contextless environments, such as Python REPL.
Heads up: a function with the same name exists in core sqlalchemy, but behaves
completely differently!!
"""
engine = create_engine(url)
return Session(bind = engine)
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
"""
Generate a user signing function.
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
and a user secret.
"""
id_val: Column | Wanted[Column]
if isinstance(id_attr, Column):
id_val = id_attr
elif isinstance(id_attr, str):
id_val = Wanted(id_attr)
if isinstance(secret_attr, Column):
secret_val = secret_attr
elif isinstance(secret_attr, str):
secret_val = Wanted(secret_attr)
def token_signer_factory(owner: DeclarativeBase, name: str):
def my_signer(self):
return UserSigner(
owner.metadata.info['secret_key'],
id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue]
)
my_signer.__name__ = name
return my_signer
return Incomplete(Wanted(token_signer_factory))
## Utilities for use in web apps below
@deprecated('not part of the public API and not even working')
class AuthSrc(metaclass=ABCMeta):
'''
AuthSrc object required for require_auth_base().
This is an abstract class and is NOT usable directly.
This is not part of the public API
'''
def required_exc(self) -> Never:
raise ValueError('required field missing')
def invalid_exc(self, msg: str = 'validation failed') -> Never:
raise ValueError(msg)
@abstractmethod
def get_session(self) -> Session:
pass
def get_user(self, getter: Callable):
return getter(self.get_token())
@abstractmethod
def get_token(self):
pass
@abstractmethod
def get_signature(self):
pass
@deprecated('not working and too complex to use')
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
'''
Inject the current user into a view, given the Authorization: Bearer header.
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
'''
col = want_column(cls, column)
validators = makelist(validators)
def get_user(token) -> DeclarativeBase:
if token is None:
return None
tok_parts = UserSigner.split_token(token)
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
try:
signer: UserSigner = user.signer()
signer.unsign(token)
return user
except Exception:
return None
def _default_invalid(msg: str = 'Validation failed'):
raise ValueError(msg)
invalid_exc = src.invalid_exc or _default_invalid
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
def decorator(func: Callable):
@wraps(func)
def wrapper(*a, **ka):
ka[dest] = get_user(src.get_token())
if not ka[dest] and required:
required_exc()
if signed:
ka[sig_dest] = src.get_signature()
for valid in validators:
if not valid(ka[dest]):
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
return func(*a, **ka)
return wrapper
return decorator
from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query
from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column
# Optional dependency: do not import into __init__.py
__all__ = (
'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer',
'match_column', 'match_constraint', 'bool_column', 'parent_children',
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
# .asyncio
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
)

View file

@ -1,176 +0,0 @@
"""
Helpers for asynchronous use of SQLAlchemy.
NEW 0.5.0; moved to current location 0.6.0
---
Copyright (c) 2025 Sakuragasaki46.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See LICENSE for the specific language governing permissions and
limitations under the License.
This software is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
from __future__ import annotations
from functools import wraps
from sqlalchemy import Engine, Select, func, select
from sqlalchemy.orm import DeclarativeBase, lazyload
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from flask_sqlalchemy.pagination import Pagination
from suou.exceptions import NotFoundError
class SQLAlchemy:
"""
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
eligible for async environments
NEW 0.5.0
"""
base: DeclarativeBase
engine: AsyncEngine
_sessions: list[AsyncSession]
NotFound = NotFoundError
def __init__(self, model_class: DeclarativeBase):
self.base = model_class
self.engine = None
self._sessions = []
def bind(self, url: str):
self.engine = create_async_engine(url)
def _ensure_engine(self):
if self.engine is None:
raise RuntimeError('database is not connected')
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
self._ensure_engine()
## XXX is it accurate?
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
self._sessions.append(s)
return s
async def __aenter__(self) -> AsyncSession:
return await self.begin()
async def __aexit__(self, e1, e2, e3):
## XXX is it accurate?
s = self._sessions.pop()
if e1:
await s.rollback()
else:
await s.commit()
await s.close()
async def paginate(self, select: Select, *,
page: int | None = None, per_page: int | None = None,
max_per_page: int | None = None, error_out: bool = True,
count: bool = True) -> AsyncSelectPagination:
"""
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
"""
async with self as session:
return AsyncSelectPagination(
select = select,
session = session,
page = page,
per_page=per_page, max_per_page=max_per_page,
error_out=self.NotFound if error_out else None, count=count
)
async def create_all(self, *, checkfirst = True):
"""
Initialize database
"""
self._ensure_engine()
self.base.metadata.create_all(
self.engine, checkfirst=checkfirst
)
class AsyncSelectPagination(Pagination):
"""
flask_sqlalchemy.SelectPagination but asynchronous.
Pagination is not part of the public API, therefore expect that it may break
"""
async def _query_items(self) -> list:
select_q: Select = self._query_args["select"]
select = select_q.limit(self.per_page).offset(self._query_offset)
session: AsyncSession = self._query_args["session"]
out = (await session.execute(select)).scalars()
return out
async def _query_count(self) -> int:
select_q: Select = self._query_args["select"]
sub = select_q.options(lazyload("*")).order_by(None).subquery()
session: AsyncSession = self._query_args["session"]
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
return out
def __init__(self,
page: int | None = None,
per_page: int | None = None,
max_per_page: int | None = 100,
error_out: Exception | None = NotFoundError,
count: bool = True,
**kwargs):
## XXX flask-sqlalchemy says Pagination() is not public API.
## Things may break; beware.
self._query_args = kwargs
page, per_page = self._prepare_page_args(
page=page,
per_page=per_page,
max_per_page=max_per_page,
error_out=error_out,
)
self.page: int = page
"""The current page."""
self.per_page: int = per_page
"""The maximum number of items on a page."""
self.max_per_page: int | None = max_per_page
"""The maximum allowed value for ``per_page``."""
self.items = None
self.total = None
self.error_out = error_out
self.has_count = count
async def __aiter__(self):
self.items = await self._query_items()
if self.items is None:
raise RuntimeError('query returned None')
if not self.items and self.page != 1 and self.error_out:
raise self.error_out
if self.has_count:
self.total = await self._query_count()
for i in self.items:
yield i
def async_query(db: SQLAlchemy, multi: False):
"""
Wraps a query returning function into an executor coroutine.
The query function remains available as the .q or .query attribute.
"""
def decorator(func):
@wraps(func)
async def executor(*args, **kwargs):
async with db as session:
result = await session.execute(func(*args, **kwargs))
return result.scalars() if multi else result.scalar()
executor.query = executor.q = func
return executor
return decorator
# Optional dependency: do not import into __init__.py
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')

View file

@ -1,7 +1,7 @@
"""
Helpers for asynchronous use of SQLAlchemy.
NEW 0.5.0; MOVED to sqlalchemy.asyncio in 0.6.0
NEW 0.5.0
---
@ -17,16 +17,159 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
from __future__ import annotations
from functools import wraps
from .functools import deprecated
from sqlalchemy import Engine, Select, func, select
from sqlalchemy.orm import DeclarativeBase, lazyload
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from flask_sqlalchemy.pagination import Pagination
from suou.exceptions import NotFoundError
class SQLAlchemy:
"""
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
eligible for async environments
NEW 0.5.0
"""
base: DeclarativeBase
engine: AsyncEngine
_sessions: list[AsyncSession]
NotFound = NotFoundError
def __init__(self, model_class: DeclarativeBase):
self.base = model_class
self.engine = None
self._sessions = []
def bind(self, url: str):
self.engine = create_async_engine(url)
def _ensure_engine(self):
if self.engine is None:
raise RuntimeError('database is not connected')
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
self._ensure_engine()
## XXX is it accurate?
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
self._sessions.append(s)
return s
async def __aenter__(self) -> AsyncSession:
return await self.begin()
async def __aexit__(self, e1, e2, e3):
## XXX is it accurate?
s = self._sessions.pop()
if e1:
await s.rollback()
else:
await s.commit()
await s.close()
async def paginate(self, select: Select, *,
page: int | None = None, per_page: int | None = None,
max_per_page: int | None = None, error_out: bool = True,
count: bool = True) -> AsyncSelectPagination:
"""
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
"""
async with self as session:
return AsyncSelectPagination(
select = select,
session = session,
page = page,
per_page=per_page, max_per_page=max_per_page,
error_out=self.NotFound if error_out else None, count=count
)
async def create_all(self, *, checkfirst = True):
"""
Initialize database
"""
self._ensure_engine()
self.base.metadata.create_all(
self.engine, checkfirst=checkfirst
)
from .sqlalchemy.asyncio import SQLAlchemy, AsyncSelectPagination, async_query
class AsyncSelectPagination(Pagination):
"""
flask_sqlalchemy.SelectPagination but asynchronous.
Pagination is not part of the public API, therefore expect that it may break
"""
async def _query_items(self) -> list:
select_q: Select = self._query_args["select"]
select = select_q.limit(self.per_page).offset(self._query_offset)
session: AsyncSession = self._query_args["session"]
out = (await session.execute(select)).scalars()
return out
async def _query_count(self) -> int:
select_q: Select = self._query_args["select"]
sub = select_q.options(lazyload("*")).order_by(None).subquery()
session: AsyncSession = self._query_args["session"]
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
return out
def __init__(self,
page: int | None = None,
per_page: int | None = None,
max_per_page: int | None = 100,
error_out: Exception | None = NotFoundError,
count: bool = True,
**kwargs):
## XXX flask-sqlalchemy says Pagination() is not public API.
## Things may break; beware.
self._query_args = kwargs
page, per_page = self._prepare_page_args(
page=page,
per_page=per_page,
max_per_page=max_per_page,
error_out=error_out,
)
self.page: int = page
"""The current page."""
self.per_page: int = per_page
"""The maximum number of items on a page."""
self.max_per_page: int | None = max_per_page
"""The maximum allowed value for ``per_page``."""
self.items = None
self.total = None
self.error_out = error_out
self.has_count = count
async def __aiter__(self):
self.items = await self._query_items()
if self.items is None:
raise RuntimeError('query returned None')
if not self.items and self.page != 1 and self.error_out:
raise self.error_out
if self.has_count:
self.total = await self._query_count()
for i in self.items:
yield i
def async_query(db: SQLAlchemy, multi: False):
"""
Wraps a query returning function into an executor coroutine.
The query function remains available as the .q or .query attribute.
"""
def decorator(func):
@wraps(func)
async def executor(*args, **kwargs):
async with db as session:
result = await session.execute(func(*args, **kwargs))
return result.scalars() if multi else result.scalar()
executor.query = executor.q = func
return executor
return decorator
SQLAlchemy = deprecated('import from suou.sqlalchemy.asyncio instead')(SQLAlchemy)
AsyncSelectPagination = deprecated('import from suou.sqlalchemy.asyncio instead')(AsyncSelectPagination)
async_query = deprecated('import from suou.sqlalchemy.asyncio instead')(async_query)
# Optional dependency: do not import into __init__.py
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
__all__ = ('SQLAlchemy', 'async_query')