Files
eryao/backend/src/v1/auth/rate_limit.py
T

115 lines
2.8 KiB
Python

from __future__ import annotations
import asyncio
from collections import deque
from time import monotonic
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
_BUCKETS: dict[str, deque[float]] = {}
_LAST_SEEN: dict[str, float] = {}
_LOCK = asyncio.Lock()
_CLEANUP_INTERVAL = 200
_CALL_COUNT = 0
logger = get_logger("v1.auth.rate_limit")
_REDIS_LIMIT_SCRIPT = """
local current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ARGV[1])
end
return current
"""
async def enforce_rate_limit(
*,
scope: str,
identifier: str,
limit: int,
window_seconds: int,
) -> None:
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
try:
await _enforce_rate_limit_with_redis(
key=key,
limit=limit,
window_seconds=window_seconds,
)
return
except ApiProblemError:
raise
except Exception as exc: # noqa: BLE001
logger.warning(
"Rate limit fallback to in-memory",
scope=scope,
error_type=type(exc).__name__,
)
await _enforce_rate_limit_in_memory(
key=key,
limit=limit,
window_seconds=window_seconds,
)
async def _enforce_rate_limit_with_redis(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
client = await get_or_init_redis_client()
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
if int(current) > limit:
raise ApiProblemError(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
)
async def _enforce_rate_limit_in_memory(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
global _CALL_COUNT
now = monotonic()
async with _LOCK:
bucket = _BUCKETS.setdefault(key, deque())
_LAST_SEEN[key] = now
cutoff = now - float(window_seconds)
while bucket and bucket[0] <= cutoff:
bucket.popleft()
if len(bucket) >= limit:
raise ApiProblemError(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
)
bucket.append(now)
_CALL_COUNT += 1
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
_cleanup_stale_buckets(now)
def _cleanup_stale_buckets(now: float) -> None:
stale_keys = [
key
for key, last_seen in _LAST_SEEN.items()
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
]
for key in stale_keys:
_BUCKETS.pop(key, None)
_LAST_SEEN.pop(key, None)
def reset_rate_limit_state() -> None:
_BUCKETS.clear()
_LAST_SEEN.clear()
global _CALL_COUNT
_CALL_COUNT = 0