add SessionWrapper

This commit is contained in:
Yusur 2025-09-04 09:29:38 +02:00
parent eb8371757d
commit 9e386c4f71
3 changed files with 71 additions and 8 deletions

View file

@ -3,8 +3,9 @@
## 0.6.0 ## 0.6.0
+ `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports. + `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports.
+ Add several new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`, `SessionWrapper`,
`wrap=` argument to SQLAlchemy. Also removed dead batteries.
+ Add `.waiter` module. For now, non-functional. + Add `.waiter` module. For now, non-functional.
+ Add those new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`. Also removed dead batteries.
+ Add `ArgConfigSource` to `.configparse` + Add `ArgConfigSource` to `.configparse`
## 0.5.3 ## 0.5.3

View file

@ -169,5 +169,5 @@ __all__ = (
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column', 'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
'a_relationship', 'BitSelector', 'secret_column', 'a_relationship', 'BitSelector', 'secret_column',
# .asyncio # .asyncio
'SQLAlchemy', 'AsyncSelectPagination', 'async_query' 'SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper'
) )

View file

@ -21,7 +21,7 @@ from __future__ import annotations
from functools import wraps from functools import wraps
from sqlalchemy import Engine, Select, func, select from sqlalchemy import Engine, Select, Table, func, select
from sqlalchemy.orm import DeclarativeBase, lazyload from sqlalchemy.orm import DeclarativeBase, lazyload
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from flask_sqlalchemy.pagination import Pagination from flask_sqlalchemy.pagination import Pagination
@ -30,29 +30,44 @@ from suou.exceptions import NotFoundError
class SQLAlchemy: class SQLAlchemy:
""" """
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy() Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy()
eligible for async environments eligible for async environments.
Notable changes:
+ You have to create the session yourself. Easiest use case:
async def handler (userid):
async with db as session:
# do something
user = (await session.execute(select(User).where(User.id == userid))).scalar()
# ...
NEW 0.5.0 NEW 0.5.0
UPDATED 0.6.0: added wrap=True
""" """
base: DeclarativeBase base: DeclarativeBase
engine: AsyncEngine engine: AsyncEngine
_sessions: list[AsyncSession] _sessions: list[AsyncSession]
_wrapsessions: bool
NotFound = NotFoundError NotFound = NotFoundError
def __init__(self, model_class: DeclarativeBase): def __init__(self, model_class: DeclarativeBase, *, wrap = False):
self.base = model_class self.base = model_class
self.engine = None self.engine = None
self._wrapsessions = wrap
self._sessions = [] self._sessions = []
def bind(self, url: str): def bind(self, url: str):
self.engine = create_async_engine(url) self.engine = create_async_engine(url)
def _ensure_engine(self): def _ensure_engine(self):
if self.engine is None: if self.engine is None:
raise RuntimeError('database is not connected') raise RuntimeError('database is not connected')
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession: async def begin(self, *, expire_on_commit = False, wrap = False, **kw) -> AsyncSession:
self._ensure_engine() self._ensure_engine()
## XXX is it accurate? ## XXX is it accurate?
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw) s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
if wrap:
s = SessionWrapper(s)
self._sessions.append(s) self._sessions.append(s)
return s return s
async def __aenter__(self) -> AsyncSession: async def __aenter__(self) -> AsyncSession:
@ -171,6 +186,53 @@ def async_query(db: SQLAlchemy, multi: False):
return executor return executor
return decorator return decorator
class SessionWrapper:
"""
Wrap a SQLAlchemy() session (context manager) adding several QoL utilitites.
It can be applied to:
+ sessions created by SQLAlchemy() - it must receive a wrap=True argument in constructor;
+ sessions created manually - by constructing SessionWrapper(session).
This works in async context; DO NOT USE with regular SQLAlchemy.
NEW 0.6.0
"""
def __init__(self, db_or_session: SQLAlchemy | AsyncSession):
self._wrapped = db_or_session
async def __aenter__(self):
if isinstance(self._wrapped, SQLAlchemy):
self._wrapped = await self._wrapped.begin()
return self
async def __aexit__(self, *exc_info):
await self._wrapped.__aexit__(*exc_info)
@property
def _session(self):
if isinstance(self._wrapped, AsyncSession):
return self._wrapped
raise RuntimeError('active session is required')
async def get_one(self, query: Select):
result = await self._session.execute(query)
return result.scalar()
async def get_by_id(self, table: Table, key) :
return await self.get_one(select(table).where(table.id == key)) # pyright: ignore[reportAttributeAccessIssue]
async def get_list(self, query: Select, limit: int | None = None):
if limit:
query = query.limit(limit)
result = await self._session.execute(query)
return list(result.scalars())
def __getattr__(self, key):
"""
Fall back to the wrapped session
"""
return getattr(self._session, key)
# Optional dependency: do not import into __init__.py # Optional dependency: do not import into __init__.py
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query') __all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper')