From 9e386c4f71677770fae778110f31ca0250cea2d4 Mon Sep 17 00:00:00 2001 From: Yusur Princeps Date: Thu, 4 Sep 2025 09:29:38 +0200 Subject: [PATCH] add SessionWrapper --- CHANGELOG.md | 3 +- src/suou/sqlalchemy/__init__.py | 2 +- src/suou/sqlalchemy/asyncio.py | 74 ++++++++++++++++++++++++++++++--- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a31ce90..c70d776 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,8 +3,9 @@ ## 0.6.0 + `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports. ++ Add several new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`, `SessionWrapper`, + `wrap=` argument to SQLAlchemy. Also removed dead batteries. + Add `.waiter` module. For now, non-functional. -+ Add those new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`. Also removed dead batteries. + Add `ArgConfigSource` to `.configparse` ## 0.5.3 diff --git a/src/suou/sqlalchemy/__init__.py b/src/suou/sqlalchemy/__init__.py index 2603d0b..b207438 100644 --- a/src/suou/sqlalchemy/__init__.py +++ b/src/suou/sqlalchemy/__init__.py @@ -169,5 +169,5 @@ __all__ = ( 'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column', 'a_relationship', 'BitSelector', 'secret_column', # .asyncio - 'SQLAlchemy', 'AsyncSelectPagination', 'async_query' + 'SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper' ) \ No newline at end of file diff --git a/src/suou/sqlalchemy/asyncio.py b/src/suou/sqlalchemy/asyncio.py index 786d9f8..e513652 100644 --- a/src/suou/sqlalchemy/asyncio.py +++ b/src/suou/sqlalchemy/asyncio.py @@ -21,7 +21,7 @@ from __future__ import annotations from functools import wraps -from sqlalchemy import Engine, Select, func, select +from sqlalchemy import Engine, Select, Table, func, select from sqlalchemy.orm import DeclarativeBase, lazyload from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from flask_sqlalchemy.pagination import Pagination @@ -30,29 +30,44 @@ from suou.exceptions import NotFoundError class SQLAlchemy: """ - Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy() - eligible for async environments + Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy() + eligible for async environments. + + Notable changes: + + You have to create the session yourself. Easiest use case: + + async def handler (userid): + async with db as session: + # do something + user = (await session.execute(select(User).where(User.id == userid))).scalar() + # ... NEW 0.5.0 + + UPDATED 0.6.0: added wrap=True """ base: DeclarativeBase engine: AsyncEngine _sessions: list[AsyncSession] + _wrapsessions: bool NotFound = NotFoundError - def __init__(self, model_class: DeclarativeBase): + def __init__(self, model_class: DeclarativeBase, *, wrap = False): self.base = model_class self.engine = None + self._wrapsessions = wrap self._sessions = [] def bind(self, url: str): self.engine = create_async_engine(url) def _ensure_engine(self): if self.engine is None: raise RuntimeError('database is not connected') - async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession: + async def begin(self, *, expire_on_commit = False, wrap = False, **kw) -> AsyncSession: self._ensure_engine() ## XXX is it accurate? s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw) + if wrap: + s = SessionWrapper(s) self._sessions.append(s) return s async def __aenter__(self) -> AsyncSession: @@ -171,6 +186,53 @@ def async_query(db: SQLAlchemy, multi: False): return executor return decorator +class SessionWrapper: + """ + Wrap a SQLAlchemy() session (context manager) adding several QoL utilitites. + + It can be applied to: + + sessions created by SQLAlchemy() - it must receive a wrap=True argument in constructor; + + sessions created manually - by constructing SessionWrapper(session). + + This works in async context; DO NOT USE with regular SQLAlchemy. + + NEW 0.6.0 + """ + + def __init__(self, db_or_session: SQLAlchemy | AsyncSession): + self._wrapped = db_or_session + async def __aenter__(self): + if isinstance(self._wrapped, SQLAlchemy): + self._wrapped = await self._wrapped.begin() + return self + + async def __aexit__(self, *exc_info): + await self._wrapped.__aexit__(*exc_info) + + @property + def _session(self): + if isinstance(self._wrapped, AsyncSession): + return self._wrapped + raise RuntimeError('active session is required') + + async def get_one(self, query: Select): + result = await self._session.execute(query) + return result.scalar() + + async def get_by_id(self, table: Table, key) : + return await self.get_one(select(table).where(table.id == key)) # pyright: ignore[reportAttributeAccessIssue] + + async def get_list(self, query: Select, limit: int | None = None): + if limit: + query = query.limit(limit) + result = await self._session.execute(query) + return list(result.scalars()) + + def __getattr__(self, key): + """ + Fall back to the wrapped session + """ + return getattr(self._session, key) # Optional dependency: do not import into __init__.py -__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query') \ No newline at end of file +__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper') \ No newline at end of file