Compare commits
No commits in common. "1c809a9930837e05e23efcdda30bd56271a665f5" and "eb8371757dcd87b1d3675963726352d16f274323" have entirely different histories.
1c809a9930
...
eb8371757d
3 changed files with 9 additions and 73 deletions
|
|
@ -3,15 +3,13 @@
|
||||||
## 0.6.0
|
## 0.6.0
|
||||||
|
|
||||||
+ `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports.
|
+ `.sqlalchemy` has been made a subpackage and split; `sqlalchemy_async` has been deprecated. Update your imports.
|
||||||
+ Add several new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`, `SessionWrapper`,
|
|
||||||
`wrap=` argument to SQLAlchemy. Also removed dead batteries.
|
|
||||||
+ Add `.waiter` module. For now, non-functional.
|
+ Add `.waiter` module. For now, non-functional.
|
||||||
|
+ Add those new utilities to `.sqlalchemy`: `BitSelector`, `secret_column`, `a_relationship`. Also removed dead batteries.
|
||||||
+ Add `ArgConfigSource` to `.configparse`
|
+ Add `ArgConfigSource` to `.configparse`
|
||||||
|
|
||||||
## 0.5.3
|
## 0.5.3
|
||||||
|
|
||||||
- Added docstring to `SQLAlchemy()`.
|
...
|
||||||
- More type fixes.
|
|
||||||
|
|
||||||
## 0.5.2
|
## 0.5.2
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -169,5 +169,5 @@ __all__ = (
|
||||||
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
|
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
|
||||||
'a_relationship', 'BitSelector', 'secret_column',
|
'a_relationship', 'BitSelector', 'secret_column',
|
||||||
# .asyncio
|
# .asyncio
|
||||||
'SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper'
|
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
|
||||||
)
|
)
|
||||||
|
|
@ -21,7 +21,7 @@ from __future__ import annotations
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
from sqlalchemy import Engine, Select, Table, func, select
|
from sqlalchemy import Engine, Select, func, select
|
||||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||||
from flask_sqlalchemy.pagination import Pagination
|
from flask_sqlalchemy.pagination import Pagination
|
||||||
|
|
@ -30,44 +30,29 @@ from suou.exceptions import NotFoundError
|
||||||
|
|
||||||
class SQLAlchemy:
|
class SQLAlchemy:
|
||||||
"""
|
"""
|
||||||
Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy()
|
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||||
eligible for async environments.
|
eligible for async environments
|
||||||
|
|
||||||
Notable changes:
|
|
||||||
+ You have to create the session yourself. Easiest use case:
|
|
||||||
|
|
||||||
async def handler (userid):
|
|
||||||
async with db as session:
|
|
||||||
# do something
|
|
||||||
user = (await session.execute(select(User).where(User.id == userid))).scalar()
|
|
||||||
# ...
|
|
||||||
|
|
||||||
NEW 0.5.0
|
NEW 0.5.0
|
||||||
|
|
||||||
UPDATED 0.6.0: added wrap=True
|
|
||||||
"""
|
"""
|
||||||
base: DeclarativeBase
|
base: DeclarativeBase
|
||||||
engine: AsyncEngine
|
engine: AsyncEngine
|
||||||
_sessions: list[AsyncSession]
|
_sessions: list[AsyncSession]
|
||||||
_wrapsessions: bool
|
|
||||||
NotFound = NotFoundError
|
NotFound = NotFoundError
|
||||||
|
|
||||||
def __init__(self, model_class: DeclarativeBase, *, wrap = False):
|
def __init__(self, model_class: DeclarativeBase):
|
||||||
self.base = model_class
|
self.base = model_class
|
||||||
self.engine = None
|
self.engine = None
|
||||||
self._wrapsessions = wrap
|
|
||||||
self._sessions = []
|
self._sessions = []
|
||||||
def bind(self, url: str):
|
def bind(self, url: str):
|
||||||
self.engine = create_async_engine(url)
|
self.engine = create_async_engine(url)
|
||||||
def _ensure_engine(self):
|
def _ensure_engine(self):
|
||||||
if self.engine is None:
|
if self.engine is None:
|
||||||
raise RuntimeError('database is not connected')
|
raise RuntimeError('database is not connected')
|
||||||
async def begin(self, *, expire_on_commit = False, wrap = False, **kw) -> AsyncSession:
|
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||||
self._ensure_engine()
|
self._ensure_engine()
|
||||||
## XXX is it accurate?
|
## XXX is it accurate?
|
||||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||||
if wrap:
|
|
||||||
s = SessionWrapper(s)
|
|
||||||
self._sessions.append(s)
|
self._sessions.append(s)
|
||||||
return s
|
return s
|
||||||
async def __aenter__(self) -> AsyncSession:
|
async def __aenter__(self) -> AsyncSession:
|
||||||
|
|
@ -186,53 +171,6 @@ def async_query(db: SQLAlchemy, multi: False):
|
||||||
return executor
|
return executor
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
class SessionWrapper:
|
|
||||||
"""
|
|
||||||
Wrap a SQLAlchemy() session (context manager) adding several QoL utilitites.
|
|
||||||
|
|
||||||
It can be applied to:
|
|
||||||
+ sessions created by SQLAlchemy() - it must receive a wrap=True argument in constructor;
|
|
||||||
+ sessions created manually - by constructing SessionWrapper(session).
|
|
||||||
|
|
||||||
This works in async context; DO NOT USE with regular SQLAlchemy.
|
|
||||||
|
|
||||||
NEW 0.6.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db_or_session: SQLAlchemy | AsyncSession):
|
|
||||||
self._wrapped = db_or_session
|
|
||||||
async def __aenter__(self):
|
|
||||||
if isinstance(self._wrapped, SQLAlchemy):
|
|
||||||
self._wrapped = await self._wrapped.begin()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, *exc_info):
|
|
||||||
await self._wrapped.__aexit__(*exc_info)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _session(self):
|
|
||||||
if isinstance(self._wrapped, AsyncSession):
|
|
||||||
return self._wrapped
|
|
||||||
raise RuntimeError('active session is required')
|
|
||||||
|
|
||||||
async def get_one(self, query: Select):
|
|
||||||
result = await self._session.execute(query)
|
|
||||||
return result.scalar()
|
|
||||||
|
|
||||||
async def get_by_id(self, table: Table, key) :
|
|
||||||
return await self.get_one(select(table).where(table.id == key)) # pyright: ignore[reportAttributeAccessIssue]
|
|
||||||
|
|
||||||
async def get_list(self, query: Select, limit: int | None = None):
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
result = await self._session.execute(query)
|
|
||||||
return list(result.scalars())
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
"""
|
|
||||||
Fall back to the wrapped session
|
|
||||||
"""
|
|
||||||
return getattr(self._session, key)
|
|
||||||
|
|
||||||
# Optional dependency: do not import into __init__.py
|
# Optional dependency: do not import into __init__.py
|
||||||
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query', 'SessionWrapper')
|
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||||
Loading…
Add table
Add a link
Reference in a new issue