refactor: 统一认证端点并删除冗余 profile 模块
- 合并 auth 端点: /verifications/verify → /verify, /verifications/resend → /resend - 整合密码重置到 /verify 端点 (type=recovery) - 移除未使用的 /auth/users 端点 - 添加 redirect URL 白名单验证 (site_url + additional_redirect_urls) - 限流改用 Redis + IP 标识,替代内存锁 - 删除 v1/profile 死代码模块 - 更新前端 auth_api 适配新端点 - 添加 supabase site_url 和 additional_redirect_urls 配置
This commit is contained in:
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
@@ -47,7 +49,11 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
"data": metadata,
|
||||
}
|
||||
if request.redirect_to:
|
||||
payload["options"] = {"email_redirect_to": request.redirect_to}
|
||||
payload["options"] = {
|
||||
"email_redirect_to": _validate_redirect_url_or_raise(
|
||||
request.redirect_to
|
||||
)
|
||||
}
|
||||
try:
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
@@ -61,9 +67,12 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
) -> SessionResponse:
|
||||
if request.type != "signup":
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": "signup",
|
||||
"type": request.type,
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
@@ -79,7 +88,16 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"type": "signup", "email": request.email}
|
||||
if request.type == "recovery":
|
||||
await self.request_password_reset(
|
||||
PasswordResetRequest(
|
||||
email=request.email,
|
||||
redirect_to=request.redirect_to,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
payload: dict[str, Any] = {"type": request.type, "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
@@ -167,7 +185,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {"redirect_to": request.redirect_to}
|
||||
options: dict[str, str] = {
|
||||
"redirect_to": _validate_redirect_url_or_raise(request.redirect_to)
|
||||
}
|
||||
await asyncio.to_thread(reset_email, email, options=options)
|
||||
else:
|
||||
await asyncio.to_thread(reset_email, email)
|
||||
@@ -243,11 +263,37 @@ def _map_auth_response(response: object, failure_message: str) -> SessionRespons
|
||||
)
|
||||
|
||||
|
||||
def _validate_redirect_url_or_raise(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"}:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
if not parsed.netloc:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
|
||||
site_origin = _origin_of(config.supabase.site_url)
|
||||
allowlist = {
|
||||
site_origin,
|
||||
*(_origin_of(item) for item in config.supabase.additional_redirect_urls),
|
||||
}
|
||||
target_origin = f"{parsed.scheme}://{parsed.netloc}".lower()
|
||||
if target_origin not in allowlist:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
return url
|
||||
|
||||
|
||||
def _origin_of(url: str) -> str:
|
||||
parsed = urlparse(url.strip())
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
return ""
|
||||
return f"{parsed.scheme}://{parsed.netloc}".lower()
|
||||
|
||||
|
||||
def _list_auth_users(client: Any) -> list[Any]:
|
||||
users: list[Any] = []
|
||||
page = 1
|
||||
max_pages = 100
|
||||
|
||||
while True:
|
||||
while page <= max_pages:
|
||||
response = client.auth.admin.list_users(page=page, per_page=100)
|
||||
batch = list(getattr(response, "users", []))
|
||||
users.extend(batch)
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from threading import Lock
|
||||
from time import monotonic
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
_BUCKETS: dict[str, deque[float]] = {}
|
||||
_LOCK = Lock()
|
||||
_LAST_SEEN: dict[str, float] = {}
|
||||
_LOCK = asyncio.Lock()
|
||||
_CLEANUP_INTERVAL = 200
|
||||
_CALL_COUNT = 0
|
||||
logger = get_logger("v1.auth.rate_limit")
|
||||
_REDIS_LIMIT_SCRIPT = """
|
||||
local current = redis.call("INCR", KEYS[1])
|
||||
if current == 1 then
|
||||
redis.call("EXPIRE", KEYS[1], ARGV[1])
|
||||
end
|
||||
return current
|
||||
"""
|
||||
|
||||
|
||||
async def enforce_rate_limit(
|
||||
@@ -17,30 +31,76 @@ async def enforce_rate_limit(
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
_enforce_rate_limit_in_memory(
|
||||
key=f"auth:rate_limit:{scope}:{identifier.lower()}",
|
||||
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
|
||||
try:
|
||||
await _enforce_rate_limit_with_redis(
|
||||
key=key,
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
return
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Rate limit fallback to in-memory",
|
||||
scope=scope,
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
await _enforce_rate_limit_in_memory(
|
||||
key=key,
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _enforce_rate_limit_in_memory(
|
||||
async def _enforce_rate_limit_with_redis(
|
||||
*,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
client = await get_or_init_redis_client()
|
||||
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds)
|
||||
if int(current) > limit:
|
||||
raise HTTPException(status_code=429, detail="Too many requests")
|
||||
|
||||
|
||||
async def _enforce_rate_limit_in_memory(
|
||||
*,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
global _CALL_COUNT
|
||||
now = monotonic()
|
||||
with _LOCK:
|
||||
async with _LOCK:
|
||||
bucket = _BUCKETS.setdefault(key, deque())
|
||||
_LAST_SEEN[key] = now
|
||||
cutoff = now - float(window_seconds)
|
||||
while bucket and bucket[0] <= cutoff:
|
||||
bucket.popleft()
|
||||
if len(bucket) >= limit:
|
||||
raise HTTPException(status_code=429, detail="Too many requests")
|
||||
bucket.append(now)
|
||||
_CALL_COUNT += 1
|
||||
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
|
||||
_cleanup_stale_buckets(now)
|
||||
|
||||
|
||||
def _cleanup_stale_buckets(now: float) -> None:
|
||||
stale_keys = [
|
||||
key
|
||||
for key, last_seen in _LAST_SEEN.items()
|
||||
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
|
||||
]
|
||||
for key in stale_keys:
|
||||
_BUCKETS.pop(key, None)
|
||||
_LAST_SEEN.pop(key, None)
|
||||
|
||||
|
||||
def reset_rate_limit_state() -> None:
|
||||
with _LOCK:
|
||||
_BUCKETS.clear()
|
||||
_BUCKETS.clear()
|
||||
_LAST_SEEN.clear()
|
||||
global _CALL_COUNT
|
||||
_CALL_COUNT = 0
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
@@ -11,7 +11,6 @@ from v1.auth.dependencies import get_auth_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
@@ -44,28 +43,45 @@ async def create_verification(
|
||||
return await service.create_verification(payload)
|
||||
|
||||
|
||||
@router.post("/verifications/verify", response_model=SessionResponse)
|
||||
async def verify_verification(
|
||||
@router.post("/verify", response_model=SessionResponse)
|
||||
async def verify(
|
||||
payload: VerificationVerifyRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
) -> SessionResponse | Response:
|
||||
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
|
||||
limit = 10
|
||||
window_seconds = 600
|
||||
await enforce_rate_limit(
|
||||
scope="signup_verify",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=600,
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
return await service.verify_verification(payload)
|
||||
if payload.type == "signup":
|
||||
return await service.verify_verification(payload)
|
||||
if payload.new_password is None:
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
await service.confirm_password_reset(
|
||||
PasswordResetConfirmRequest(
|
||||
email=payload.email,
|
||||
token=payload.token,
|
||||
new_password=payload.new_password,
|
||||
)
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/verifications/resend", status_code=204)
|
||||
async def resend_verification(
|
||||
@router.post("/resend", status_code=204)
|
||||
async def resend(
|
||||
payload: VerificationResendRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
|
||||
await enforce_rate_limit(
|
||||
scope="signup_resend",
|
||||
identifier=payload.email,
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
@@ -90,11 +106,12 @@ async def create_session(
|
||||
@router.post("/sessions/refresh", response_model=SessionResponse)
|
||||
async def refresh_session(
|
||||
payload: SessionRefreshRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="refresh",
|
||||
identifier=payload.refresh_token,
|
||||
identifier=_client_ip(request),
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
@@ -104,11 +121,12 @@ async def refresh_session(
|
||||
@router.delete("/sessions", status_code=204)
|
||||
async def delete_session(
|
||||
payload: SessionDeleteRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="logout",
|
||||
identifier=payload.refresh_token,
|
||||
identifier=_client_ip(request),
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
@@ -116,42 +134,14 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.get("/users", response_model=UserByEmailResponse)
|
||||
async def get_user_by_email(
|
||||
email: str,
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> UserByEmailResponse:
|
||||
if current_user.role != "service_role" and current_user.email != email:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return await service.get_user_by_email(email)
|
||||
|
||||
|
||||
@router.post("/password-reset", status_code=204)
|
||||
async def request_password_reset(
|
||||
payload: PasswordResetRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="password_reset_request",
|
||||
identifier=payload.email,
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.request_password_reset(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", status_code=204)
|
||||
async def confirm_password_reset(
|
||||
payload: PasswordResetConfirmRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="password_reset_confirm",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=600,
|
||||
)
|
||||
await service.confirm_password_reset(payload)
|
||||
return Response(status_code=204)
|
||||
def _client_ip(request: Request) -> str:
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
if first:
|
||||
return first
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
host = request.client.host if request.client else ""
|
||||
return host or "unknown"
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
OtpType = Literal["signup", "recovery"]
|
||||
|
||||
|
||||
class VerificationCreateRequest(BaseModel):
|
||||
@@ -8,7 +13,7 @@ class VerificationCreateRequest(BaseModel):
|
||||
|
||||
username: str = Field(min_length=3, max_length=30)
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=6)
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
redirect_to: str | None = None
|
||||
invite_code: str | None = Field(
|
||||
default=None,
|
||||
@@ -20,16 +25,30 @@ class VerificationCreateRequest(BaseModel):
|
||||
|
||||
class VerificationResendRequest(BaseModel):
|
||||
email: EmailStr
|
||||
type: OtpType = "signup"
|
||||
redirect_to: str | None = None
|
||||
|
||||
|
||||
class VerificationVerifyRequest(BaseModel):
|
||||
type: OtpType = "signup"
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str | None = Field(
|
||||
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_type_payload(self) -> "VerificationVerifyRequest":
|
||||
if self.type == "recovery" and self.new_password is None:
|
||||
raise ValueError("new_password is required when type is recovery")
|
||||
if self.type == "signup" and self.new_password is not None:
|
||||
raise ValueError("new_password is only allowed when type is recovery")
|
||||
return self
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=6)
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
|
||||
|
||||
class SessionRefreshRequest(BaseModel):
|
||||
@@ -72,4 +91,4 @@ class PasswordResetRequest(BaseModel):
|
||||
class PasswordResetConfirmRequest(BaseModel):
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str = Field(min_length=6)
|
||||
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
|
||||
@@ -8,7 +8,6 @@ from v1.auth.schemas import (
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
@@ -39,9 +38,6 @@ class AuthServiceGateway(Protocol):
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -79,9 +75,6 @@ class AuthService:
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
return await self._gateway.get_user_by_email(email)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
await self._gateway.request_password_reset(request)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user