Compare commits
2 commits
eb8371757d
...
1c809a9930
| Author | SHA1 | Date | |
|---|---|---|---|
| 1c809a9930 | |||
| 9e386c4f71 |
3 changed files with 73 additions and 9 deletions
|
|
@ -3,13 +3,15 @@
|
|||
## 0.6.0
|
||||
|
||||
+ `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports.
|
||||
+ Add several new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`, `SessionWrapper`,
|
||||
`wrap=` argument to SQLAlchemy. Also removed dead batteries.
|
||||
+ Add `.waiter` module. For now, non-functional.
|
||||
+ Add those new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`. Also removed dead batteries.
|
||||
+ Add `ArgConfigSource` to `.configparse`
|
||||
|
||||
## 0.5.3
|
||||
|
||||
...
|
||||
- Added docstring to `SQLAlchemy()`.
|
||||
- More type fixes.
|
||||
|
||||
## 0.5.2
|
||||
|
||||
|
|
|
|||
|
|
@ -169,5 +169,5 @@ __all__ = (
|
|||
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
|
||||
'a_relationship', 'BitSelector', 'secret_column',
|
||||
# .asyncio
|
||||
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
|
||||
'SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper'
|
||||
)
|
||||
|
|
@ -21,7 +21,7 @@ from __future__ import annotations
|
|||
from functools import wraps
|
||||
|
||||
|
||||
from sqlalchemy import Engine, Select, func, select
|
||||
from sqlalchemy import Engine, Select, Table, func, select
|
||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
|
@ -30,29 +30,44 @@ from suou.exceptions import NotFoundError
|
|||
|
||||
class SQLAlchemy:
|
||||
"""
|
||||
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||
eligible for async environments
|
||||
Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||
eligible for async environments.
|
||||
|
||||
Notable changes:
|
||||
+ You have to create the session yourself. Easiest use case:
|
||||
|
||||
async def handler (userid):
|
||||
async with db as session:
|
||||
# do something
|
||||
user = (await session.execute(select(User).where(User.id == userid))).scalar()
|
||||
# ...
|
||||
|
||||
NEW 0.5.0
|
||||
|
||||
UPDATED 0.6.0: added wrap=True
|
||||
"""
|
||||
base: DeclarativeBase
|
||||
engine: AsyncEngine
|
||||
_sessions: list[AsyncSession]
|
||||
_wrapsessions: bool
|
||||
NotFound = NotFoundError
|
||||
|
||||
def __init__(self, model_class: DeclarativeBase):
|
||||
def __init__(self, model_class: DeclarativeBase, *, wrap = False):
|
||||
self.base = model_class
|
||||
self.engine = None
|
||||
self._wrapsessions = wrap
|
||||
self._sessions = []
|
||||
def bind(self, url: str):
|
||||
self.engine = create_async_engine(url)
|
||||
def _ensure_engine(self):
|
||||
if self.engine is None:
|
||||
raise RuntimeError('database is not connected')
|
||||
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||
async def begin(self, *, expire_on_commit = False, wrap = False, **kw) -> AsyncSession:
|
||||
self._ensure_engine()
|
||||
## XXX is it accurate?
|
||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||
if wrap:
|
||||
s = SessionWrapper(s)
|
||||
self._sessions.append(s)
|
||||
return s
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
|
|
@ -171,6 +186,53 @@ def async_query(db: SQLAlchemy, multi: False):
|
|||
return executor
|
||||
return decorator
|
||||
|
||||
class SessionWrapper:
|
||||
"""
|
||||
Wrap a SQLAlchemy() session (context manager) adding several QoL utilitites.
|
||||
|
||||
It can be applied to:
|
||||
+ sessions created by SQLAlchemy() - it must receive a wrap=True argument in constructor;
|
||||
+ sessions created manually - by constructing SessionWrapper(session).
|
||||
|
||||
This works in async context; DO NOT USE with regular SQLAlchemy.
|
||||
|
||||
NEW 0.6.0
|
||||
"""
|
||||
|
||||
def __init__(self, db_or_session: SQLAlchemy | AsyncSession):
|
||||
self._wrapped = db_or_session
|
||||
async def __aenter__(self):
|
||||
if isinstance(self._wrapped, SQLAlchemy):
|
||||
self._wrapped = await self._wrapped.begin()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_info):
|
||||
await self._wrapped.__aexit__(*exc_info)
|
||||
|
||||
@property
|
||||
def _session(self):
|
||||
if isinstance(self._wrapped, AsyncSession):
|
||||
return self._wrapped
|
||||
raise RuntimeError('active session is required')
|
||||
|
||||
async def get_one(self, query: Select):
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar()
|
||||
|
||||
async def get_by_id(self, table: Table, key) :
|
||||
return await self.get_one(select(table).where(table.id == key)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
async def get_list(self, query: Select, limit: int | None = None):
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
result = await self._session.execute(query)
|
||||
return list(result.scalars())
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""
|
||||
Fall back to the wrapped session
|
||||
"""
|
||||
return getattr(self._session, key)
|
||||
|
||||
# Optional dependency: do not import into __init__.py
|
||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper')
|
||||
Loading…
Add table
Add a link
Reference in a new issue