from __future__ import annotations import asyncio from collections import deque from time import monotonic from core.http.errors import ApiProblemError from core.logging import get_logger from services.base.redis import get_or_init_redis_client _BUCKETS: dict[str, deque[float]] = {} _LAST_SEEN: dict[str, float] = {} _LOCK = asyncio.Lock() _CLEANUP_INTERVAL = 200 _CALL_COUNT = 0 logger = get_logger("v1.auth.rate_limit") _REDIS_LIMIT_SCRIPT = """ local current = redis.call("INCR", KEYS[1]) if current == 1 then redis.call("EXPIRE", KEYS[1], ARGV[1]) end return current """ async def enforce_rate_limit( *, scope: str, identifier: str, limit: int, window_seconds: int, ) -> None: key = f"auth:rate_limit:{scope}:{identifier.lower()}" try: await _enforce_rate_limit_with_redis( key=key, limit=limit, window_seconds=window_seconds, ) return except ApiProblemError: raise except Exception as exc: # noqa: BLE001 logger.warning( "Rate limit fallback to in-memory", scope=scope, error_type=type(exc).__name__, ) await _enforce_rate_limit_in_memory( key=key, limit=limit, window_seconds=window_seconds, ) async def _enforce_rate_limit_with_redis( *, key: str, limit: int, window_seconds: int, ) -> None: client = await get_or_init_redis_client() current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await] if int(current) > limit: raise ApiProblemError( status_code=429, code="AUTH_TOO_MANY_REQUESTS", detail="Too many requests", ) async def _enforce_rate_limit_in_memory( *, key: str, limit: int, window_seconds: int, ) -> None: global _CALL_COUNT now = monotonic() async with _LOCK: bucket = _BUCKETS.setdefault(key, deque()) _LAST_SEEN[key] = now cutoff = now - float(window_seconds) while bucket and bucket[0] <= cutoff: bucket.popleft() if len(bucket) >= limit: raise ApiProblemError( status_code=429, code="AUTH_TOO_MANY_REQUESTS", detail="Too many requests", ) bucket.append(now) _CALL_COUNT += 1 if _CALL_COUNT % _CLEANUP_INTERVAL == 0: _cleanup_stale_buckets(now) def _cleanup_stale_buckets(now: float) -> None: stale_keys = [ key for key, last_seen in _LAST_SEEN.items() if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600) ] for key in stale_keys: _BUCKETS.pop(key, None) _LAST_SEEN.pop(key, None) def reset_rate_limit_state() -> None: _BUCKETS.clear() _LAST_SEEN.clear() global _CALL_COUNT _CALL_COUNT = 0