fixed context manager

This commit is contained in:
Yusur 2025-08-04 14:39:24 +02:00
parent 6846c763f2
commit daa9f6de0c

View file

@ -31,21 +31,31 @@ class SQLAlchemy:
""" """
base: DeclarativeBase base: DeclarativeBase
engine: Engine engine: Engine
_sessions: list[AsyncSession]
NotFound = NotFoundError NotFound = NotFoundError
def __init__(self, model_class: DeclarativeBase): def __init__(self, model_class: DeclarativeBase):
self.base = model_class self.base = model_class
self.engine = None self.engine = None
self._sessions = []
def bind(self, url: str): def bind(self, url: str):
self.engine = create_async_engine(url) self.engine = create_async_engine(url)
async def begin(self) -> AsyncSession: async def begin(self) -> AsyncSession:
if self.engine is None: if self.engine is None:
raise RuntimeError('database is not connected') raise RuntimeError('database is not connected')
return self.engine.begin() ## XXX is it accurate?
async def __aenter__(self): s = self.engine.begin()
return self.begin() self._sessions.append(s)
return s
async def __aenter__(self) -> AsyncSession:
return await self.begin()
async def __aexit__(self, e1, e2, e3): async def __aexit__(self, e1, e2, e3):
return await self.engine.__aexit__(e1, e2, e3) ## XXX is it accurate?
s = self._sessions.pop()
if e1:
await s.rollback()
else:
await s.commit()
async def paginate(self, select: Select, *, async def paginate(self, select: Select, *,
page: int | None = None, per_page: int | None = None, page: int | None = None, per_page: int | None = None,
max_per_page: int | None = None, error_out: bool = True, max_per_page: int | None = None, error_out: bool = True,