diff --git a/CHANGELOG.md b/CHANGELOG.md index fffb0ea..3e60ffb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,5 @@ # Changelog -## 0.6.0 - -... - -## 0.5.3 - -... - ## 0.5.2 - Fixed poorly handled merge conflict leaving `.sqlalchemy` modulem unusable diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 96db617..59b3aa0 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -34,7 +34,7 @@ from .validators import matches from .redact import redact_url_password from .http import WantsContentType -__version__ = "0.6.0-dev35" +__version__ = "0.5.2" __all__ = ( 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', diff --git a/src/suou/classtools.py b/src/suou/classtools.py index c58a123..c27fa61 100644 --- a/src/suou/classtools.py +++ b/src/suou/classtools.py @@ -17,7 +17,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from __future__ import annotations from abc import abstractmethod -from types import EllipsisType from typing import Any, Callable, Generic, Iterable, Mapping, TypeVar import logging @@ -25,10 +24,7 @@ _T = TypeVar('_T') logger = logging.getLogger(__name__) -class MissingType(object): - __slots__ = () - -MISSING = MissingType() +MISSING = object() def _not_missing(v) -> bool: return v and v is not MISSING @@ -47,10 +43,10 @@ class Wanted(Generic[_T]): Owner class will call .__set_name__() on the parent Incomplete instance; the __set_name__ parameters (owner class and name) will be passed down here. """ - _target: Callable | str | None | EllipsisType - def __init__(self, getter: Callable | str | None | EllipsisType): + _target: Callable | str | None | Ellipsis + def __init__(self, getter: Callable | str | None | Ellipsis): self._target = getter - def __call__(self, owner: type, name: str | None = None) -> _T | str | None: + def __call__(self, owner: type, name: str | None = None) -> _T: if self._target is None or self._target is Ellipsis: return name elif isinstance(self._target, str): @@ -71,10 +67,10 @@ class Incomplete(Generic[_T]): Missing arguments must be passed in the appropriate positions (positional or keyword) as a Wanted() object. """ - _obj: Callable[..., _T] + _obj = Callable[Any, _T] _args: Iterable _kwargs: dict - def __init__(self, obj: Callable[..., _T] | Wanted, *args, **kwargs): + def __init__(self, obj: Callable[Any, _T] | Wanted, *args, **kwargs): if isinstance(obj, Wanted): self._obj = lambda x: x self._args = (obj, ) @@ -124,7 +120,7 @@ class ValueSource(Mapping): class ValueProperty(Generic[_T]): _name: str | None _srcs: dict[str, str] - _val: Any | MissingType + _val: Any | MISSING _default: Any | None _cast: Callable | None _required: bool diff --git a/src/suou/sqlalchemy/orm.py b/src/suou/sqlalchemy.py similarity index 57% rename from src/suou/sqlalchemy/orm.py rename to src/suou/sqlalchemy.py index 6d75a4f..2f434e2 100644 --- a/src/suou/sqlalchemy/orm.py +++ b/src/suou/sqlalchemy.py @@ -1,7 +1,5 @@ """ -Utilities for SQLAlchemy; ORM - -NEW 0.6.0 +Utilities for SQLAlchemy --- @@ -16,38 +14,53 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ +from __future__ import annotations - -from binascii import Incomplete -from typing import Any, Callable +from abc import ABCMeta, abstractmethod +from functools import wraps +from typing import Callable, Iterable, Never, TypeVar import warnings -from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text -from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, declarative_base as _declarative_base, relationship -from sqlalchemy.types import TypeEngine -from suou.classtools import Wanted -from suou.codecs import StringCase -from suou.iding import Siq, SiqCache, SiqGen, SiqType -from suou.itertools import kwargs_prefix -from suou.snowflake import SnowflakeGen -from suou.sqlalchemy import IdType +from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text +from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship +from .snowflake import SnowflakeGen +from .itertools import kwargs_prefix, makelist +from .signing import HasSigner, UserSigner +from .codecs import StringCase +from .functools import deprecated, not_implemented +from .iding import Siq, SiqGen, SiqType, SiqCache +from .classtools import Incomplete, Wanted -def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]: +_T = TypeVar('_T') + +# SIQs are 14 bytes long. Storage is padded for alignment +# Not to be confused with SiqType. +IdType: type[LargeBinary] = LargeBinary(16) + +@not_implemented +def sql_escape(s: str, /, dialect: Dialect) -> str: """ - Return a table's column given its name. + Escape a value for SQL embedding, using SQLAlchemy's literal processors. + Requires a dialect argument. - XXX does it belong outside any scopes? + XXX this function is not mature yet, do not use """ - if isinstance(col, Incomplete): - raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.') - elif isinstance(col, Column): - return col - elif isinstance(col, str): - return getattr(cls, col) - else: - raise TypeError + if isinstance(s, str): + return String().literal_processor(dialect=dialect)(s) + raise TypeError('invalid data type') +def create_session(url: str) -> Session: + """ + Create a session on the fly, given a database URL. Useful for + contextless environments, such as Python REPL. + + Heads up: a function with the same name exists in core sqlalchemy, but behaves + completely differently!! + """ + engine = create_engine(url) + return Session(bind = engine) + def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs): """ Marks a column which contains a SIQ. @@ -106,6 +119,7 @@ def match_column(length: int, regex: str, /, case: StringCase = StringCase.AS_IS return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name, constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs) + def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]: """ Column for a single boolean value. @@ -134,11 +148,33 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No ) Base = _declarative_base(metadata=MetaData(**metadata), **kwargs) return Base -entity_base = warnings.deprecated('use declarative_base() instead')(declarative_base) +entity_base = deprecated('use declarative_base() instead')(declarative_base) +def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]: + """ + Generate a user signing function. -def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]: + Requires a master secret (taken from Base.metadata), a user id (visible in the token) + and a user secret. + """ + if isinstance(id_attr, Column): + id_val = id_attr + elif isinstance(id_attr, str): + id_val = Wanted(id_attr) + if isinstance(secret_attr, Column): + secret_val = secret_attr + elif isinstance(secret_attr, str): + secret_val = Wanted(secret_attr) + def token_signer_factory(owner: DeclarativeBase, name: str): + def my_signer(self): + return UserSigner(owner.metadata.info['secret_key'], id_val.__get__(self, owner), secret_val.__get__(self, owner)) + my_signer.__name__ = name + return my_signer + return Incomplete(Wanted(token_signer_factory)) + + +def author_pair(fk_name: str, *, id_type: type = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]: """ Return an owner ID/signature column pair, for authenticated values. """ @@ -167,8 +203,7 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]: return (date_col, acc_col) - -def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]: +def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship], Incomplete[Relationship]]: """ Self-referential one-to-many relationship pair. Parent comes first, children come later. @@ -185,8 +220,8 @@ def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Inco parent_kwargs = kwargs_prefix(kwargs, 'parent_') child_kwargs = kwargs_prefix(kwargs, 'child_') - parent: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs) - child: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs) + parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs) + child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs) return parent, child @@ -232,3 +267,93 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs) +def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]: + """ + Return a table's column given its name. + + XXX does it belong outside any scopes? + """ + if isinstance(col, Incomplete): + raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.') + elif isinstance(col, Column): + return col + elif isinstance(col, str): + return getattr(cls, col) + else: + raise TypeError + +## Utilities for use in web apps below + +class AuthSrc(metaclass=ABCMeta): + ''' + AuthSrc object required for require_auth_base(). + + This is an abstract class and is NOT usable directly. + + This is not part of the public API + ''' + def required_exc(self) -> Never: + raise ValueError('required field missing') + def invalid_exc(self, msg: str = 'validation failed') -> Never: + raise ValueError(msg) + @abstractmethod + def get_session(self) -> Session: + pass + def get_user(self, getter: Callable): + return getter(self.get_token()) + @abstractmethod + def get_token(self): + pass + @abstractmethod + def get_signature(self): + pass + + +def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): + ''' + Inject the current user into a view, given the Authorization: Bearer header. + + For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object. + ''' + col = want_column(cls, column) + validators = makelist(validators) + + def get_user(token) -> DeclarativeBase: + if token is None: + return None + tok_parts = UserSigner.split_token(token) + user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar() + try: + signer: UserSigner = user.signer() + signer.unsign(token) + return user + except Exception: + return None + + def _default_invalid(msg: str = 'Validation failed'): + raise ValueError(msg) + + invalid_exc = src.invalid_exc or _default_invalid + required_exc = src.required_exc or (lambda: _default_invalid('Login required')) + + def decorator(func: Callable): + @wraps(func) + def wrapper(*a, **ka): + ka[dest] = get_user(src.get_token()) + if not ka[dest] and required: + required_exc() + if signed: + ka[sig_dest] = src.get_signature() + for valid in validators: + if not valid(ka[dest]): + invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest])) + return func(*a, **ka) + return wrapper + return decorator + +# Optional dependency: do not import into __init__.py +__all__ = ( + 'IdType', 'id_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint', + 'author_pair', 'age_pair', 'require_auth_base', 'bound_fk', 'unbound_fk', 'want_column' +) \ No newline at end of file diff --git a/src/suou/sqlalchemy/__init__.py b/src/suou/sqlalchemy/__init__.py deleted file mode 100644 index 81f61f8..0000000 --- a/src/suou/sqlalchemy/__init__.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Utilities for SQLAlchemy - ---- - -Copyright (c) 2025 Sakuragasaki46. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -See LICENSE for the specific language governing permissions and -limitations under the License. - -This software is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -""" - -from __future__ import annotations - -from abc import ABCMeta, abstractmethod -from functools import wraps -from typing import Any, Callable, Iterable, Never, TypeVar -import warnings -from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text -from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship -from sqlalchemy.types import TypeEngine - -from ..snowflake import SnowflakeGen -from ..itertools import kwargs_prefix, makelist -from ..signing import HasSigner, UserSigner -from ..codecs import StringCase -from ..functools import deprecated, not_implemented -from ..iding import Siq, SiqGen, SiqType, SiqCache -from ..classtools import Incomplete, Wanted - - - -_T = TypeVar('_T') - -# SIQs are 14 bytes long. Storage is padded for alignment -# Not to be confused with SiqType. -IdType: TypeEngine = LargeBinary(16) - -def create_session(url: str) -> Session: - """ - Create a session on the fly, given a database URL. Useful for - contextless environments, such as Python REPL. - - Heads up: a function with the same name exists in core sqlalchemy, but behaves - completely differently!! - """ - engine = create_engine(url) - return Session(bind = engine) - - - -def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]: - """ - Generate a user signing function. - - Requires a master secret (taken from Base.metadata), a user id (visible in the token) - and a user secret. - """ - id_val: Column | Wanted[Column] - if isinstance(id_attr, Column): - id_val = id_attr - elif isinstance(id_attr, str): - id_val = Wanted(id_attr) - if isinstance(secret_attr, Column): - secret_val = secret_attr - elif isinstance(secret_attr, str): - secret_val = Wanted(secret_attr) - def token_signer_factory(owner: DeclarativeBase, name: str): - def my_signer(self): - return UserSigner( - owner.metadata.info['secret_key'], - id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue] - ) - my_signer.__name__ = name - return my_signer - return Incomplete(Wanted(token_signer_factory)) - - - - - -## Utilities for use in web apps below - -@deprecated('not part of the public API and not even working') -class AuthSrc(metaclass=ABCMeta): - ''' - AuthSrc object required for require_auth_base(). - - This is an abstract class and is NOT usable directly. - - This is not part of the public API - ''' - def required_exc(self) -> Never: - raise ValueError('required field missing') - def invalid_exc(self, msg: str = 'validation failed') -> Never: - raise ValueError(msg) - @abstractmethod - def get_session(self) -> Session: - pass - def get_user(self, getter: Callable): - return getter(self.get_token()) - @abstractmethod - def get_token(self): - pass - @abstractmethod - def get_signature(self): - pass - - -@deprecated('not working and too complex to use') -def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', - required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): - ''' - Inject the current user into a view, given the Authorization: Bearer header. - - For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object. - ''' - col = want_column(cls, column) - validators = makelist(validators) - - def get_user(token) -> DeclarativeBase: - if token is None: - return None - tok_parts = UserSigner.split_token(token) - user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar() - try: - signer: UserSigner = user.signer() - signer.unsign(token) - return user - except Exception: - return None - - def _default_invalid(msg: str = 'Validation failed'): - raise ValueError(msg) - - invalid_exc = src.invalid_exc or _default_invalid - required_exc = src.required_exc or (lambda: _default_invalid('Login required')) - - def decorator(func: Callable): - @wraps(func) - def wrapper(*a, **ka): - ka[dest] = get_user(src.get_token()) - if not ka[dest] and required: - required_exc() - if signed: - ka[sig_dest] = src.get_signature() - for valid in validators: - if not valid(ka[dest]): - invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest])) - return func(*a, **ka) - return wrapper - return decorator - - -from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query -from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column - -# Optional dependency: do not import into __init__.py -__all__ = ( - 'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer', - 'match_column', 'match_constraint', 'bool_column', 'parent_children', - 'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column', - # .asyncio - 'SQLAlchemy', 'AsyncSelectPagination', 'async_query' -) \ No newline at end of file diff --git a/src/suou/sqlalchemy/asyncio.py b/src/suou/sqlalchemy/asyncio.py deleted file mode 100644 index 786d9f8..0000000 --- a/src/suou/sqlalchemy/asyncio.py +++ /dev/null @@ -1,176 +0,0 @@ - -""" -Helpers for asynchronous use of SQLAlchemy. - -NEW 0.5.0; moved to current location 0.6.0 - ---- - -Copyright (c) 2025 Sakuragasaki46. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -See LICENSE for the specific language governing permissions and -limitations under the License. - -This software is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -""" - -from __future__ import annotations -from functools import wraps - - -from sqlalchemy import Engine, Select, func, select -from sqlalchemy.orm import DeclarativeBase, lazyload -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from flask_sqlalchemy.pagination import Pagination - -from suou.exceptions import NotFoundError - -class SQLAlchemy: - """ - Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy() - eligible for async environments - - NEW 0.5.0 - """ - base: DeclarativeBase - engine: AsyncEngine - _sessions: list[AsyncSession] - NotFound = NotFoundError - - def __init__(self, model_class: DeclarativeBase): - self.base = model_class - self.engine = None - self._sessions = [] - def bind(self, url: str): - self.engine = create_async_engine(url) - def _ensure_engine(self): - if self.engine is None: - raise RuntimeError('database is not connected') - async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession: - self._ensure_engine() - ## XXX is it accurate? - s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw) - self._sessions.append(s) - return s - async def __aenter__(self) -> AsyncSession: - return await self.begin() - async def __aexit__(self, e1, e2, e3): - ## XXX is it accurate? - s = self._sessions.pop() - if e1: - await s.rollback() - else: - await s.commit() - await s.close() - async def paginate(self, select: Select, *, - page: int | None = None, per_page: int | None = None, - max_per_page: int | None = None, error_out: bool = True, - count: bool = True) -> AsyncSelectPagination: - """ - Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate(). - """ - async with self as session: - return AsyncSelectPagination( - select = select, - session = session, - page = page, - per_page=per_page, max_per_page=max_per_page, - error_out=self.NotFound if error_out else None, count=count - ) - async def create_all(self, *, checkfirst = True): - """ - Initialize database - """ - self._ensure_engine() - self.base.metadata.create_all( - self.engine, checkfirst=checkfirst - ) - - - -class AsyncSelectPagination(Pagination): - """ - flask_sqlalchemy.SelectPagination but asynchronous. - - Pagination is not part of the public API, therefore expect that it may break - """ - - async def _query_items(self) -> list: - select_q: Select = self._query_args["select"] - select = select_q.limit(self.per_page).offset(self._query_offset) - session: AsyncSession = self._query_args["session"] - out = (await session.execute(select)).scalars() - return out - - async def _query_count(self) -> int: - select_q: Select = self._query_args["select"] - sub = select_q.options(lazyload("*")).order_by(None).subquery() - session: AsyncSession = self._query_args["session"] - out = (await session.execute(select(func.count()).select_from(sub))).scalar() - return out - - def __init__(self, - page: int | None = None, - per_page: int | None = None, - max_per_page: int | None = 100, - error_out: Exception | None = NotFoundError, - count: bool = True, - **kwargs): - ## XXX flask-sqlalchemy says Pagination() is not public API. - ## Things may break; beware. - self._query_args = kwargs - page, per_page = self._prepare_page_args( - page=page, - per_page=per_page, - max_per_page=max_per_page, - error_out=error_out, - ) - - self.page: int = page - """The current page.""" - - self.per_page: int = per_page - """The maximum number of items on a page.""" - - self.max_per_page: int | None = max_per_page - """The maximum allowed value for ``per_page``.""" - - self.items = None - self.total = None - self.error_out = error_out - self.has_count = count - - async def __aiter__(self): - self.items = await self._query_items() - if self.items is None: - raise RuntimeError('query returned None') - if not self.items and self.page != 1 and self.error_out: - raise self.error_out - if self.has_count: - self.total = await self._query_count() - for i in self.items: - yield i - - -def async_query(db: SQLAlchemy, multi: False): - """ - Wraps a query returning function into an executor coroutine. - - The query function remains available as the .q or .query attribute. - """ - def decorator(func): - @wraps(func) - async def executor(*args, **kwargs): - async with db as session: - result = await session.execute(func(*args, **kwargs)) - return result.scalars() if multi else result.scalar() - executor.query = executor.q = func - return executor - return decorator - - -# Optional dependency: do not import into __init__.py -__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query') \ No newline at end of file diff --git a/src/suou/sqlalchemy_async.py b/src/suou/sqlalchemy_async.py index 47b3396..575f239 100644 --- a/src/suou/sqlalchemy_async.py +++ b/src/suou/sqlalchemy_async.py @@ -1,7 +1,7 @@ """ Helpers for asynchronous use of SQLAlchemy. -NEW 0.5.0; MOVED to sqlalchemy.asyncio in 0.6.0 +NEW 0.5.0 --- @@ -17,16 +17,159 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ from __future__ import annotations +from functools import wraps -from .functools import deprecated + +from sqlalchemy import Engine, Select, func, select +from sqlalchemy.orm import DeclarativeBase, lazyload +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from flask_sqlalchemy.pagination import Pagination + +from suou.exceptions import NotFoundError + +class SQLAlchemy: + """ + Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy() + eligible for async environments + + NEW 0.5.0 + """ + base: DeclarativeBase + engine: AsyncEngine + _sessions: list[AsyncSession] + NotFound = NotFoundError + + def __init__(self, model_class: DeclarativeBase): + self.base = model_class + self.engine = None + self._sessions = [] + def bind(self, url: str): + self.engine = create_async_engine(url) + def _ensure_engine(self): + if self.engine is None: + raise RuntimeError('database is not connected') + async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession: + self._ensure_engine() + ## XXX is it accurate? + s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw) + self._sessions.append(s) + return s + async def __aenter__(self) -> AsyncSession: + return await self.begin() + async def __aexit__(self, e1, e2, e3): + ## XXX is it accurate? + s = self._sessions.pop() + if e1: + await s.rollback() + else: + await s.commit() + await s.close() + async def paginate(self, select: Select, *, + page: int | None = None, per_page: int | None = None, + max_per_page: int | None = None, error_out: bool = True, + count: bool = True) -> AsyncSelectPagination: + """ + Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate(). + """ + async with self as session: + return AsyncSelectPagination( + select = select, + session = session, + page = page, + per_page=per_page, max_per_page=max_per_page, + error_out=self.NotFound if error_out else None, count=count + ) + async def create_all(self, *, checkfirst = True): + """ + Initialize database + """ + self._ensure_engine() + self.base.metadata.create_all( + self.engine, checkfirst=checkfirst + ) -from .sqlalchemy.asyncio import SQLAlchemy, AsyncSelectPagination, async_query +class AsyncSelectPagination(Pagination): + """ + flask_sqlalchemy.SelectPagination but asynchronous. -SQLAlchemy = deprecated('import from suou.sqlalchemy.asyncio instead')(SQLAlchemy) -AsyncSelectPagination = deprecated('import from suou.sqlalchemy.asyncio instead')(AsyncSelectPagination) -async_query = deprecated('import from suou.sqlalchemy.asyncio instead')(async_query) + Pagination is not part of the public API, therefore expect that it may break + """ + + async def _query_items(self) -> list: + select_q: Select = self._query_args["select"] + select = select_q.limit(self.per_page).offset(self._query_offset) + session: AsyncSession = self._query_args["session"] + out = (await session.execute(select)).scalars() + return out + + async def _query_count(self) -> int: + select_q: Select = self._query_args["select"] + sub = select_q.options(lazyload("*")).order_by(None).subquery() + session: AsyncSession = self._query_args["session"] + out = (await session.execute(select(func.count()).select_from(sub))).scalar() + return out + + def __init__(self, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = 100, + error_out: Exception | None = NotFoundError, + count: bool = True, + **kwargs): + ## XXX flask-sqlalchemy says Pagination() is not public API. + ## Things may break; beware. + self._query_args = kwargs + page, per_page = self._prepare_page_args( + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + ) + + self.page: int = page + """The current page.""" + + self.per_page: int = per_page + """The maximum number of items on a page.""" + + self.max_per_page: int | None = max_per_page + """The maximum allowed value for ``per_page``.""" + + self.items = None + self.total = None + self.error_out = error_out + self.has_count = count + + async def __aiter__(self): + self.items = await self._query_items() + if self.items is None: + raise RuntimeError('query returned None') + if not self.items and self.page != 1 and self.error_out: + raise self.error_out + if self.has_count: + self.total = await self._query_count() + for i in self.items: + yield i + + +def async_query(db: SQLAlchemy, multi: False): + """ + Wraps a query returning function into an executor coroutine. + + The query function remains available as the .q or .query attribute. + """ + def decorator(func): + @wraps(func) + async def executor(*args, **kwargs): + async with db as session: + result = await session.execute(func(*args, **kwargs)) + return result.scalars() if multi else result.scalar() + executor.query = executor.q = func + return executor + return decorator + # Optional dependency: do not import into __init__.py -__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query') \ No newline at end of file +__all__ = ('SQLAlchemy', 'async_query') \ No newline at end of file