split sqlalchemy modules
This commit is contained in:
parent
f7807ff05a
commit
97194b2b85
4 changed files with 379 additions and 307 deletions
169
src/suou/sqlalchemy/__init__.py
Normal file
169
src/suou/sqlalchemy/__init__.py
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
"""
|
||||||
|
Utilities for SQLAlchemy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Copyright (c) 2025 Sakuragasaki46.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
See LICENSE for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
This software is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, Iterable, Never, TypeVar
|
||||||
|
import warnings
|
||||||
|
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
|
||||||
|
from sqlalchemy.types import TypeEngine
|
||||||
|
|
||||||
|
from ..snowflake import SnowflakeGen
|
||||||
|
from ..itertools import kwargs_prefix, makelist
|
||||||
|
from ..signing import HasSigner, UserSigner
|
||||||
|
from ..codecs import StringCase
|
||||||
|
from ..functools import deprecated, not_implemented
|
||||||
|
from ..iding import Siq, SiqGen, SiqType, SiqCache
|
||||||
|
from ..classtools import Incomplete, Wanted
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
# SIQs are 14 bytes long. Storage is padded for alignment
|
||||||
|
# Not to be confused with SiqType.
|
||||||
|
IdType: TypeEngine = LargeBinary(16)
|
||||||
|
|
||||||
|
def create_session(url: str) -> Session:
|
||||||
|
"""
|
||||||
|
Create a session on the fly, given a database URL. Useful for
|
||||||
|
contextless environments, such as Python REPL.
|
||||||
|
|
||||||
|
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
||||||
|
completely differently!!
|
||||||
|
"""
|
||||||
|
engine = create_engine(url)
|
||||||
|
return Session(bind = engine)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
|
||||||
|
"""
|
||||||
|
Generate a user signing function.
|
||||||
|
|
||||||
|
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
|
||||||
|
and a user secret.
|
||||||
|
"""
|
||||||
|
id_val: Column | Wanted[Column]
|
||||||
|
if isinstance(id_attr, Column):
|
||||||
|
id_val = id_attr
|
||||||
|
elif isinstance(id_attr, str):
|
||||||
|
id_val = Wanted(id_attr)
|
||||||
|
if isinstance(secret_attr, Column):
|
||||||
|
secret_val = secret_attr
|
||||||
|
elif isinstance(secret_attr, str):
|
||||||
|
secret_val = Wanted(secret_attr)
|
||||||
|
def token_signer_factory(owner: DeclarativeBase, name: str):
|
||||||
|
def my_signer(self):
|
||||||
|
return UserSigner(
|
||||||
|
owner.metadata.info['secret_key'],
|
||||||
|
id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
)
|
||||||
|
my_signer.__name__ = name
|
||||||
|
return my_signer
|
||||||
|
return Incomplete(Wanted(token_signer_factory))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Utilities for use in web apps below
|
||||||
|
|
||||||
|
@deprecated('not part of the public API and not even working')
|
||||||
|
class AuthSrc(metaclass=ABCMeta):
|
||||||
|
'''
|
||||||
|
AuthSrc object required for require_auth_base().
|
||||||
|
|
||||||
|
This is an abstract class and is NOT usable directly.
|
||||||
|
|
||||||
|
This is not part of the public API
|
||||||
|
'''
|
||||||
|
def required_exc(self) -> Never:
|
||||||
|
raise ValueError('required field missing')
|
||||||
|
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
||||||
|
raise ValueError(msg)
|
||||||
|
@abstractmethod
|
||||||
|
def get_session(self) -> Session:
|
||||||
|
pass
|
||||||
|
def get_user(self, getter: Callable):
|
||||||
|
return getter(self.get_token())
|
||||||
|
@abstractmethod
|
||||||
|
def get_token(self):
|
||||||
|
pass
|
||||||
|
@abstractmethod
|
||||||
|
def get_signature(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated('not working and too complex to use')
|
||||||
|
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
||||||
|
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
||||||
|
'''
|
||||||
|
Inject the current user into a view, given the Authorization: Bearer header.
|
||||||
|
|
||||||
|
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
||||||
|
'''
|
||||||
|
col = want_column(cls, column)
|
||||||
|
validators = makelist(validators)
|
||||||
|
|
||||||
|
def get_user(token) -> DeclarativeBase:
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
tok_parts = UserSigner.split_token(token)
|
||||||
|
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
||||||
|
try:
|
||||||
|
signer: UserSigner = user.signer()
|
||||||
|
signer.unsign(token)
|
||||||
|
return user
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _default_invalid(msg: str = 'Validation failed'):
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
invalid_exc = src.invalid_exc or _default_invalid
|
||||||
|
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*a, **ka):
|
||||||
|
ka[dest] = get_user(src.get_token())
|
||||||
|
if not ka[dest] and required:
|
||||||
|
required_exc()
|
||||||
|
if signed:
|
||||||
|
ka[sig_dest] = src.get_signature()
|
||||||
|
for valid in validators:
|
||||||
|
if not valid(ka[dest]):
|
||||||
|
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
||||||
|
return func(*a, **ka)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
from .asyncio import SQLAlchemy, AsyncSelectPagination, async_query
|
||||||
|
from .orm import id_column, snowflake_column, match_column, match_constraint, bool_column, declarative_base, author_pair, age_pair, bound_fk, unbound_fk, want_column
|
||||||
|
|
||||||
|
# Optional dependency: do not import into __init__.py
|
||||||
|
__all__ = (
|
||||||
|
'IdType', 'id_column', 'snowflake_column', 'entity_base', 'declarative_base', 'token_signer',
|
||||||
|
'match_column', 'match_constraint', 'bool_column', 'parent_children',
|
||||||
|
'author_pair', 'age_pair', 'bound_fk', 'unbound_fk', 'want_column',
|
||||||
|
# .asyncio
|
||||||
|
'SQLAlchemy', 'AsyncSelectPagination', 'async_query'
|
||||||
|
)
|
||||||
176
src/suou/sqlalchemy/asyncio.py
Normal file
176
src/suou/sqlalchemy/asyncio.py
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helpers for asynchronous use of SQLAlchemy.
|
||||||
|
|
||||||
|
NEW 0.5.0; moved to current location 0.6.0
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Copyright (c) 2025 Sakuragasaki46.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
See LICENSE for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
This software is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import Engine, Select, func, select
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, lazyload
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||||
|
from flask_sqlalchemy.pagination import Pagination
|
||||||
|
|
||||||
|
from suou.exceptions import NotFoundError
|
||||||
|
|
||||||
|
class SQLAlchemy:
|
||||||
|
"""
|
||||||
|
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
||||||
|
eligible for async environments
|
||||||
|
|
||||||
|
NEW 0.5.0
|
||||||
|
"""
|
||||||
|
base: DeclarativeBase
|
||||||
|
engine: AsyncEngine
|
||||||
|
_sessions: list[AsyncSession]
|
||||||
|
NotFound = NotFoundError
|
||||||
|
|
||||||
|
def __init__(self, model_class: DeclarativeBase):
|
||||||
|
self.base = model_class
|
||||||
|
self.engine = None
|
||||||
|
self._sessions = []
|
||||||
|
def bind(self, url: str):
|
||||||
|
self.engine = create_async_engine(url)
|
||||||
|
def _ensure_engine(self):
|
||||||
|
if self.engine is None:
|
||||||
|
raise RuntimeError('database is not connected')
|
||||||
|
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
||||||
|
self._ensure_engine()
|
||||||
|
## XXX is it accurate?
|
||||||
|
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
||||||
|
self._sessions.append(s)
|
||||||
|
return s
|
||||||
|
async def __aenter__(self) -> AsyncSession:
|
||||||
|
return await self.begin()
|
||||||
|
async def __aexit__(self, e1, e2, e3):
|
||||||
|
## XXX is it accurate?
|
||||||
|
s = self._sessions.pop()
|
||||||
|
if e1:
|
||||||
|
await s.rollback()
|
||||||
|
else:
|
||||||
|
await s.commit()
|
||||||
|
await s.close()
|
||||||
|
async def paginate(self, select: Select, *,
|
||||||
|
page: int | None = None, per_page: int | None = None,
|
||||||
|
max_per_page: int | None = None, error_out: bool = True,
|
||||||
|
count: bool = True) -> AsyncSelectPagination:
|
||||||
|
"""
|
||||||
|
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
|
||||||
|
"""
|
||||||
|
async with self as session:
|
||||||
|
return AsyncSelectPagination(
|
||||||
|
select = select,
|
||||||
|
session = session,
|
||||||
|
page = page,
|
||||||
|
per_page=per_page, max_per_page=max_per_page,
|
||||||
|
error_out=self.NotFound if error_out else None, count=count
|
||||||
|
)
|
||||||
|
async def create_all(self, *, checkfirst = True):
|
||||||
|
"""
|
||||||
|
Initialize database
|
||||||
|
"""
|
||||||
|
self._ensure_engine()
|
||||||
|
self.base.metadata.create_all(
|
||||||
|
self.engine, checkfirst=checkfirst
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSelectPagination(Pagination):
|
||||||
|
"""
|
||||||
|
flask_sqlalchemy.SelectPagination but asynchronous.
|
||||||
|
|
||||||
|
Pagination is not part of the public API, therefore expect that it may break
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _query_items(self) -> list:
|
||||||
|
select_q: Select = self._query_args["select"]
|
||||||
|
select = select_q.limit(self.per_page).offset(self._query_offset)
|
||||||
|
session: AsyncSession = self._query_args["session"]
|
||||||
|
out = (await session.execute(select)).scalars()
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _query_count(self) -> int:
|
||||||
|
select_q: Select = self._query_args["select"]
|
||||||
|
sub = select_q.options(lazyload("*")).order_by(None).subquery()
|
||||||
|
session: AsyncSession = self._query_args["session"]
|
||||||
|
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
page: int | None = None,
|
||||||
|
per_page: int | None = None,
|
||||||
|
max_per_page: int | None = 100,
|
||||||
|
error_out: Exception | None = NotFoundError,
|
||||||
|
count: bool = True,
|
||||||
|
**kwargs):
|
||||||
|
## XXX flask-sqlalchemy says Pagination() is not public API.
|
||||||
|
## Things may break; beware.
|
||||||
|
self._query_args = kwargs
|
||||||
|
page, per_page = self._prepare_page_args(
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
max_per_page=max_per_page,
|
||||||
|
error_out=error_out,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.page: int = page
|
||||||
|
"""The current page."""
|
||||||
|
|
||||||
|
self.per_page: int = per_page
|
||||||
|
"""The maximum number of items on a page."""
|
||||||
|
|
||||||
|
self.max_per_page: int | None = max_per_page
|
||||||
|
"""The maximum allowed value for ``per_page``."""
|
||||||
|
|
||||||
|
self.items = None
|
||||||
|
self.total = None
|
||||||
|
self.error_out = error_out
|
||||||
|
self.has_count = count
|
||||||
|
|
||||||
|
async def __aiter__(self):
|
||||||
|
self.items = await self._query_items()
|
||||||
|
if self.items is None:
|
||||||
|
raise RuntimeError('query returned None')
|
||||||
|
if not self.items and self.page != 1 and self.error_out:
|
||||||
|
raise self.error_out
|
||||||
|
if self.has_count:
|
||||||
|
self.total = await self._query_count()
|
||||||
|
for i in self.items:
|
||||||
|
yield i
|
||||||
|
|
||||||
|
|
||||||
|
def async_query(db: SQLAlchemy, multi: False):
|
||||||
|
"""
|
||||||
|
Wraps a query returning function into an executor coroutine.
|
||||||
|
|
||||||
|
The query function remains available as the .q or .query attribute.
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def executor(*args, **kwargs):
|
||||||
|
async with db as session:
|
||||||
|
result = await session.execute(func(*args, **kwargs))
|
||||||
|
return result.scalars() if multi else result.scalar()
|
||||||
|
executor.query = executor.q = func
|
||||||
|
return executor
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Optional dependency: do not import into __init__.py
|
||||||
|
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Utilities for SQLAlchemy
|
Utilities for SQLAlchemy; ORM
|
||||||
|
|
||||||
|
NEW 0.6.0
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -14,54 +16,38 @@ This software is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
|
||||||
from functools import wraps
|
from binascii import Incomplete
|
||||||
from typing import Any, Callable, Iterable, Never, TypeVar
|
from typing import Any, Callable
|
||||||
import warnings
|
import warnings
|
||||||
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Date, Dialect, ForeignKey, LargeBinary, Column, MetaData, SmallInteger, String, create_engine, select, text
|
from sqlalchemy import BigInteger, Boolean, CheckConstraint, Column, Date, ForeignKey, LargeBinary, MetaData, SmallInteger, String, text
|
||||||
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, Session, declarative_base as _declarative_base, relationship
|
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, Relationship, declarative_base as _declarative_base, relationship
|
||||||
from sqlalchemy.types import TypeEngine
|
from sqlalchemy.types import TypeEngine
|
||||||
|
from suou.classtools import Wanted
|
||||||
|
from suou.codecs import StringCase
|
||||||
|
from suou.iding import Siq, SiqCache, SiqGen, SiqType
|
||||||
|
from suou.itertools import kwargs_prefix
|
||||||
|
from suou.snowflake import SnowflakeGen
|
||||||
|
from suou.sqlalchemy import IdType
|
||||||
|
|
||||||
from .snowflake import SnowflakeGen
|
|
||||||
from .itertools import kwargs_prefix, makelist
|
|
||||||
from .signing import HasSigner, UserSigner
|
|
||||||
from .codecs import StringCase
|
|
||||||
from .functools import deprecated, not_implemented
|
|
||||||
from .iding import Siq, SiqGen, SiqType, SiqCache
|
|
||||||
from .classtools import Incomplete, Wanted
|
|
||||||
|
|
||||||
_T = TypeVar('_T')
|
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
||||||
|
|
||||||
# SIQs are 14 bytes long. Storage is padded for alignment
|
|
||||||
# Not to be confused with SiqType.
|
|
||||||
IdType: TypeEngine = LargeBinary(16)
|
|
||||||
|
|
||||||
@not_implemented
|
|
||||||
def sql_escape(s: str, /, dialect: Dialect) -> str:
|
|
||||||
"""
|
"""
|
||||||
Escape a value for SQL embedding, using SQLAlchemy's literal processors.
|
Return a table's column given its name.
|
||||||
Requires a dialect argument.
|
|
||||||
|
|
||||||
XXX this function is not mature yet, do not use
|
XXX does it belong outside any scopes?
|
||||||
"""
|
"""
|
||||||
if isinstance(s, str):
|
if isinstance(col, Incomplete):
|
||||||
return String().literal_processor(dialect=dialect)(s)
|
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
|
||||||
raise TypeError('invalid data type')
|
elif isinstance(col, Column):
|
||||||
|
return col
|
||||||
|
elif isinstance(col, str):
|
||||||
|
return getattr(cls, col)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
|
||||||
def create_session(url: str) -> Session:
|
|
||||||
"""
|
|
||||||
Create a session on the fly, given a database URL. Useful for
|
|
||||||
contextless environments, such as Python REPL.
|
|
||||||
|
|
||||||
Heads up: a function with the same name exists in core sqlalchemy, but behaves
|
|
||||||
completely differently!!
|
|
||||||
"""
|
|
||||||
engine = create_engine(url)
|
|
||||||
return Session(bind = engine)
|
|
||||||
|
|
||||||
def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs):
|
def id_column(typ: SiqType, *, primary_key: bool = True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Marks a column which contains a SIQ.
|
Marks a column which contains a SIQ.
|
||||||
|
|
@ -120,7 +106,6 @@ def match_column(length: int, regex: str, /, case: StringCase = StringCase.AS_IS
|
||||||
return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name,
|
return Incomplete(Column, String(length), Wanted(lambda x, n: match_constraint(n, regex, #dialect=x.metadata.engine.dialect.name,
|
||||||
constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs)
|
constraint_name=constraint_name or f'{x.__tablename__}_{n}_valid')), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]:
|
def bool_column(value: bool = False, nullable: bool = False, **kwargs) -> Column[bool]:
|
||||||
"""
|
"""
|
||||||
Column for a single boolean value.
|
Column for a single boolean value.
|
||||||
|
|
@ -149,35 +134,9 @@ def declarative_base(domain_name: str, master_secret: bytes, metadata: dict | No
|
||||||
)
|
)
|
||||||
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
|
Base = _declarative_base(metadata=MetaData(**metadata), **kwargs)
|
||||||
return Base
|
return Base
|
||||||
entity_base = deprecated('use declarative_base() instead')(declarative_base)
|
entity_base = warnings.deprecated('use declarative_base() instead')(declarative_base)
|
||||||
|
|
||||||
|
|
||||||
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
|
|
||||||
"""
|
|
||||||
Generate a user signing function.
|
|
||||||
|
|
||||||
Requires a master secret (taken from Base.metadata), a user id (visible in the token)
|
|
||||||
and a user secret.
|
|
||||||
"""
|
|
||||||
id_val: Column | Wanted[Column]
|
|
||||||
if isinstance(id_attr, Column):
|
|
||||||
id_val = id_attr
|
|
||||||
elif isinstance(id_attr, str):
|
|
||||||
id_val = Wanted(id_attr)
|
|
||||||
if isinstance(secret_attr, Column):
|
|
||||||
secret_val = secret_attr
|
|
||||||
elif isinstance(secret_attr, str):
|
|
||||||
secret_val = Wanted(secret_attr)
|
|
||||||
def token_signer_factory(owner: DeclarativeBase, name: str):
|
|
||||||
def my_signer(self):
|
|
||||||
return UserSigner(
|
|
||||||
owner.metadata.info['secret_key'],
|
|
||||||
id_val.__get__(self, owner), secret_val.__get__(self, owner) # pyright: ignore[reportAttributeAccessIssue]
|
|
||||||
)
|
|
||||||
my_signer.__name__ = name
|
|
||||||
return my_signer
|
|
||||||
return Incomplete(Wanted(token_signer_factory))
|
|
||||||
|
|
||||||
|
|
||||||
def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
|
def author_pair(fk_name: str, *, id_type: type | TypeEngine = IdType, sig_type: type | None = None, nullable: bool = False, sig_length: int | None = 2048, **ka) -> tuple[Column, Column]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -208,6 +167,7 @@ def age_pair(*, nullable: bool = False, **ka) -> tuple[Column, Column]:
|
||||||
return (date_col, acc_col)
|
return (date_col, acc_col)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]:
|
def parent_children(keyword: str, /, *, lazy='selectin', **kwargs) -> tuple[Incomplete[Relationship[Any]], Incomplete[Relationship[Any]]]:
|
||||||
"""
|
"""
|
||||||
Self-referential one-to-many relationship pair.
|
Self-referential one-to-many relationship pair.
|
||||||
|
|
@ -272,93 +232,3 @@ def bound_fk(target: str | Column | InstrumentedAttribute, typ: _T = None, **kwa
|
||||||
|
|
||||||
return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs)
|
return Column(typ, ForeignKey(target_name, ondelete='CASCADE'), nullable=False, **kwargs)
|
||||||
|
|
||||||
def want_column(cls: type[DeclarativeBase], col: Column[_T] | str) -> Column[_T]:
|
|
||||||
"""
|
|
||||||
Return a table's column given its name.
|
|
||||||
|
|
||||||
XXX does it belong outside any scopes?
|
|
||||||
"""
|
|
||||||
if isinstance(col, Incomplete):
|
|
||||||
raise TypeError('attempt to pass an uninstanced column. Pass the column name as a string instead.')
|
|
||||||
elif isinstance(col, Column):
|
|
||||||
return col
|
|
||||||
elif isinstance(col, str):
|
|
||||||
return getattr(cls, col)
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
## Utilities for use in web apps below
|
|
||||||
|
|
||||||
class AuthSrc(metaclass=ABCMeta):
|
|
||||||
'''
|
|
||||||
AuthSrc object required for require_auth_base().
|
|
||||||
|
|
||||||
This is an abstract class and is NOT usable directly.
|
|
||||||
|
|
||||||
This is not part of the public API
|
|
||||||
'''
|
|
||||||
def required_exc(self) -> Never:
|
|
||||||
raise ValueError('required field missing')
|
|
||||||
def invalid_exc(self, msg: str = 'validation failed') -> Never:
|
|
||||||
raise ValueError(msg)
|
|
||||||
@abstractmethod
|
|
||||||
def get_session(self) -> Session:
|
|
||||||
pass
|
|
||||||
def get_user(self, getter: Callable):
|
|
||||||
return getter(self.get_token())
|
|
||||||
@abstractmethod
|
|
||||||
def get_token(self):
|
|
||||||
pass
|
|
||||||
@abstractmethod
|
|
||||||
def get_signature(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str | Column[_T] = 'id', dest: str = 'user',
|
|
||||||
required: bool = False, signed: bool = False, sig_dest: str = 'signature', validators: Callable | Iterable[Callable] | None = None):
|
|
||||||
'''
|
|
||||||
Inject the current user into a view, given the Authorization: Bearer header.
|
|
||||||
|
|
||||||
For portability reasons, this is a partial, two-component function, requiring a AuthSrc() object.
|
|
||||||
'''
|
|
||||||
col = want_column(cls, column)
|
|
||||||
validators = makelist(validators)
|
|
||||||
|
|
||||||
def get_user(token) -> DeclarativeBase:
|
|
||||||
if token is None:
|
|
||||||
return None
|
|
||||||
tok_parts = UserSigner.split_token(token)
|
|
||||||
user: HasSigner = src.get_session().execute(select(cls).where(col == tok_parts[0])).scalar()
|
|
||||||
try:
|
|
||||||
signer: UserSigner = user.signer()
|
|
||||||
signer.unsign(token)
|
|
||||||
return user
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _default_invalid(msg: str = 'Validation failed'):
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
invalid_exc = src.invalid_exc or _default_invalid
|
|
||||||
required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
|
|
||||||
|
|
||||||
def decorator(func: Callable):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*a, **ka):
|
|
||||||
ka[dest] = get_user(src.get_token())
|
|
||||||
if not ka[dest] and required:
|
|
||||||
required_exc()
|
|
||||||
if signed:
|
|
||||||
ka[sig_dest] = src.get_signature()
|
|
||||||
for valid in validators:
|
|
||||||
if not valid(ka[dest]):
|
|
||||||
invalid_exc(getattr(valid, 'message', 'validation failed').format(user=ka[dest]))
|
|
||||||
return func(*a, **ka)
|
|
||||||
return wrapper
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
# Optional dependency: do not import into __init__.py
|
|
||||||
__all__ = (
|
|
||||||
'IdType', 'id_column', 'entity_base', 'declarative_base', 'token_signer', 'match_column', 'match_constraint',
|
|
||||||
'author_pair', 'age_pair', 'require_auth_base', 'bound_fk', 'unbound_fk', 'want_column'
|
|
||||||
)
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Helpers for asynchronous use of SQLAlchemy.
|
Helpers for asynchronous use of SQLAlchemy.
|
||||||
|
|
||||||
NEW 0.5.0
|
NEW 0.5.0; MOVED to sqlalchemy.asyncio in 0.6.0
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -17,159 +17,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
|
from .functools import deprecated
|
||||||
from sqlalchemy import Engine, Select, func, select
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, lazyload
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
||||||
from flask_sqlalchemy.pagination import Pagination
|
|
||||||
|
|
||||||
from suou.exceptions import NotFoundError
|
|
||||||
|
|
||||||
class SQLAlchemy:
|
|
||||||
"""
|
|
||||||
Drop-in (?) replacement for flask_sqlalchemy.SQLAlchemy()
|
|
||||||
eligible for async environments
|
|
||||||
|
|
||||||
NEW 0.5.0
|
|
||||||
"""
|
|
||||||
base: DeclarativeBase
|
|
||||||
engine: AsyncEngine
|
|
||||||
_sessions: list[AsyncSession]
|
|
||||||
NotFound = NotFoundError
|
|
||||||
|
|
||||||
def __init__(self, model_class: DeclarativeBase):
|
|
||||||
self.base = model_class
|
|
||||||
self.engine = None
|
|
||||||
self._sessions = []
|
|
||||||
def bind(self, url: str):
|
|
||||||
self.engine = create_async_engine(url)
|
|
||||||
def _ensure_engine(self):
|
|
||||||
if self.engine is None:
|
|
||||||
raise RuntimeError('database is not connected')
|
|
||||||
async def begin(self, *, expire_on_commit = False, **kw) -> AsyncSession:
|
|
||||||
self._ensure_engine()
|
|
||||||
## XXX is it accurate?
|
|
||||||
s = AsyncSession(self.engine, expire_on_commit=expire_on_commit, **kw)
|
|
||||||
self._sessions.append(s)
|
|
||||||
return s
|
|
||||||
async def __aenter__(self) -> AsyncSession:
|
|
||||||
return await self.begin()
|
|
||||||
async def __aexit__(self, e1, e2, e3):
|
|
||||||
## XXX is it accurate?
|
|
||||||
s = self._sessions.pop()
|
|
||||||
if e1:
|
|
||||||
await s.rollback()
|
|
||||||
else:
|
|
||||||
await s.commit()
|
|
||||||
await s.close()
|
|
||||||
async def paginate(self, select: Select, *,
|
|
||||||
page: int | None = None, per_page: int | None = None,
|
|
||||||
max_per_page: int | None = None, error_out: bool = True,
|
|
||||||
count: bool = True) -> AsyncSelectPagination:
|
|
||||||
"""
|
|
||||||
Return a pagination. Analogous to flask_sqlalchemy.SQLAlchemy.paginate().
|
|
||||||
"""
|
|
||||||
async with self as session:
|
|
||||||
return AsyncSelectPagination(
|
|
||||||
select = select,
|
|
||||||
session = session,
|
|
||||||
page = page,
|
|
||||||
per_page=per_page, max_per_page=max_per_page,
|
|
||||||
error_out=self.NotFound if error_out else None, count=count
|
|
||||||
)
|
|
||||||
async def create_all(self, *, checkfirst = True):
|
|
||||||
"""
|
|
||||||
Initialize database
|
|
||||||
"""
|
|
||||||
self._ensure_engine()
|
|
||||||
self.base.metadata.create_all(
|
|
||||||
self.engine, checkfirst=checkfirst
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncSelectPagination(Pagination):
|
from .sqlalchemy.asyncio import SQLAlchemy, AsyncSelectPagination, async_query
|
||||||
"""
|
|
||||||
flask_sqlalchemy.SelectPagination but asynchronous.
|
|
||||||
|
|
||||||
Pagination is not part of the public API, therefore expect that it may break
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def _query_items(self) -> list:
|
|
||||||
select_q: Select = self._query_args["select"]
|
|
||||||
select = select_q.limit(self.per_page).offset(self._query_offset)
|
|
||||||
session: AsyncSession = self._query_args["session"]
|
|
||||||
out = (await session.execute(select)).scalars()
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def _query_count(self) -> int:
|
|
||||||
select_q: Select = self._query_args["select"]
|
|
||||||
sub = select_q.options(lazyload("*")).order_by(None).subquery()
|
|
||||||
session: AsyncSession = self._query_args["session"]
|
|
||||||
out = (await session.execute(select(func.count()).select_from(sub))).scalar()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
page: int | None = None,
|
|
||||||
per_page: int | None = None,
|
|
||||||
max_per_page: int | None = 100,
|
|
||||||
error_out: Exception | None = NotFoundError,
|
|
||||||
count: bool = True,
|
|
||||||
**kwargs):
|
|
||||||
## XXX flask-sqlalchemy says Pagination() is not public API.
|
|
||||||
## Things may break; beware.
|
|
||||||
self._query_args = kwargs
|
|
||||||
page, per_page = self._prepare_page_args(
|
|
||||||
page=page,
|
|
||||||
per_page=per_page,
|
|
||||||
max_per_page=max_per_page,
|
|
||||||
error_out=error_out,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.page: int = page
|
|
||||||
"""The current page."""
|
|
||||||
|
|
||||||
self.per_page: int = per_page
|
|
||||||
"""The maximum number of items on a page."""
|
|
||||||
|
|
||||||
self.max_per_page: int | None = max_per_page
|
|
||||||
"""The maximum allowed value for ``per_page``."""
|
|
||||||
|
|
||||||
self.items = None
|
|
||||||
self.total = None
|
|
||||||
self.error_out = error_out
|
|
||||||
self.has_count = count
|
|
||||||
|
|
||||||
async def __aiter__(self):
|
|
||||||
self.items = await self._query_items()
|
|
||||||
if self.items is None:
|
|
||||||
raise RuntimeError('query returned None')
|
|
||||||
if not self.items and self.page != 1 and self.error_out:
|
|
||||||
raise self.error_out
|
|
||||||
if self.has_count:
|
|
||||||
self.total = await self._query_count()
|
|
||||||
for i in self.items:
|
|
||||||
yield i
|
|
||||||
|
|
||||||
|
|
||||||
def async_query(db: SQLAlchemy, multi: False):
|
|
||||||
"""
|
|
||||||
Wraps a query returning function into an executor coroutine.
|
|
||||||
|
|
||||||
The query function remains available as the .q or .query attribute.
|
|
||||||
"""
|
|
||||||
def decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def executor(*args, **kwargs):
|
|
||||||
async with db as session:
|
|
||||||
result = await session.execute(func(*args, **kwargs))
|
|
||||||
return result.scalars() if multi else result.scalar()
|
|
||||||
executor.query = executor.q = func
|
|
||||||
return executor
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
SQLAlchemy = deprecated('import from suou.sqlalchemy.asyncio instead')(SQLAlchemy)
|
||||||
|
AsyncSelectPagination = deprecated('import from suou.sqlalchemy.asyncio instead')(AsyncSelectPagination)
|
||||||
|
async_query = deprecated('import from suou.sqlalchemy.asyncio instead')(async_query)
|
||||||
|
|
||||||
# Optional dependency: do not import into __init__.py
|
# Optional dependency: do not import into __init__.py
|
||||||
__all__ = ('SQLAlchemy', 'async_query')
|
__all__ = ('SQLAlchemy', 'AsyncSelectPagination', 'async_query')
|
||||||
Loading…
Add table
Add a link
Reference in a new issue