fixed context manager
This commit is contained in:
parent
6846c763f2
commit
daa9f6de0c
1 changed files with 14 additions and 4 deletions
|
|
@ -31,21 +31,31 @@ class SQLAlchemy:
|
|||
"""
|
||||
base: DeclarativeBase
|
||||
engine: Engine
|
||||
_sessions: list[AsyncSession]
|
||||
NotFound = NotFoundError
|
||||
|
||||
def __init__(self, model_class: DeclarativeBase):
|
||||
self.base = model_class
|
||||
self.engine = None
|
||||
self._sessions = []
|
||||
def bind(self, url: str):
|
||||
self.engine = create_async_engine(url)
|
||||
async def begin(self) -> AsyncSession:
|
||||
if self.engine is None:
|
||||
raise RuntimeError('database is not connected')
|
||||
return self.engine.begin()
|
||||
async def __aenter__(self):
|
||||
return self.begin()
|
||||
## XXX is it accurate?
|
||||
s = self.engine.begin()
|
||||
self._sessions.append(s)
|
||||
return s
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return await self.begin()
|
||||
async def __aexit__(self, e1, e2, e3):
|
||||
return await self.engine.__aexit__(e1, e2, e3)
|
||||
## XXX is it accurate?
|
||||
s = self._sessions.pop()
|
||||
if e1:
|
||||
await s.rollback()
|
||||
else:
|
||||
await s.commit()
|
||||
async def paginate(self, select: Select, *,
|
||||
page: int | None = None, per_page: int | None = None,
|
||||
max_per_page: int | None = None, error_out: bool = True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue