Compare commits

...

2 commits

Author SHA1 Message Date
1c809a9930 changelog for 0.5.3 2025-09-04 09:49:31 +02:00
9e386c4f71 add SessionWrapper 2025-09-04 09:29:38 +02:00
3 changed files with 73 additions and 9 deletions

View file

@ -3,13 +3,15 @@
## 0.6.0
+ `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports.
+ Add several new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`, `SessionWrapper`,
`wrap=` argument to SQLAlchemy. Also removed dead batteries.
+ Add `.waiter` module. For now, non-functional.
+ Add those new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`. Also removed dead batteries.
+ Add `ArgConfigSource` to `.configparse`
## 0.5.3
...
- Added docstring to `SQLAlchemy()`.
- More type fixes.
## 0.5.2

View file

@ -169,5 +169,5 @@ __all__ = (
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
'a_relationship', 'BitSelector', 'secret_column',
# .asyncio
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
'SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper'
)

View file

@ -21,7 +21,7 @@ from __future__ import annotations
from functools import wraps
from sqlalchemy import Engine, Select, func, select
from sqlalchemy import Engine, Select, Table, func, select
from sqlalchemy.orm import DeclarativeBase, lazyload
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from flask_sqlalchemy.pagination import Pagination
@ -30,29 +30,44 @@ from suou.exceptions import NotFoundError
class SQLAlchemy:
"""
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
eligible for async environments
Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy()
eligible for async environments.
Notable changes:
+ You have to create the session yourself. Easiest use case:
async def handler (userid):
async with db as session:
# do something
user = (await session.execute(select(User).where(User.id == userid))).scalar()
# ...
NEW 0.5.0
UPDATED 0.6.0: added wrap=True
"""
base: DeclarativeBase
engine: AsyncEngine
_sessions: list[AsyncSession]
_wrapsessions: bool
NotFound = NotFoundError
def __init__(self, model_class: DeclarativeBase):
def __init__(self, model_class: DeclarativeBase, *, wrap = False):
self.base = model_class
self.engine = None
self._wrapsessions = wrap
self._sessions = []
def bind(self, url: str):
self.engine = create_async_engine(url)
def _ensure_engine(self):
if self.engine is None:
raise RuntimeError('database is not connected')
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
async def begin(self, *, expire_on_commit = False, wrap = False, **kw) -> AsyncSession:
self._ensure_engine()
## XXX is it accurate?
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
if wrap:
s = SessionWrapper(s)
self._sessions.append(s)
return s
async def __aenter__(self) -> AsyncSession:
@ -171,6 +186,53 @@ def async_query(db: SQLAlchemy, multi: False):
return executor
return decorator
class SessionWrapper:
"""
Wrap a SQLAlchemy() session (context manager) adding several QoL utilitites.
It can be applied to:
+ sessions created by SQLAlchemy() - it must receive a wrap=True argument in constructor;
+ sessions created manually - by constructing SessionWrapper(session).
This works in async context; DO NOT USE with regular SQLAlchemy.
NEW 0.6.0
"""
def __init__(self, db_or_session: SQLAlchemy | AsyncSession):
self._wrapped = db_or_session
async def __aenter__(self):
if isinstance(self._wrapped, SQLAlchemy):
self._wrapped = await self._wrapped.begin()
return self
async def __aexit__(self, *exc_info):
await self._wrapped.__aexit__(*exc_info)
@property
def _session(self):
if isinstance(self._wrapped, AsyncSession):
return self._wrapped
raise RuntimeError('active session is required')
async def get_one(self, query: Select):
result = await self._session.execute(query)
return result.scalar()
async def get_by_id(self, table: Table, key) :
return await self.get_one(select(table).where(table.id == key)) # pyright: ignore[reportAttributeAccessIssue]
async def get_list(self, query: Select, limit: int | None = None):
if limit:
query = query.limit(limit)
result = await self._session.execute(query)
return list(result.scalars())
def __getattr__(self, key):
"""
Fall back to the wrapped session
"""
return getattr(self._session, key)
# Optional dependency: do not import into __init__.py
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper')