refactor: 统一认证端点并删除冗余 profile 模块

- 合并 auth 端点: /verifications/verify → /verify, /verifications/resend → /resend
- 整合密码重置到 /verify 端点 (type=recovery)
- 移除未使用的 /auth/users 端点
- 添加 redirect URL 白名单验证 (site_url + additional_redirect_urls)
- 限流改用 Redis + IP 标识,替代内存锁
- 删除 v1/profile 死代码模块
- 更新前端 auth_api 适配新端点
- 添加 supabase site_url 和 additional_redirect_urls 配置
This commit is contained in:
zl-q
2026-03-07 14:55:00 +08:00
parent 1f6cb1a48f
commit ec33bb0cee
25 changed files with 421 additions and 1614 deletions
+51 -5
View File
@@ -3,10 +3,12 @@ from __future__ import annotations
import asyncio
from collections.abc import Mapping
from typing import Any, cast
from urllib.parse import urlparse
from fastapi import HTTPException
from supabase import AuthError
from core.config.settings import config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
@@ -47,7 +49,11 @@ class SupabaseAuthGateway(AuthServiceGateway):
"data": metadata,
}
if request.redirect_to:
payload["options"] = {"email_redirect_to": request.redirect_to}
payload["options"] = {
"email_redirect_to": _validate_redirect_url_or_raise(
request.redirect_to
)
}
try:
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
@@ -61,9 +67,12 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
if request.type != "signup":
raise HTTPException(status_code=422, detail="Invalid request")
client = self._get_client()
payload: dict[str, Any] = {
"type": "signup",
"type": request.type,
"email": request.email,
"token": request.token,
}
@@ -79,7 +88,16 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {"type": "signup", "email": request.email}
if request.type == "recovery":
await self.request_password_reset(
PasswordResetRequest(
email=request.email,
redirect_to=request.redirect_to,
)
)
return
payload: dict[str, Any] = {"type": request.type, "email": request.email}
try:
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
@@ -167,7 +185,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
options: dict[str, str] = {
"redirect_to": _validate_redirect_url_or_raise(request.redirect_to)
}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
@@ -243,11 +263,37 @@ def _map_auth_response(response: object, failure_message: str) -> SessionRespons
)
def _validate_redirect_url_or_raise(url: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
if not parsed.netloc:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
site_origin = _origin_of(config.supabase.site_url)
allowlist = {
site_origin,
*(_origin_of(item) for item in config.supabase.additional_redirect_urls),
}
target_origin = f"{parsed.scheme}://{parsed.netloc}".lower()
if target_origin not in allowlist:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
return url
def _origin_of(url: str) -> str:
parsed = urlparse(url.strip())
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
return ""
return f"{parsed.scheme}://{parsed.netloc}".lower()
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while True:
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = list(getattr(response, "users", []))
users.extend(batch)
+68 -8
View File
@@ -1,13 +1,27 @@
from __future__ import annotations
import asyncio
from collections import deque
from threading import Lock
from time import monotonic
from fastapi import HTTPException
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
_BUCKETS: dict[str, deque[float]] = {}
_LOCK = Lock()
_LAST_SEEN: dict[str, float] = {}
_LOCK = asyncio.Lock()
_CLEANUP_INTERVAL = 200
_CALL_COUNT = 0
logger = get_logger("v1.auth.rate_limit")
_REDIS_LIMIT_SCRIPT = """
local current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ARGV[1])
end
return current
"""
async def enforce_rate_limit(
@@ -17,30 +31,76 @@ async def enforce_rate_limit(
limit: int,
window_seconds: int,
) -> None:
_enforce_rate_limit_in_memory(
key=f"auth:rate_limit:{scope}:{identifier.lower()}",
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
try:
await _enforce_rate_limit_with_redis(
key=key,
limit=limit,
window_seconds=window_seconds,
)
return
except HTTPException:
raise
except Exception as exc: # noqa: BLE001
logger.warning(
"Rate limit fallback to in-memory",
scope=scope,
error_type=type(exc).__name__,
)
await _enforce_rate_limit_in_memory(
key=key,
limit=limit,
window_seconds=window_seconds,
)
def _enforce_rate_limit_in_memory(
async def _enforce_rate_limit_with_redis(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
client = await get_or_init_redis_client()
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds)
if int(current) > limit:
raise HTTPException(status_code=429, detail="Too many requests")
async def _enforce_rate_limit_in_memory(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
global _CALL_COUNT
now = monotonic()
with _LOCK:
async with _LOCK:
bucket = _BUCKETS.setdefault(key, deque())
_LAST_SEEN[key] = now
cutoff = now - float(window_seconds)
while bucket and bucket[0] <= cutoff:
bucket.popleft()
if len(bucket) >= limit:
raise HTTPException(status_code=429, detail="Too many requests")
bucket.append(now)
_CALL_COUNT += 1
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
_cleanup_stale_buckets(now)
def _cleanup_stale_buckets(now: float) -> None:
stale_keys = [
key
for key, last_seen in _LAST_SEEN.items()
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
]
for key in stale_keys:
_BUCKETS.pop(key, None)
_LAST_SEEN.pop(key, None)
def reset_rate_limit_state() -> None:
with _LOCK:
_BUCKETS.clear()
_BUCKETS.clear()
_LAST_SEEN.clear()
global _CALL_COUNT
_CALL_COUNT = 0
+45 -55
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Response
from fastapi import APIRouter, Depends, Request, Response
from fastapi import HTTPException
from core.auth.models import CurrentUser
@@ -11,7 +11,6 @@ from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
@@ -44,28 +43,45 @@ async def create_verification(
return await service.create_verification(payload)
@router.post("/verifications/verify", response_model=SessionResponse)
async def verify_verification(
@router.post("/verify", response_model=SessionResponse)
async def verify(
payload: VerificationVerifyRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
) -> SessionResponse | Response:
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
limit = 10
window_seconds = 600
await enforce_rate_limit(
scope="signup_verify",
identifier=payload.email,
limit=10,
window_seconds=600,
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=limit,
window_seconds=window_seconds,
)
return await service.verify_verification(payload)
if payload.type == "signup":
return await service.verify_verification(payload)
if payload.new_password is None:
raise HTTPException(status_code=422, detail="Invalid request")
await service.confirm_password_reset(
PasswordResetConfirmRequest(
email=payload.email,
token=payload.token,
new_password=payload.new_password,
)
)
return Response(status_code=204)
@router.post("/verifications/resend", status_code=204)
async def resend_verification(
@router.post("/resend", status_code=204)
async def resend(
payload: VerificationResendRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
await enforce_rate_limit(
scope="signup_resend",
identifier=payload.email,
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=5,
window_seconds=60,
)
@@ -90,11 +106,12 @@ async def create_session(
@router.post("/sessions/refresh", response_model=SessionResponse)
async def refresh_session(
payload: SessionRefreshRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
await enforce_rate_limit(
scope="refresh",
identifier=payload.refresh_token,
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
@@ -104,11 +121,12 @@ async def refresh_session(
@router.delete("/sessions", status_code=204)
async def delete_session(
payload: SessionDeleteRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="logout",
identifier=payload.refresh_token,
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
@@ -116,42 +134,14 @@ async def delete_session(
return Response(status_code=204)
@router.get("/users", response_model=UserByEmailResponse)
async def get_user_by_email(
email: str,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
service: AuthService = Depends(get_auth_service),
) -> UserByEmailResponse:
if current_user.role != "service_role" and current_user.email != email:
raise HTTPException(status_code=403, detail="Forbidden")
return await service.get_user_by_email(email)
@router.post("/password-reset", status_code=204)
async def request_password_reset(
payload: PasswordResetRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_request",
identifier=payload.email,
limit=5,
window_seconds=60,
)
await service.request_password_reset(payload)
return Response(status_code=204)
@router.post("/password-reset/confirm", status_code=204)
async def confirm_password_reset(
payload: PasswordResetConfirmRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_confirm",
identifier=payload.email,
limit=10,
window_seconds=600,
)
await service.confirm_password_reset(payload)
return Response(status_code=204)
def _client_ip(request: Request) -> str:
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
if first:
return first
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
host = request.client.host if request.client else ""
return host or "unknown"
+23 -4
View File
@@ -1,6 +1,11 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, EmailStr, Field
from typing import Literal
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
SUPABASE_PASSWORD_MIN_LENGTH = 6
OtpType = Literal["signup", "recovery"]
class VerificationCreateRequest(BaseModel):
@@ -8,7 +13,7 @@ class VerificationCreateRequest(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: EmailStr
password: str = Field(min_length=6)
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
redirect_to: str | None = None
invite_code: str | None = Field(
default=None,
@@ -20,16 +25,30 @@ class VerificationCreateRequest(BaseModel):
class VerificationResendRequest(BaseModel):
email: EmailStr
type: OtpType = "signup"
redirect_to: str | None = None
class VerificationVerifyRequest(BaseModel):
type: OtpType = "signup"
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str | None = Field(
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
)
@model_validator(mode="after")
def validate_type_payload(self) -> "VerificationVerifyRequest":
if self.type == "recovery" and self.new_password is None:
raise ValueError("new_password is required when type is recovery")
if self.type == "signup" and self.new_password is not None:
raise ValueError("new_password is only allowed when type is recovery")
return self
class SessionCreateRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=6)
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
class SessionRefreshRequest(BaseModel):
@@ -72,4 +91,4 @@ class PasswordResetRequest(BaseModel):
class PasswordResetConfirmRequest(BaseModel):
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str = Field(min_length=6)
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
-7
View File
@@ -8,7 +8,6 @@ from v1.auth.schemas import (
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
@@ -39,9 +38,6 @@ class AuthServiceGateway(Protocol):
async def delete_session(self, refresh_token: str | None) -> None:
raise NotImplementedError
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
@@ -79,9 +75,6 @@ class AuthService:
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return await self._gateway.get_user_by_email(email)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
await self._gateway.request_password_reset(request)