diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e60ffb..fffb0ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 0.6.0 + +... + +## 0.5.3 + +... + ## 0.5.2 - Fixed poorly handled merge conflict leaving `.sqlalchemy` modulem unusable diff --git a/src/suou/__init__.py b/src/suou/__init__.py index 59b3aa0..96db617 100644 --- a/src/suou/__init__.py +++ b/src/suou/__init__.py @@ -34,7 +34,7 @@ from .validators import matches from .redact import redact_url_password from .http import WantsContentType -__version__ = "0.5.2" +__version__ = "0.6.0-dev35" __all__ = ( 'ConfigOptions', 'ConfigParserConfigSource', 'ConfigSource', 'ConfigValue', diff --git a/src/suou/classtools.py b/src/suou/classtools.py index c27fa61..c58a123 100644 --- a/src/suou/classtools.py +++ b/src/suou/classtools.py @@ -17,6 +17,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from __future__ import annotations from abc import abstractmethod +from types import EllipsisType from typing import Any, Callable, Generic, Iterable, Mapping, TypeVar import logging @@ -24,7 +25,10 @@ _T = TypeVar('_T') logger = logging.getLogger(__name__) -MISSING = object() +class MissingType(object): + __slots__ = () + +MISSING = MissingType() def _not_missing(v) -> bool: return v and v is not MISSING @@ -43,10 +47,10 @@ class Wanted(Generic[_T]): Owner class will call .__set_name__() on the parent Incomplete instance; the __set_name__ parameters (owner class and name) will be passed down here. """ - _target: Callable | str | None | Ellipsis - def __init__(self, getter: Callable | str | None | Ellipsis): + _target: Callable | str | None | EllipsisType + def __init__(self, getter: Callable | str | None | EllipsisType): self._target = getter - def __call__(self, owner: type, name: str | None = None) -> _T: + def __call__(self, owner: type, name: str | None = None) -> _T | str | None: if self._target is None or self._target is Ellipsis: return name elif isinstance(self._target, str): @@ -67,10 +71,10 @@ class Incomplete(Generic[_T]): Missing arguments must be passed in the appropriate positions (positional or keyword) as a Wanted() object. """ - _obj = Callable[Any, _T] + _obj: Callable[..., _T] _args: Iterable _kwargs: dict - def __init__(self, obj: Callable[Any, _T] | Wanted, *args, **kwargs): + def __init__(self, obj: Callable[..., _T] | Wanted, *args, **kwargs): if isinstance(obj, Wanted): self._obj = lambda x: x self._args = (obj, ) @@ -120,7 +124,7 @@ class ValueSource(Mapping): class ValueProperty(Generic[_T]): _name: str | None _srcs: dict[str, str] - _val: Any | MISSING + _val: Any | MissingType _default: Any | None _cast: Callable | None _required: bool diff --git a/src/suou/sqlalchemy/__init__.py b/src/suou/sqlalchemy/__init__.py new file mode 100644 index 0000000..81f61f8 --- /dev/null +++ b/src/suou/sqlalchemy/__init__.py @@ -0,0 +1,169 @@ +""" +Utilities for SQLAlchemy + +--- + +Copyright (c) 2025 Sakuragasaki46. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +See LICENSE for the specific language governing permissions and +limitations under the License. + +This software is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from functools import wraps +from typing import Any, Callable, Iterable, Never, TypeVar +import warnings +from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text +from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship +from sqlalchemy.types import TypeEngine + +from ..snowflake import SnowflakeGen +from ..itertools import kwargs_prefix, makelist +from ..signing import HasSigner, UserSigner +from ..codecs import StringCase +from ..functools import deprecated, not_implemented +from ..iding import Siq, SiqGen, SiqType, SiqCache +from ..classtools import Incomplete, Wanted + + + +_T = TypeVar('_T') + +# SIQs are 14 bytes long. Storage is padded for alignment +# Not to be confused with SiqType. +IdType: TypeEngine = LargeBinary(16) + +def create_session(url: str) -> Session: + """ + Create a session on the fly, given a database URL. Useful for + contextless environments, such as Python REPL. + + Heads up: a function with the same name exists in core sqlalchemy, but behaves + completely differently!! + """ + engine = create_engine(url) + return Session(bind = engine) + + + +def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]: + """ + Generate a user signing function. + + Requires a master secret (taken from Base.metadata), a user id (visible in the token) + and a user secret. + """ + id_val: Column | Wanted[Column] + if isinstance(id_attr, Column): + id_val = id_attr + elif isinstance(id_attr, str): + id_val = Wanted(id_attr) + if isinstance(secret_attr, Column): + secret_val = secret_attr + elif isinstance(secret_attr, str): + secret_val = Wanted(secret_attr) + def token_signer_factory(owner: DeclarativeBase, name: str): + def my_signer(self): + return UserSigner( + owner.metadata.info['secret_key'], + id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue] + ) + my_signer.__name__ = name + return my_signer + return Incomplete(Wanted(token_signer_factory)) + + + + + +## Utilities for use in web apps below + +@deprecated('not part of the public API and not even working') +class AuthSrc(metaclass=ABCMeta): + ''' + AuthSrc object required for require_auth_base(). + + This is an abstract class and is NOT usable directly. + + This is not part of the public API + ''' + def required_exc(self) -> Never: + raise ValueError('required field missing') + def invalid_exc(self, msg: str = 'validation failed') -> Never: + raise ValueError(msg) + @abstractmethod + def get_session(self) -> Session: + pass + def get_user(self, getter: Callable): + return getter(self.get_token()) + @abstractmethod + def get_token(self): + pass + @abstractmethod + def get_signature(self): + pass + + +@deprecated('not working and too complex to use') +def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', + required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): + ''' + Inject the current user into a view, given the Authorization: Bearer header. + + For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object. + ''' + col = want_column(cls, column) + validators = makelist(validators) + + def get_user(token) -> DeclarativeBase: + if token is None: + return None + tok_parts = UserSigner.split_token(token) + user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar() + try: + signer: UserSigner = user.signer() + signer.unsign(token) + return user + except Exception: + return None + + def _default_invalid(msg: str = 'Validation failed'): + raise ValueError(msg) + + invalid_exc = src.invalid_exc or _default_invalid + required_exc = src.required_exc or (lambda: _default_invalid('Login required')) + + def decorator(func: Callable): + @wraps(func) + def wrapper(*a, **ka): + ka[dest] = get_user(src.get_token()) + if not ka[dest] and required: + required_exc() + if signed: + ka[sig_dest] = src.get_signature() + for valid in validators: + if not valid(ka[dest]): + invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest])) + return func(*a, **ka) + return wrapper + return decorator + + +from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query +from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column + +# Optional dependency: do not import into __init__.py +__all__ = ( + 'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer', + 'match_column', 'match_constraint', 'bool_column', 'parent_children', + 'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column', + # .asyncio + 'SQLAlchemy', 'AsyncSelectPagination', 'async_query' +) \ No newline at end of file diff --git a/src/suou/sqlalchemy/asyncio.py b/src/suou/sqlalchemy/asyncio.py new file mode 100644 index 0000000..786d9f8 --- /dev/null +++ b/src/suou/sqlalchemy/asyncio.py @@ -0,0 +1,176 @@ + +""" +Helpers for asynchronous use of SQLAlchemy. + +NEW 0.5.0; moved to current location 0.6.0 + +--- + +Copyright (c) 2025 Sakuragasaki46. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +See LICENSE for the specific language governing permissions and +limitations under the License. + +This software is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" + +from __future__ import annotations +from functools import wraps + + +from sqlalchemy import Engine, Select, func, select +from sqlalchemy.orm import DeclarativeBase, lazyload +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from flask_sqlalchemy.pagination import Pagination + +from suou.exceptions import NotFoundError + +class SQLAlchemy: + """ + Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy() + eligible for async environments + + NEW 0.5.0 + """ + base: DeclarativeBase + engine: AsyncEngine + _sessions: list[AsyncSession] + NotFound = NotFoundError + + def __init__(self, model_class: DeclarativeBase): + self.base = model_class + self.engine = None + self._sessions = [] + def bind(self, url: str): + self.engine = create_async_engine(url) + def _ensure_engine(self): + if self.engine is None: + raise RuntimeError('database is not connected') + async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession: + self._ensure_engine() + ## XXX is it accurate? + s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw) + self._sessions.append(s) + return s + async def __aenter__(self) -> AsyncSession: + return await self.begin() + async def __aexit__(self, e1, e2, e3): + ## XXX is it accurate? + s = self._sessions.pop() + if e1: + await s.rollback() + else: + await s.commit() + await s.close() + async def paginate(self, select: Select, *, + page: int | None = None, per_page: int | None = None, + max_per_page: int | None = None, error_out: bool = True, + count: bool = True) -> AsyncSelectPagination: + """ + Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate(). + """ + async with self as session: + return AsyncSelectPagination( + select = select, + session = session, + page = page, + per_page=per_page, max_per_page=max_per_page, + error_out=self.NotFound if error_out else None, count=count + ) + async def create_all(self, *, checkfirst = True): + """ + Initialize database + """ + self._ensure_engine() + self.base.metadata.create_all( + self.engine, checkfirst=checkfirst + ) + + + +class AsyncSelectPagination(Pagination): + """ + flask_sqlalchemy.SelectPagination but asynchronous. + + Pagination is not part of the public API, therefore expect that it may break + """ + + async def _query_items(self) -> list: + select_q: Select = self._query_args["select"] + select = select_q.limit(self.per_page).offset(self._query_offset) + session: AsyncSession = self._query_args["session"] + out = (await session.execute(select)).scalars() + return out + + async def _query_count(self) -> int: + select_q: Select = self._query_args["select"] + sub = select_q.options(lazyload("*")).order_by(None).subquery() + session: AsyncSession = self._query_args["session"] + out = (await session.execute(select(func.count()).select_from(sub))).scalar() + return out + + def __init__(self, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = 100, + error_out: Exception | None = NotFoundError, + count: bool = True, + **kwargs): + ## XXX flask-sqlalchemy says Pagination() is not public API. + ## Things may break; beware. + self._query_args = kwargs + page, per_page = self._prepare_page_args( + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + ) + + self.page: int = page + """The current page.""" + + self.per_page: int = per_page + """The maximum number of items on a page.""" + + self.max_per_page: int | None = max_per_page + """The maximum allowed value for ``per_page``.""" + + self.items = None + self.total = None + self.error_out = error_out + self.has_count = count + + async def __aiter__(self): + self.items = await self._query_items() + if self.items is None: + raise RuntimeError('query returned None') + if not self.items and self.page != 1 and self.error_out: + raise self.error_out + if self.has_count: + self.total = await self._query_count() + for i in self.items: + yield i + + +def async_query(db: SQLAlchemy, multi: False): + """ + Wraps a query returning function into an executor coroutine. + + The query function remains available as the .q or .query attribute. + """ + def decorator(func): + @wraps(func) + async def executor(*args, **kwargs): + async with db as session: + result = await session.execute(func(*args, **kwargs)) + return result.scalars() if multi else result.scalar() + executor.query = executor.q = func + return executor + return decorator + + +# Optional dependency: do not import into __init__.py +__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query') \ No newline at end of file diff --git a/src/suou/sqlalchemy.py b/src/suou/sqlalchemy/orm.py similarity index 57% rename from src/suou/sqlalchemy.py rename to src/suou/sqlalchemy/orm.py index 2f434e2..6d75a4f 100644 --- a/src/suou/sqlalchemy.py +++ b/src/suou/sqlalchemy/orm.py @@ -1,5 +1,7 @@ """ -Utilities for SQLAlchemy +Utilities for SQLAlchemy; ORM + +NEW 0.6.0 --- @@ -14,53 +16,38 @@ This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ -from __future__ import annotations -from abc import ABCMeta, abstractmethod -from functools import wraps -from typing import Callable, Iterable, Never, TypeVar + +from binascii import Incomplete +from typing import Any, Callable import warnings -from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text -from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship +from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text +from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, declarative_base as _declarative_base, relationship +from sqlalchemy.types import TypeEngine +from suou.classtools import Wanted +from suou.codecs import StringCase +from suou.iding import Siq, SiqCache, SiqGen, SiqType +from suou.itertools import kwargs_prefix +from suou.snowflake import SnowflakeGen +from suou.sqlalchemy import IdType -from .snowflake import SnowflakeGen -from .itertools import kwargs_prefix, makelist -from .signing import HasSigner, UserSigner -from .codecs import StringCase -from .functools import deprecated, not_implemented -from .iding import Siq, SiqGen, SiqType, SiqCache -from .classtools import Incomplete, Wanted -_T = TypeVar('_T') - -# SIQs are 14 bytes long. Storage is padded for alignment -# Not to be confused with SiqType. -IdType: type[LargeBinary] = LargeBinary(16) - -@not_implemented -def sql_escape(s: str, /, dialect: Dialect) -> str: +def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]: """ - Escape a value for SQL embedding, using SQLAlchemy's literal processors. - Requires a dialect argument. + Return a table's column given its name. - XXX this function is not mature yet, do not use + XXX does it belong outside any scopes? """ - if isinstance(s, str): - return String().literal_processor(dialect=dialect)(s) - raise TypeError('invalid data type') + if isinstance(col, Incomplete): + raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.') + elif isinstance(col, Column): + return col + elif isinstance(col, str): + return getattr(cls, col) + else: + raise TypeError -def create_session(url: str) -> Session: - """ - Create a session on the fly, given a database URL. Useful for - contextless environments, such as Python REPL. - - Heads up: a function with the same name exists in core sqlalchemy, but behaves - completely differently!! - """ - engine = create_engine(url) - return Session(bind = engine) - def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs): """ Marks a column which contains a SIQ. @@ -119,7 +106,6 @@ def match_column(length: int, regex: str, /, case: StringCase = StringCase.AS_IS return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name, constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs) - def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]: """ Column for a single boolean value. @@ -148,33 +134,11 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No ) Base = _declarative_base(metadata=MetaData(**metadata), **kwargs) return Base -entity_base = deprecated('use declarative_base() instead')(declarative_base) +entity_base = warnings.deprecated('use declarative_base() instead')(declarative_base) -def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]: - """ - Generate a user signing function. - Requires a master secret (taken from Base.metadata), a user id (visible in the token) - and a user secret. - """ - if isinstance(id_attr, Column): - id_val = id_attr - elif isinstance(id_attr, str): - id_val = Wanted(id_attr) - if isinstance(secret_attr, Column): - secret_val = secret_attr - elif isinstance(secret_attr, str): - secret_val = Wanted(secret_attr) - def token_signer_factory(owner: DeclarativeBase, name: str): - def my_signer(self): - return UserSigner(owner.metadata.info['secret_key'], id_val.__get__(self, owner), secret_val.__get__(self, owner)) - my_signer.__name__ = name - return my_signer - return Incomplete(Wanted(token_signer_factory)) - - -def author_pair(fk_name: str, *, id_type: type = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]: +def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]: """ Return an owner ID/signature column pair, for authenticated values. """ @@ -203,7 +167,8 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]: return (date_col, acc_col) -def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship], Incomplete[Relationship]]: + +def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]: """ Self-referential one-to-many relationship pair. Parent comes first, children come later. @@ -220,8 +185,8 @@ def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Inco parent_kwargs = kwargs_prefix(kwargs, 'parent_') child_kwargs = kwargs_prefix(kwargs, 'child_') - parent = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs) - child = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs) + parent: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'child_{keyword}s', lazy=lazy, **parent_kwargs) + child: Incomplete[Relationship[Any]] = Incomplete(relationship, Wanted(lambda o, n: o.__name__), back_populates=f'parent_{keyword}', lazy=lazy, **child_kwargs) return parent, child @@ -267,93 +232,3 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs) -def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]: - """ - Return a table's column given its name. - - XXX does it belong outside any scopes? - """ - if isinstance(col, Incomplete): - raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.') - elif isinstance(col, Column): - return col - elif isinstance(col, str): - return getattr(cls, col) - else: - raise TypeError - -## Utilities for use in web apps below - -class AuthSrc(metaclass=ABCMeta): - ''' - AuthSrc object required for require_auth_base(). - - This is an abstract class and is NOT usable directly. - - This is not part of the public API - ''' - def required_exc(self) -> Never: - raise ValueError('required field missing') - def invalid_exc(self, msg: str = 'validation failed') -> Never: - raise ValueError(msg) - @abstractmethod - def get_session(self) -> Session: - pass - def get_user(self, getter: Callable): - return getter(self.get_token()) - @abstractmethod - def get_token(self): - pass - @abstractmethod - def get_signature(self): - pass - - -def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user', - required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None): - ''' - Inject the current user into a view, given the Authorization: Bearer header. - - For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object. - ''' - col = want_column(cls, column) - validators = makelist(validators) - - def get_user(token) -> DeclarativeBase: - if token is None: - return None - tok_parts = UserSigner.split_token(token) - user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar() - try: - signer: UserSigner = user.signer() - signer.unsign(token) - return user - except Exception: - return None - - def _default_invalid(msg: str = 'Validation failed'): - raise ValueError(msg) - - invalid_exc = src.invalid_exc or _default_invalid - required_exc = src.required_exc or (lambda: _default_invalid('Login required')) - - def decorator(func: Callable): - @wraps(func) - def wrapper(*a, **ka): - ka[dest] = get_user(src.get_token()) - if not ka[dest] and required: - required_exc() - if signed: - ka[sig_dest] = src.get_signature() - for valid in validators: - if not valid(ka[dest]): - invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest])) - return func(*a, **ka) - return wrapper - return decorator - -# Optional dependency: do not import into __init__.py -__all__ = ( - 'IdType', 'id_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint', - 'author_pair', 'age_pair', 'require_auth_base', 'bound_fk', 'unbound_fk', 'want_column' -) \ No newline at end of file diff --git a/src/suou/sqlalchemy_async.py b/src/suou/sqlalchemy_async.py index 575f239..47b3396 100644 --- a/src/suou/sqlalchemy_async.py +++ b/src/suou/sqlalchemy_async.py @@ -1,7 +1,7 @@ """ Helpers for asynchronous use of SQLAlchemy. -NEW 0.5.0 +NEW 0.5.0; MOVED to sqlalchemy.asyncio in 0.6.0 --- @@ -17,159 +17,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """ from __future__ import annotations -from functools import wraps - -from sqlalchemy import Engine, Select, func, select -from sqlalchemy.orm import DeclarativeBase, lazyload -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from flask_sqlalchemy.pagination import Pagination - -from suou.exceptions import NotFoundError - -class SQLAlchemy: - """ - Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy() - eligible for async environments - - NEW 0.5.0 - """ - base: DeclarativeBase - engine: AsyncEngine - _sessions: list[AsyncSession] - NotFound = NotFoundError - - def __init__(self, model_class: DeclarativeBase): - self.base = model_class - self.engine = None - self._sessions = [] - def bind(self, url: str): - self.engine = create_async_engine(url) - def _ensure_engine(self): - if self.engine is None: - raise RuntimeError('database is not connected') - async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession: - self._ensure_engine() - ## XXX is it accurate? - s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw) - self._sessions.append(s) - return s - async def __aenter__(self) -> AsyncSession: - return await self.begin() - async def __aexit__(self, e1, e2, e3): - ## XXX is it accurate? - s = self._sessions.pop() - if e1: - await s.rollback() - else: - await s.commit() - await s.close() - async def paginate(self, select: Select, *, - page: int | None = None, per_page: int | None = None, - max_per_page: int | None = None, error_out: bool = True, - count: bool = True) -> AsyncSelectPagination: - """ - Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate(). - """ - async with self as session: - return AsyncSelectPagination( - select = select, - session = session, - page = page, - per_page=per_page, max_per_page=max_per_page, - error_out=self.NotFound if error_out else None, count=count - ) - async def create_all(self, *, checkfirst = True): - """ - Initialize database - """ - self._ensure_engine() - self.base.metadata.create_all( - self.engine, checkfirst=checkfirst - ) +from .functools import deprecated -class AsyncSelectPagination(Pagination): - """ - flask_sqlalchemy.SelectPagination but asynchronous. +from .sqlalchemy.asyncio import SQLAlchemy, AsyncSelectPagination, async_query - Pagination is not part of the public API, therefore expect that it may break - """ - - async def _query_items(self) -> list: - select_q: Select = self._query_args["select"] - select = select_q.limit(self.per_page).offset(self._query_offset) - session: AsyncSession = self._query_args["session"] - out = (await session.execute(select)).scalars() - return out - - async def _query_count(self) -> int: - select_q: Select = self._query_args["select"] - sub = select_q.options(lazyload("*")).order_by(None).subquery() - session: AsyncSession = self._query_args["session"] - out = (await session.execute(select(func.count()).select_from(sub))).scalar() - return out - - def __init__(self, - page: int | None = None, - per_page: int | None = None, - max_per_page: int | None = 100, - error_out: Exception | None = NotFoundError, - count: bool = True, - **kwargs): - ## XXX flask-sqlalchemy says Pagination() is not public API. - ## Things may break; beware. - self._query_args = kwargs - page, per_page = self._prepare_page_args( - page=page, - per_page=per_page, - max_per_page=max_per_page, - error_out=error_out, - ) - - self.page: int = page - """The current page.""" - - self.per_page: int = per_page - """The maximum number of items on a page.""" - - self.max_per_page: int | None = max_per_page - """The maximum allowed value for ``per_page``.""" - - self.items = None - self.total = None - self.error_out = error_out - self.has_count = count - - async def __aiter__(self): - self.items = await self._query_items() - if self.items is None: - raise RuntimeError('query returned None') - if not self.items and self.page != 1 and self.error_out: - raise self.error_out - if self.has_count: - self.total = await self._query_count() - for i in self.items: - yield i - - -def async_query(db: SQLAlchemy, multi: False): - """ - Wraps a query returning function into an executor coroutine. - - The query function remains available as the .q or .query attribute. - """ - def decorator(func): - @wraps(func) - async def executor(*args, **kwargs): - async with db as session: - result = await session.execute(func(*args, **kwargs)) - return result.scalars() if multi else result.scalar() - executor.query = executor.q = func - return executor - return decorator - +SQLAlchemy = deprecated('import from suou.sqlalchemy.asyncio instead')(SQLAlchemy) +AsyncSelectPagination = deprecated('import from suou.sqlalchemy.asyncio instead')(AsyncSelectPagination) +async_query = deprecated('import from suou.sqlalchemy.asyncio instead')(async_query) # Optional dependency: do not import into __init__.py -__all__ = ('SQLAlchemy', 'async_query') \ No newline at end of file +__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query') \ No newline at end of file