Source code for flask_limiter.wrappers
from __future__ import annotations
import dataclasses
import typing
import weakref
from collections.abc import Iterator
from flask import request
from flask.wrappers import Response
from limits import RateLimitItem, parse_many
from limits.strategies import RateLimiter
from limits.util import WindowStats
from .typing import Callable
if typing.TYPE_CHECKING:
from .extension import Limiter
[docs]
class RequestLimit:
"""
Provides details of a rate limit within the context of a request
"""
#: The instance of the rate limit
limit: RateLimitItem
#: The full key for the request against which the rate limit is tested
key: str
#: Whether the limit was breached within the context of this request
breached: bool
#: Whether the limit is a shared limit
shared: bool
def __init__(
self,
extension: Limiter,
limit: RateLimitItem,
request_args: list[str],
breached: bool,
shared: bool,
) -> None:
self.extension: weakref.ProxyType[Limiter] = weakref.proxy(extension)
self.limit = limit
self.request_args = request_args
self.key = limit.key_for(*request_args)
self.breached = breached
self.shared = shared
self._window: WindowStats | None = None
@property
def limiter(self) -> RateLimiter:
return typing.cast(RateLimiter, self.extension.limiter)
@property
def window(self) -> WindowStats:
if not self._window:
self._window = self.limiter.get_window_stats(self.limit, *self.request_args)
return self._window
@property
def reset_at(self) -> int:
"""Timestamp at which the rate limit will be reset"""
return int(self.window[0] + 1)
@property
def remaining(self) -> int:
"""Quantity remaining for this rate limit"""
return self.window[1]
@dataclasses.dataclass(eq=True, unsafe_hash=True)
class Limit:
"""
simple wrapper to encapsulate limits and their context
"""
limit: RateLimitItem
key_func: Callable[[], str]
_scope: str | Callable[[str], str] | None
per_method: bool = False
methods: tuple[str, ...] | None = None
error_message: str | None = None
exempt_when: Callable[[], bool] | None = None
override_defaults: bool | None = False
deduct_when: Callable[[Response], bool] | None = None
on_breach: Callable[[RequestLimit], Response | None] | None = None
_cost: Callable[[], int] | int = 1
shared: bool = False
def __post_init__(self) -> None:
if self.methods:
self.methods = tuple([k.lower() for k in self.methods])
@property
def is_exempt(self) -> bool:
"""Check if the limit is exempt."""
if self.exempt_when:
return self.exempt_when()
return False
@property
def scope(self) -> str | None:
return (
self._scope(request.endpoint or "")
if callable(self._scope)
else self._scope
)
@property
def cost(self) -> int:
if isinstance(self._cost, int):
return self._cost
return self._cost()
@property
def method_exempt(self) -> bool:
"""Check if the limit is not applicable for this method"""
return self.methods is not None and request.method.lower() not in self.methods
def scope_for(self, endpoint: str, method: str | None) -> str:
"""
Derive final bucket (scope) for this limit given the endpoint
and request method. If the limit is shared between multiple
routes, the scope does not include the endpoint.
"""
limit_scope = self.scope
if limit_scope:
if self.shared:
scope = limit_scope
else:
scope = f"{endpoint}:{limit_scope}"
else:
scope = endpoint
if self.per_method:
assert method
scope += f":{method.upper()}"
return scope
@dataclasses.dataclass(eq=True, unsafe_hash=True)
class LimitGroup:
"""
represents a group of related limits either from a string or a callable
that returns one
"""
limit_provider: Callable[[], str] | str
key_function: Callable[[], str]
scope: str | Callable[[str], str] | None = None
methods: tuple[str, ...] | None = None
error_message: str | None = None
exempt_when: Callable[[], bool] | None = None
override_defaults: bool | None = False
deduct_when: Callable[[Response], bool] | None = None
on_breach: Callable[[RequestLimit], Response | None] | None = None
per_method: bool = False
cost: Callable[[], int] | int | None = None
shared: bool = False
def __iter__(self) -> Iterator[Limit]:
limit_str = (
self.limit_provider()
if callable(self.limit_provider)
else self.limit_provider
)
limit_items = parse_many(limit_str) if limit_str else []
for limit in limit_items:
yield Limit(
limit,
self.key_function,
self.scope,
self.per_method,
self.methods,
self.error_message,
self.exempt_when,
self.override_defaults,
self.deduct_when,
self.on_breach,
self.cost or 1,
self.shared,
)