improve decorator typing

This commit is contained in:
Yusur 2025-10-11 10:22:49 +02:00
parent 21021875c8
commit fca91bdc54
5 changed files with 27 additions and 19 deletions

View file

@ -19,7 +19,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from __future__ import annotations from __future__ import annotations
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable, TypeVar
_T = TypeVar('_T')
_U = TypeVar('_U')
BRICKS = '@abcdefghijklmnopqrstuvwxyz+?-\'/' BRICKS = '@abcdefghijklmnopqrstuvwxyz+?-\'/'
@ -122,7 +125,7 @@ def dei_args(**renames):
Dear conservatives, this does not influence the ability to call the wrapped function Dear conservatives, this does not influence the ability to call the wrapped function
with the original parameter names. with the original parameter names.
""" """
def decorator(func: Callable): def decorator(func: Callable[_T, _U]) -> Callable[_T, _U]:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for alias_name, actual_name in renames.items(): for alias_name, actual_name in renames.items():

View file

@ -19,7 +19,7 @@ import math
from threading import RLock from threading import RLock
import time import time
from types import CoroutineType, NoneType from types import CoroutineType, NoneType
from typing import Callable, Iterable, Mapping, TypeVar from typing import Any, Callable, Iterable, Mapping, Never, TypeVar
import warnings import warnings
from functools import update_wrapper, wraps, lru_cache from functools import update_wrapper, wraps, lru_cache
@ -70,7 +70,7 @@ def not_implemented(msg: Callable | str | None = None):
""" """
A more elegant way to say a method is not implemented, but may get in the future. A more elegant way to say a method is not implemented, but may get in the future.
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable[_T, Any]) -> Callable[_T, Never]:
da_msg = msg if isinstance(msg, str) else 'method {name}() is not implemented'.format(name=func.__name__) da_msg = msg if isinstance(msg, str) else 'method {name}() is not implemented'.format(name=func.__name__)
@wraps(func) @wraps(func)
def wrapper(*a, **k): def wrapper(*a, **k):
@ -288,7 +288,7 @@ def timed_cache(ttl: int, maxsize: int = 128, typed: bool = False, *, async_: bo
NEW 0.5.0 NEW 0.5.0
""" """
def decorator(func): def decorator(func: Callable[_T, _U]) -> Callable[_T, _U]:
start_time = None start_time = None
if async_: if async_:
@ -318,7 +318,7 @@ def timed_cache(ttl: int, maxsize: int = 128, typed: bool = False, *, async_: bo
return wrapper return wrapper
return decorator return decorator
def none_pass(func: Callable, *args, **kwargs) -> Callable: def none_pass(func: Callable[_T, _U], *args, **kwargs) -> Callable[_T, _U]:
""" """
Wrap callable so that gets called only on not None values. Wrap callable so that gets called only on not None values.

View file

@ -35,7 +35,7 @@ def lucky(validators: Iterable[Callable[[_U], bool]] = ()):
NEW 0.7.0 NEW 0.7.0
""" """
def decorator(func: Callable[..., _U]): def decorator(func: Callable[_T, _U]) -> Callable[_T, _U]:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs) -> _U: def wrapper(*args, **kwargs) -> _U:
try: try:
@ -102,7 +102,7 @@ def rng_overload(prev_func: RngCallable[..., _U] | int | None, /, *, weight: int
if isinstance(prev_func, int) and weight == 1: if isinstance(prev_func, int) and weight == 1:
weight, prev_func = prev_func, None weight, prev_func = prev_func, None
def decorator(func: Callable[_T, _U]): def decorator(func: Callable[_T, _U]) -> RngCallable[_T, _U]:
nonlocal prev_func nonlocal prev_func
if prev_func is None: if prev_func is None:
prev_func = RngCallable(func, weight=weight) prev_func = RngCallable(func, weight=weight)

View file

@ -1,5 +1,5 @@
""" """
Utilities for SQLAlchemy Utilities for SQLAlchemy.
--- ---
@ -33,12 +33,16 @@ from ..iding import Siq, SiqGen, SiqType, SiqCache
from ..classtools import Incomplete, Wanted from ..classtools import Incomplete, Wanted
_T = TypeVar('_T') _T = TypeVar('_T')
_U = TypeVar('_U')
# SIQs are 14 bytes long. Storage is padded for alignment
# Not to be confused with SiqType.
IdType: TypeEngine = LargeBinary(16) IdType: TypeEngine = LargeBinary(16)
"""
Database type for SIQ.
SIQs are 14 bytes long. Storage is padded for alignment
Not to be confused with SiqType.
"""
def create_session(url: str) -> Session: def create_session(url: str) -> Session:
""" """
@ -52,7 +56,6 @@ def create_session(url: str) -> Session:
return Session(bind = engine) return Session(bind = engine)
def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]: def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete[UserSigner]:
""" """
Generate a user signing function. Generate a user signing function.
@ -80,9 +83,6 @@ def token_signer(id_attr: Column | str, secret_attr: Column | str) -> Incomplete
return Incomplete(Wanted(token_signer_factory)) return Incomplete(Wanted(token_signer_factory))
## (in)Utilities for use in web apps below ## (in)Utilities for use in web apps below
@deprecated('not part of the public API and not even working') @deprecated('not part of the public API and not even working')
@ -93,6 +93,8 @@ class AuthSrc(metaclass=ABCMeta):
This is an abstract class and is NOT usable directly. This is an abstract class and is NOT usable directly.
This is not part of the public API This is not part of the public API
DEPRECATED
''' '''
def required_exc(self) -> Never: def required_exc(self) -> Never:
raise ValueError('required field missing') raise ValueError('required field missing')
@ -140,7 +142,7 @@ def require_auth_base(cls: type[DeclarativeBase], *, src: AuthSrc, column: str |
invalid_exc = src.invalid_exc or _default_invalid invalid_exc = src.invalid_exc or _default_invalid
required_exc = src.required_exc or (lambda: _default_invalid('Login required')) required_exc = src.required_exc or (lambda: _default_invalid('Login required'))
def decorator(func: Callable): def decorator(func: Callable[_T, _U]) -> Callable[_T, _U]:
@wraps(func) @wraps(func)
def wrapper(*a, **ka): def wrapper(*a, **ka):
ka[dest] = get_user(src.get_token()) ka[dest] = get_user(src.get_token())

View file

@ -21,14 +21,17 @@ from __future__ import annotations
from functools import wraps from functools import wraps
from contextvars import ContextVar, Token from contextvars import ContextVar, Token
from typing import Callable, TypeVar
from sqlalchemy import Select, Table, func, select from sqlalchemy import Select, Table, func, select
from sqlalchemy.orm import DeclarativeBase, lazyload from sqlalchemy.orm import DeclarativeBase, lazyload
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from flask_sqlalchemy.pagination import Pagination from flask_sqlalchemy.pagination import Pagination
from suou.classtools import MISSING
from suou.exceptions import NotFoundError from suou.exceptions import NotFoundError
_T = TypeVar('_T')
_U = TypeVar('_U')
class SQLAlchemy: class SQLAlchemy:
""" """
Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy() Drop-in (in fact, almost) replacement for flask_sqlalchemy.SQLAlchemy()
@ -186,7 +189,7 @@ def async_query(db: SQLAlchemy, multi: False):
The query function remains available as the .q or .query attribute. The query function remains available as the .q or .query attribute.
""" """
def decorator(func): def decorator(func: Callable[_T, _U]) -> Callable[_T, _U]:
@wraps(func) @wraps(func)
async def executor(*args, **kwargs): async def executor(*args, **kwargs):
async with db as session: async with db as session: