115 lines
2.8 KiB
Python
115 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import deque
|
|
from time import monotonic
|
|
|
|
from core.http.errors import ApiProblemError
|
|
|
|
from core.logging import get_logger
|
|
from services.base.redis import get_or_init_redis_client
|
|
|
|
_BUCKETS: dict[str, deque[float]] = {}
|
|
_LAST_SEEN: dict[str, float] = {}
|
|
_LOCK = asyncio.Lock()
|
|
_CLEANUP_INTERVAL = 200
|
|
_CALL_COUNT = 0
|
|
logger = get_logger("v1.auth.rate_limit")
|
|
_REDIS_LIMIT_SCRIPT = """
|
|
local current = redis.call("INCR", KEYS[1])
|
|
if current == 1 then
|
|
redis.call("EXPIRE", KEYS[1], ARGV[1])
|
|
end
|
|
return current
|
|
"""
|
|
|
|
|
|
async def enforce_rate_limit(
|
|
*,
|
|
scope: str,
|
|
identifier: str,
|
|
limit: int,
|
|
window_seconds: int,
|
|
) -> None:
|
|
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
|
|
try:
|
|
await _enforce_rate_limit_with_redis(
|
|
key=key,
|
|
limit=limit,
|
|
window_seconds=window_seconds,
|
|
)
|
|
return
|
|
except ApiProblemError:
|
|
raise
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning(
|
|
"Rate limit fallback to in-memory",
|
|
scope=scope,
|
|
error_type=type(exc).__name__,
|
|
)
|
|
await _enforce_rate_limit_in_memory(
|
|
key=key,
|
|
limit=limit,
|
|
window_seconds=window_seconds,
|
|
)
|
|
|
|
|
|
async def _enforce_rate_limit_with_redis(
|
|
*,
|
|
key: str,
|
|
limit: int,
|
|
window_seconds: int,
|
|
) -> None:
|
|
client = await get_or_init_redis_client()
|
|
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
|
|
if int(current) > limit:
|
|
raise ApiProblemError(
|
|
status_code=429,
|
|
code="AUTH_TOO_MANY_REQUESTS",
|
|
detail="Too many requests",
|
|
)
|
|
|
|
|
|
async def _enforce_rate_limit_in_memory(
|
|
*,
|
|
key: str,
|
|
limit: int,
|
|
window_seconds: int,
|
|
) -> None:
|
|
global _CALL_COUNT
|
|
now = monotonic()
|
|
async with _LOCK:
|
|
bucket = _BUCKETS.setdefault(key, deque())
|
|
_LAST_SEEN[key] = now
|
|
cutoff = now - float(window_seconds)
|
|
while bucket and bucket[0] <= cutoff:
|
|
bucket.popleft()
|
|
if len(bucket) >= limit:
|
|
raise ApiProblemError(
|
|
status_code=429,
|
|
code="AUTH_TOO_MANY_REQUESTS",
|
|
detail="Too many requests",
|
|
)
|
|
bucket.append(now)
|
|
_CALL_COUNT += 1
|
|
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
|
|
_cleanup_stale_buckets(now)
|
|
|
|
|
|
def _cleanup_stale_buckets(now: float) -> None:
|
|
stale_keys = [
|
|
key
|
|
for key, last_seen in _LAST_SEEN.items()
|
|
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
|
|
]
|
|
for key in stale_keys:
|
|
_BUCKETS.pop(key, None)
|
|
_LAST_SEEN.pop(key, None)
|
|
|
|
|
|
def reset_rate_limit_state() -> None:
|
|
_BUCKETS.clear()
|
|
_LAST_SEEN.clear()
|
|
global _CALL_COUNT
|
|
_CALL_COUNT = 0
|