refactor: 统一认证端点并删除冗余 profile 模块

- 合并 auth 端点: /verifications/verify → /verify, /verifications/resend → /resend
- 整合密码重置到 /verify 端点 (type=recovery)
- 移除未使用的 /auth/users 端点
- 添加 redirect URL 白名单验证 (site_url + additional_redirect_urls)
- 限流改用 Redis + IP 标识,替代内存锁
- 删除 v1/profile 死代码模块
- 更新前端 auth_api 适配新端点
- 添加 supabase site_url 和 additional_redirect_urls 配置
This commit is contained in:
zl-q
2026-03-07 14:55:00 +08:00
parent 1f6cb1a48f
commit ec33bb0cee
25 changed files with 421 additions and 1614 deletions
+13
View File
@@ -119,10 +119,23 @@ class SupabaseSettings(BaseModel):
public_scheme: str = "http"
public_host: str = "localhost"
kong_http_port: int = 8000
site_url: str = "http://localhost:3000"
additional_redirect_urls: list[str] = Field(default_factory=list)
anon_key: str = "CHANGE_ME"
service_role_key: str = "CHANGE_ME"
jwt_secret: str | None = None
@field_validator("additional_redirect_urls", mode="before")
@classmethod
def normalize_redirect_urls(cls, value: object) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [item.strip() for item in value.split(",") if item.strip()]
if isinstance(value, list):
return [str(item).strip() for item in value if str(item).strip()]
return []
@computed_field
@property
def public_url(self) -> str:
+51 -5
View File
@@ -3,10 +3,12 @@ from __future__ import annotations
import asyncio
from collections.abc import Mapping
from typing import Any, cast
from urllib.parse import urlparse
from fastapi import HTTPException
from supabase import AuthError
from core.config.settings import config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
@@ -47,7 +49,11 @@ class SupabaseAuthGateway(AuthServiceGateway):
"data": metadata,
}
if request.redirect_to:
payload["options"] = {"email_redirect_to": request.redirect_to}
payload["options"] = {
"email_redirect_to": _validate_redirect_url_or_raise(
request.redirect_to
)
}
try:
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
@@ -61,9 +67,12 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
if request.type != "signup":
raise HTTPException(status_code=422, detail="Invalid request")
client = self._get_client()
payload: dict[str, Any] = {
"type": "signup",
"type": request.type,
"email": request.email,
"token": request.token,
}
@@ -79,7 +88,16 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {"type": "signup", "email": request.email}
if request.type == "recovery":
await self.request_password_reset(
PasswordResetRequest(
email=request.email,
redirect_to=request.redirect_to,
)
)
return
payload: dict[str, Any] = {"type": request.type, "email": request.email}
try:
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
@@ -167,7 +185,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
options: dict[str, str] = {
"redirect_to": _validate_redirect_url_or_raise(request.redirect_to)
}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
@@ -243,11 +263,37 @@ def _map_auth_response(response: object, failure_message: str) -> SessionRespons
)
def _validate_redirect_url_or_raise(url: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
if not parsed.netloc:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
site_origin = _origin_of(config.supabase.site_url)
allowlist = {
site_origin,
*(_origin_of(item) for item in config.supabase.additional_redirect_urls),
}
target_origin = f"{parsed.scheme}://{parsed.netloc}".lower()
if target_origin not in allowlist:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
return url
def _origin_of(url: str) -> str:
parsed = urlparse(url.strip())
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
return ""
return f"{parsed.scheme}://{parsed.netloc}".lower()
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while True:
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = list(getattr(response, "users", []))
users.extend(batch)
+68 -8
View File
@@ -1,13 +1,27 @@
from __future__ import annotations
import asyncio
from collections import deque
from threading import Lock
from time import monotonic
from fastapi import HTTPException
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
_BUCKETS: dict[str, deque[float]] = {}
_LOCK = Lock()
_LAST_SEEN: dict[str, float] = {}
_LOCK = asyncio.Lock()
_CLEANUP_INTERVAL = 200
_CALL_COUNT = 0
logger = get_logger("v1.auth.rate_limit")
_REDIS_LIMIT_SCRIPT = """
local current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ARGV[1])
end
return current
"""
async def enforce_rate_limit(
@@ -17,30 +31,76 @@ async def enforce_rate_limit(
limit: int,
window_seconds: int,
) -> None:
_enforce_rate_limit_in_memory(
key=f"auth:rate_limit:{scope}:{identifier.lower()}",
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
try:
await _enforce_rate_limit_with_redis(
key=key,
limit=limit,
window_seconds=window_seconds,
)
return
except HTTPException:
raise
except Exception as exc: # noqa: BLE001
logger.warning(
"Rate limit fallback to in-memory",
scope=scope,
error_type=type(exc).__name__,
)
await _enforce_rate_limit_in_memory(
key=key,
limit=limit,
window_seconds=window_seconds,
)
def _enforce_rate_limit_in_memory(
async def _enforce_rate_limit_with_redis(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
client = await get_or_init_redis_client()
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds)
if int(current) > limit:
raise HTTPException(status_code=429, detail="Too many requests")
async def _enforce_rate_limit_in_memory(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
global _CALL_COUNT
now = monotonic()
with _LOCK:
async with _LOCK:
bucket = _BUCKETS.setdefault(key, deque())
_LAST_SEEN[key] = now
cutoff = now - float(window_seconds)
while bucket and bucket[0] <= cutoff:
bucket.popleft()
if len(bucket) >= limit:
raise HTTPException(status_code=429, detail="Too many requests")
bucket.append(now)
_CALL_COUNT += 1
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
_cleanup_stale_buckets(now)
def _cleanup_stale_buckets(now: float) -> None:
stale_keys = [
key
for key, last_seen in _LAST_SEEN.items()
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
]
for key in stale_keys:
_BUCKETS.pop(key, None)
_LAST_SEEN.pop(key, None)
def reset_rate_limit_state() -> None:
with _LOCK:
_BUCKETS.clear()
_BUCKETS.clear()
_LAST_SEEN.clear()
global _CALL_COUNT
_CALL_COUNT = 0
+45 -55
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Response
from fastapi import APIRouter, Depends, Request, Response
from fastapi import HTTPException
from core.auth.models import CurrentUser
@@ -11,7 +11,6 @@ from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
@@ -44,28 +43,45 @@ async def create_verification(
return await service.create_verification(payload)
@router.post("/verifications/verify", response_model=SessionResponse)
async def verify_verification(
@router.post("/verify", response_model=SessionResponse)
async def verify(
payload: VerificationVerifyRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
) -> SessionResponse | Response:
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
limit = 10
window_seconds = 600
await enforce_rate_limit(
scope="signup_verify",
identifier=payload.email,
limit=10,
window_seconds=600,
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=limit,
window_seconds=window_seconds,
)
return await service.verify_verification(payload)
if payload.type == "signup":
return await service.verify_verification(payload)
if payload.new_password is None:
raise HTTPException(status_code=422, detail="Invalid request")
await service.confirm_password_reset(
PasswordResetConfirmRequest(
email=payload.email,
token=payload.token,
new_password=payload.new_password,
)
)
return Response(status_code=204)
@router.post("/verifications/resend", status_code=204)
async def resend_verification(
@router.post("/resend", status_code=204)
async def resend(
payload: VerificationResendRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
await enforce_rate_limit(
scope="signup_resend",
identifier=payload.email,
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=5,
window_seconds=60,
)
@@ -90,11 +106,12 @@ async def create_session(
@router.post("/sessions/refresh", response_model=SessionResponse)
async def refresh_session(
payload: SessionRefreshRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
await enforce_rate_limit(
scope="refresh",
identifier=payload.refresh_token,
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
@@ -104,11 +121,12 @@ async def refresh_session(
@router.delete("/sessions", status_code=204)
async def delete_session(
payload: SessionDeleteRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="logout",
identifier=payload.refresh_token,
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
@@ -116,42 +134,14 @@ async def delete_session(
return Response(status_code=204)
@router.get("/users", response_model=UserByEmailResponse)
async def get_user_by_email(
email: str,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
service: AuthService = Depends(get_auth_service),
) -> UserByEmailResponse:
if current_user.role != "service_role" and current_user.email != email:
raise HTTPException(status_code=403, detail="Forbidden")
return await service.get_user_by_email(email)
@router.post("/password-reset", status_code=204)
async def request_password_reset(
payload: PasswordResetRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_request",
identifier=payload.email,
limit=5,
window_seconds=60,
)
await service.request_password_reset(payload)
return Response(status_code=204)
@router.post("/password-reset/confirm", status_code=204)
async def confirm_password_reset(
payload: PasswordResetConfirmRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_confirm",
identifier=payload.email,
limit=10,
window_seconds=600,
)
await service.confirm_password_reset(payload)
return Response(status_code=204)
def _client_ip(request: Request) -> str:
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
if first:
return first
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
host = request.client.host if request.client else ""
return host or "unknown"
+23 -4
View File
@@ -1,6 +1,11 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, EmailStr, Field
from typing import Literal
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
SUPABASE_PASSWORD_MIN_LENGTH = 6
OtpType = Literal["signup", "recovery"]
class VerificationCreateRequest(BaseModel):
@@ -8,7 +13,7 @@ class VerificationCreateRequest(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: EmailStr
password: str = Field(min_length=6)
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
redirect_to: str | None = None
invite_code: str | None = Field(
default=None,
@@ -20,16 +25,30 @@ class VerificationCreateRequest(BaseModel):
class VerificationResendRequest(BaseModel):
email: EmailStr
type: OtpType = "signup"
redirect_to: str | None = None
class VerificationVerifyRequest(BaseModel):
type: OtpType = "signup"
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str | None = Field(
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
)
@model_validator(mode="after")
def validate_type_payload(self) -> "VerificationVerifyRequest":
if self.type == "recovery" and self.new_password is None:
raise ValueError("new_password is required when type is recovery")
if self.type == "signup" and self.new_password is not None:
raise ValueError("new_password is only allowed when type is recovery")
return self
class SessionCreateRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=6)
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
class SessionRefreshRequest(BaseModel):
@@ -72,4 +91,4 @@ class PasswordResetRequest(BaseModel):
class PasswordResetConfirmRequest(BaseModel):
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str = Field(min_length=6)
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
-7
View File
@@ -8,7 +8,6 @@ from v1.auth.schemas import (
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
@@ -39,9 +38,6 @@ class AuthServiceGateway(Protocol):
async def delete_session(self, refresh_token: str | None) -> None:
raise NotImplementedError
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
@@ -79,9 +75,6 @@ class AuthService:
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return await self._gateway.get_user_by_email(email)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
await self._gateway.request_password_reset(request)
-1
View File
@@ -1 +0,0 @@
from __future__ import annotations
-95
View File
@@ -1,95 +0,0 @@
from __future__ import annotations
from typing import Annotated
from uuid import UUID
import jwt
from fastapi import Depends, Header, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from core.config.settings import config
from core.db import get_db
from core.logging import get_logger
from core.auth.models import CurrentUser
from v1.profile.repository import SQLAlchemyProfileRepository
from v1.profile.service import ProfileService
logger = get_logger("v1.profile.dependencies")
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
if not authorization:
logger.warning("JWT validation failed: missing authorization header")
raise HTTPException(status_code=401, detail="Unauthorized")
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer" or not token:
logger.warning("JWT validation failed: invalid authorization scheme")
raise HTTPException(status_code=401, detail="Unauthorized")
secret = config.supabase.jwt_secret
if not secret:
logger.error("JWT validation failed: secret not configured")
raise HTTPException(status_code=503, detail="JWT secret not configured")
supabase_url = config.supabase.public_url.rstrip("/")
expected_issuer = f"{supabase_url}/auth/v1"
try:
payload = jwt.decode(
token,
secret,
algorithms=["HS256"],
audience="authenticated",
issuer=expected_issuer,
options={
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"require": ["sub", "aud", "iss", "exp"],
},
)
except jwt.ExpiredSignatureError:
logger.warning("JWT validation failed: token expired")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidAudienceError:
logger.warning("JWT validation failed: invalid audience")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidIssuerError:
logger.warning("JWT validation failed: invalid issuer")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.InvalidSignatureError:
logger.warning("JWT validation failed: invalid signature")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.DecodeError:
logger.warning("JWT validation failed: malformed token")
raise HTTPException(status_code=401, detail="Unauthorized")
except jwt.PyJWTError as exc:
logger.warning(
"JWT validation failed: unknown error", error_type=type(exc).__name__
)
raise HTTPException(status_code=401, detail="Unauthorized") from exc
subject = payload.get("sub")
if not isinstance(subject, str) or not subject:
logger.warning("JWT validation failed: missing or invalid subject claim")
raise HTTPException(status_code=401, detail="Unauthorized")
try:
user_id = UUID(subject)
except ValueError:
logger.warning("JWT validation failed: invalid UUID in subject")
raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id))
email = payload.get("email") if isinstance(payload.get("email"), str) else None
role = payload.get("role") if isinstance(payload.get("role"), str) else None
return CurrentUser(id=user_id, email=email, role=role)
def get_profile_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> ProfileService:
repository = SQLAlchemyProfileRepository(session)
return ProfileService(repository=repository, session=session, current_user=user)
-81
View File
@@ -1,81 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
from core.logging import get_logger
from models.profile import Profile
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.profile.repository")
class ProfileRepository(Protocol):
"""Protocol defining the profile repository interface."""
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
"""Get profile by user ID."""
...
async def get_by_username(self, username: str) -> Profile | None:
"""Get profile by username."""
...
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
"""Update profile by user ID. Returns updated profile or None if not found."""
...
class SQLAlchemyProfileRepository(BaseRepository[Profile]):
"""SQLAlchemy implementation of ProfileRepository.
Note: This repository only performs CRUD operations.
- No commit (only flush) - service layer handles transactions
- No auth logic - service layer handles authorization
- No HTTP exceptions - returns None or raises SQLAlchemyError
"""
def __init__(self, session: AsyncSession) -> None:
super().__init__(session, Profile)
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
try:
return await self.get_by_id(user_id)
except SQLAlchemyError:
logger.exception("Profile lookup failed", user_id=str(user_id))
raise
async def get_by_username(self, username: str) -> Profile | None:
try:
stmt = (
select(Profile)
.where(Profile.username == username)
.where(Profile.deleted_at.is_(None))
.order_by(Profile.created_at.asc())
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError:
logger.exception("Profile lookup failed", username=username)
raise
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
if not update_data:
return await self.get_by_user_id(user_id)
try:
return await self.update_by_id(user_id, update_data)
except SQLAlchemyError:
logger.exception("Profile update failed", user_id=str(user_id))
raise
-36
View File
@@ -1,36 +0,0 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Path
from v1.profile.dependencies import get_profile_service
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
from v1.profile.service import ProfileService
router = APIRouter(prefix="/profile", tags=["profile"])
@router.get("/me", response_model=ProfileResponse)
async def get_me(
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.get_me()
@router.patch("/me", response_model=ProfileResponse)
async def update_me(
payload: ProfileUpdateRequest,
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.update_me(payload)
@router.get("/{username}", response_model=ProfileResponse)
async def get_by_username(
username: Annotated[
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
],
service: Annotated[ProfileService, Depends(get_profile_service)],
) -> ProfileResponse:
return await service.get_by_username(username)
-41
View File
@@ -1,41 +0,0 @@
from __future__ import annotations
from pydantic import (
AnyHttpUrl,
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
class ProfileResponse(BaseModel):
id: str
username: str
avatar_url: str | None = None
bio: str | None = None
class ProfileUpdateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
username: str | None = Field(default=None, min_length=3, max_length=30)
avatar_url: str | None = Field(default=None)
bio: str | None = Field(default=None, max_length=200)
@field_validator("avatar_url", mode="before")
@classmethod
def validate_avatar_url(cls, v: str | None) -> str | None:
if v is None:
return None
parsed = AnyHttpUrl(v)
if parsed.scheme not in ("http", "https"):
raise ValueError("avatar_url must use http or https scheme")
return str(parsed)
@model_validator(mode="after")
def require_one_field(self) -> "ProfileUpdateRequest":
if self.username is None and self.avatar_url is None and self.bio is None:
raise ValueError("At least one field must be provided")
return self
-103
View File
@@ -1,103 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from v1.profile.repository import ProfileRepository
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.profile.service")
class ProfileService(BaseService):
"""Profile service handling business logic and transactions.
Responsibilities:
- Authorization checks
- Transaction boundary (commit/rollback)
- Converting ORM models to response schemas
"""
_repository: ProfileRepository
_session: AsyncSession
def __init__(
self,
repository: ProfileRepository,
session: AsyncSession,
current_user: CurrentUser | None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._session = session
async def get_me(self) -> ProfileResponse:
user_id = self.require_user_id()
try:
profile = await self._repository.get_by_user_id(user_id)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
user_id = self.require_user_id()
update_data: dict[str, str | None] = {
key: value
for key, value in {
"username": update.username,
"avatar_url": update.avatar_url,
"bio": update.bio,
}.items()
if value is not None
}
if not update_data:
raise HTTPException(status_code=400, detail="No fields to update")
try:
profile = await self._repository.update_by_user_id(user_id, update_data)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
async def get_by_username(self, username: str) -> ProfileResponse:
try:
profile = await self._repository.get_by_username(username)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="Profile store unavailable")
if profile is None:
raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse(
id=str(profile.id),
username=profile.username,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
+1 -1
View File
@@ -119,7 +119,7 @@ def test_auth_flow_e2e() -> None:
assert verification.status == 202
verify = request_context.post(
"/api/v1/auth/verifications/verify",
"/api/v1/auth/verify",
data=json.dumps(
{
"email": "user@example.com",
+82 -112
View File
@@ -138,7 +138,7 @@ def test_signup_verify_returns_token_response() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications/verify",
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "123456"},
)
assert response.status_code == 200
@@ -166,8 +166,8 @@ def test_signup_resend_returns_generic_message() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications/resend",
json={"email": "user@example.com"},
"/api/v1/auth/resend",
json={"type": "recovery", "email": "user@example.com"},
)
assert response.status_code == 204
assert response.content == b""
@@ -191,7 +191,7 @@ def test_signup_verify_invalid_token_returns_problem_details() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications/verify",
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "000000"},
)
assert response.status_code == 401
@@ -230,7 +230,7 @@ def test_signup_start_existing_email_returns_problem_details() -> None:
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Content"
assert body["title"] == "Unprocessable Entity"
assert body["status"] == 422
assert body["detail"] == "Invalid signup request"
finally:
@@ -254,13 +254,13 @@ def test_signup_verify_rate_limited_after_too_many_attempts() -> None:
try:
for _ in range(10):
ok = client.post(
"/api/v1/auth/verifications/verify",
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "123456"},
)
assert ok.status_code == 200
blocked = client.post(
"/api/v1/auth/verifications/verify",
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "123456"},
)
assert blocked.status_code == 429
@@ -286,13 +286,13 @@ def test_signup_resend_rate_limited_after_too_many_attempts() -> None:
try:
for _ in range(5):
ok = client.post(
"/api/v1/auth/verifications/resend",
"/api/v1/auth/resend",
json={"email": "user@example.com"},
)
assert ok.status_code == 204
blocked = client.post(
"/api/v1/auth/verifications/resend",
"/api/v1/auth/resend",
json={"email": "user@example.com"},
)
assert blocked.status_code == 429
@@ -493,6 +493,37 @@ def test_refresh_rate_limited_after_too_many_attempts() -> None:
app.dependency_overrides = {}
def test_refresh_rate_limit_not_bypassed_by_changing_refresh_token() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for index in range(10):
blocked = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": f"invalid-{index}"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": "invalid-extra"},
)
assert blocked.status_code == 429
finally:
app.dependency_overrides = {}
def test_logout_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
@@ -529,6 +560,39 @@ def test_logout_rate_limited_after_too_many_attempts() -> None:
app.dependency_overrides = {}
def test_logout_rate_limit_not_bypassed_by_changing_refresh_token() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for index in range(10):
ok = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": f"refresh-{index}"},
)
assert ok.status_code == 204
blocked = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": "refresh-extra"},
)
assert blocked.status_code == 429
finally:
app.dependency_overrides = {}
def test_signup_start_validation_error_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
@@ -548,7 +612,7 @@ def test_signup_start_validation_error_returns_problem_details() -> None:
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Content"
assert body["title"] == "Unprocessable Entity"
assert body["status"] == 422
assert body["detail"] == "Invalid request"
finally:
@@ -577,110 +641,13 @@ def test_signup_start_missing_username_returns_problem_details() -> None:
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Content"
assert body["title"] == "Unprocessable Entity"
assert body["status"] == 422
assert body["detail"] == "Invalid request"
finally:
app.dependency_overrides = {}
def test_get_user_by_email_returns_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/auth/users",
params={"email": "user@example.com"},
)
assert response.status_code == 200
body = response.json()
assert body["email"] == "user@example.com"
assert body["id"] == "user-1"
finally:
app.dependency_overrides = {}
def test_get_user_by_email_not_found_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="missing@example.com",
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/auth/users",
params={"email": "missing@example.com"},
)
assert response.status_code == 404
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Not Found"
assert body["status"] == 404
assert body["detail"] == "User not found"
finally:
app.dependency_overrides = {}
def test_get_user_by_email_forbidden_when_querying_other_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="self@example.com",
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/auth/users",
params={"email": "target@example.com"},
)
assert response.status_code == 403
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Forbidden"
assert body["status"] == 403
assert body["detail"] == "Forbidden"
finally:
app.dependency_overrides = {}
def test_password_reset_request_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
@@ -697,7 +664,7 @@ def test_password_reset_request_returns_204() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset",
"/api/v1/auth/resend",
json={"email": "user@example.com"},
)
assert response.status_code == 204
@@ -721,8 +688,9 @@ def test_password_reset_confirm_returns_204() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
"/api/v1/auth/verify",
json={
"type": "recovery",
"email": "user@example.com",
"token": "123456",
"new_password": "newpassword123",
@@ -749,8 +717,9 @@ def test_password_reset_confirm_invalid_token_returns_401() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
"/api/v1/auth/verify",
json={
"type": "recovery",
"email": "user@example.com",
"token": "000000",
"new_password": "newpassword123",
@@ -781,8 +750,9 @@ def test_password_reset_confirm_weak_password_returns_422() -> None:
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
"/api/v1/auth/verify",
json={
"type": "recovery",
"email": "user@example.com",
"token": "123456",
"new_password": "123",
@@ -14,6 +14,11 @@ def test_social_prefixed_supabase_env_populates_settings(
monkeypatch.setenv("SOCIAL_SUPABASE__ANON_KEY", "anon-key")
monkeypatch.setenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "service-key")
monkeypatch.setenv("SOCIAL_SUPABASE__JWT_SECRET", "jwt-secret")
monkeypatch.setenv("SOCIAL_SUPABASE__SITE_URL", "https://app.example.com")
monkeypatch.setenv(
"SOCIAL_SUPABASE__ADDITIONAL_REDIRECT_URLS",
'["https://a.example.com", "https://b.example.com/path"]',
)
monkeypatch.setenv("SOCIAL_DATABASE__HOST", "db")
monkeypatch.setenv("SOCIAL_DATABASE__PORT", "5432")
monkeypatch.setenv("SOCIAL_DATABASE__NAME", "app")
@@ -26,10 +31,16 @@ def test_social_prefixed_supabase_env_populates_settings(
assert settings.supabase.anon_key == "anon-key"
assert settings.supabase.service_role_key == "service-key"
assert settings.supabase.jwt_secret == "jwt-secret"
assert settings.supabase.site_url == "https://app.example.com"
assert settings.supabase.additional_redirect_urls == [
"https://a.example.com",
"https://b.example.com/path",
]
supabase_settings = settings.model_dump()["supabase"]
assert supabase_settings["public_url"] == "https://public.example:8443"
assert supabase_settings["anon_key"] == "anon-key"
assert supabase_settings["service_role_key"] == "service-key"
assert supabase_settings["jwt_secret"] == "jwt-secret"
assert supabase_settings["site_url"] == "https://app.example.com"
assert settings.database_url == "postgresql+asyncpg://user:pass@db:5432/app"
@@ -7,7 +7,11 @@ import pytest
from fastapi import HTTPException
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
VerificationResendRequest,
)
class TestSupabaseAuthGateway:
@@ -56,6 +60,22 @@ class TestSupabaseAuthGateway:
options={"redirect_to": "http://localhost:3000/reset-password"},
)
@pytest.mark.asyncio
async def test_request_password_reset_rejects_redirect_outside_allowlist(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
request = PasswordResetRequest(
email="test@example.com",
redirect_to="https://evil.example/reset",
)
with pytest.raises(HTTPException) as exc_info:
await sut.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid redirect URL"
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
@@ -165,3 +185,24 @@ class TestSupabaseAuthGateway:
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"
@pytest.mark.asyncio
async def test_recovery_resend_calls_reset_password_email(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
mock_client.auth.reset_password_email = mock_reset_email
await sut.resend_verification(
VerificationResendRequest(
type="recovery",
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
)
mock_reset_email.assert_called_once_with(
"test@example.com",
options={"redirect_to": "http://localhost:3000/reset-password"},
)
@@ -33,6 +33,25 @@ def test_signup_verify_requires_six_digit_token() -> None:
VerificationVerifyRequest(email="user@example.com", token="abc123")
def test_signup_verify_disallows_new_password() -> None:
with pytest.raises(ValidationError):
VerificationVerifyRequest(
type="signup",
email="user@example.com",
token="123456",
new_password="secret123",
)
def test_recovery_verify_requires_new_password() -> None:
with pytest.raises(ValidationError):
VerificationVerifyRequest(
type="recovery",
email="user@example.com",
token="123456",
)
def test_signup_resend_requires_valid_email() -> None:
with pytest.raises(ValidationError):
VerificationResendRequest(email="invalid")
+20 -50
View File
@@ -5,6 +5,8 @@ import pytest
import v1.auth.gateway as auth_gateway_module
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
@@ -44,40 +46,15 @@ class FakeGateway(AuthServiceGateway):
return None
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return UserByEmailResponse(
id="user-1",
email=email,
created_at="2026-02-24T00:00:00Z",
email_confirmed_at=None,
)
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
@pytest.mark.asyncio
async def test_signup_maps_response() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
service = AuthService(gateway=FakeGateway(token_response))
start_result = await service.create_verification(
VerificationCreateRequest(
username="demo", email="user@example.com", password="secret123"
)
)
assert start_result.email == "user@example.com"
result = await service.verify_verification(
VerificationVerifyRequest(email="user@example.com", token="123456")
)
assert result.access_token == "access"
assert result.refresh_token == "refresh"
assert result.user.id == "user-1"
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
class LogoutAssertingGateway(AuthServiceGateway):
@@ -109,6 +86,14 @@ class LogoutAssertingGateway(AuthServiceGateway):
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
@pytest.mark.asyncio
async def test_logout_forwards_refresh_token() -> None:
@@ -117,23 +102,6 @@ async def test_logout_forwards_refresh_token() -> None:
await service.delete_session("refresh-token")
@pytest.mark.asyncio
async def test_get_user_by_email_forwards_to_gateway() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
service = AuthService(gateway=FakeGateway(token_response))
result = await service.get_user_by_email("user@example.com")
assert result.email == "user@example.com"
@pytest.mark.asyncio
async def test_signup_resend_returns_none() -> None:
user = AuthUser(id="user-1", email="user@example.com")
@@ -182,7 +150,9 @@ async def test_supabase_signup_passes_username_in_metadata(
class FakeClient:
auth = FakeSupabaseAuth()
monkeypatch.setattr(auth_gateway_module, "create_client", lambda *_: FakeClient())
monkeypatch.setattr(
auth_gateway_module.supabase_service, "get_client", lambda: FakeClient()
)
gateway = auth_gateway_module.SupabaseAuthGateway()
await gateway.create_verification(
@@ -1,285 +0,0 @@
from __future__ import annotations
import time
from typing import Any
from uuid import UUID
import jwt
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.profile.dependencies import get_current_user
class TestGetCurrentUser:
"""Tests for JWT validation in get_current_user dependency."""
@pytest.fixture
def jwt_secret(self) -> str:
return "super-secret-jwt-token-with-at-least-32-characters"
@pytest.fixture
def valid_user_id(self) -> str:
return "00000000-0000-0000-0000-000000000123"
@pytest.fixture
def valid_payload(self, valid_user_id: str) -> dict[str, Any]:
"""Valid JWT payload with all required claims."""
now = int(time.time())
return {
"sub": valid_user_id,
"aud": "authenticated",
"iss": "http://localhost:8001/auth/v1",
"exp": now + 3600, # 1 hour from now
"iat": now,
}
def _create_token(self, payload: dict[str, Any], secret: str) -> str:
return jwt.encode(payload, secret, algorithm="HS256")
def test_valid_token_returns_current_user(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
valid_user_id: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Valid JWT with correct aud/iss/exp should return CurrentUser."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
token = self._create_token(valid_payload, jwt_secret)
authorization = f"Bearer {token}"
result = get_current_user(authorization=authorization)
assert isinstance(result, CurrentUser)
assert result.id == UUID(valid_user_id)
def test_missing_authorization_raises_401(self) -> None:
"""Missing Authorization header should raise 401."""
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=None)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Unauthorized"
def test_invalid_scheme_raises_401(self) -> None:
"""Non-Bearer scheme should raise 401."""
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization="Basic dXNlcjpwYXNz")
assert exc_info.value.status_code == 401
def test_expired_token_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Expired JWT should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["exp"] = int(time.time()) - 3600 # 1 hour ago
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_invalid_audience_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with wrong audience should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["aud"] = "wrong-audience"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_invalid_issuer_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with wrong issuer should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["iss"] = "http://malicious-site.com/auth/v1"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_missing_subject_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT without 'sub' claim should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
del valid_payload["sub"]
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_wrong_secret_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT signed with wrong secret should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
token = self._create_token(
valid_payload, "wrong-secret-key-that-is-long-enough"
)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
def test_jwt_secret_not_configured_raises_503(
self, valid_payload: dict[str, Any], monkeypatch: pytest.MonkeyPatch
) -> None:
"""Missing JWT secret in config should raise 503."""
monkeypatch.setattr("v1.profile.dependencies.config.supabase.jwt_secret", None)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization="Bearer some-token")
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "JWT secret not configured"
def test_invalid_uuid_in_subject_raises_401(
self,
jwt_secret: str,
valid_payload: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""JWT with non-UUID 'sub' claim should raise 401."""
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.jwt_secret", jwt_secret
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_scheme",
"http",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.public_host",
"localhost",
)
monkeypatch.setattr(
"v1.profile.dependencies.config.supabase.kong_http_port",
8001,
)
valid_payload["sub"] = "not-a-valid-uuid"
token = self._create_token(valid_payload, jwt_secret)
with pytest.raises(HTTPException) as exc_info:
get_current_user(authorization=f"Bearer {token}")
assert exc_info.value.status_code == 401
@@ -1,170 +0,0 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.profile import Profile
from v1.profile.repository import ProfileRepository
from v1.profile.schemas import ProfileUpdateRequest
from v1.profile.service import ProfileService
def _create_mock_profile(
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
username: str = "demo",
avatar_url: str | None = None,
bio: str | None = None,
) -> Profile:
"""Create a mock Profile ORM object."""
profile = MagicMock(spec=Profile)
profile.id = user_id
profile.username = username
profile.avatar_url = avatar_url
profile.bio = bio
return profile
class FakeRepo:
"""Fake repository for testing that conforms to ProfileRepository protocol."""
def __init__(self, profile: Profile | None) -> None:
self._profile = profile
async def get_by_user_id(self, user_id: UUID) -> Profile | None:
if self._profile and user_id == self._profile.id:
return self._profile
return None
async def get_by_username(self, username: str) -> Profile | None:
if self._profile and username == self._profile.username:
return self._profile
return None
async def update_by_user_id(
self, user_id: UUID, update_data: dict[str, str | None]
) -> Profile | None:
if not self._profile or user_id != self._profile.id:
return None
# Apply updates to mock
for key, value in update_data.items():
if hasattr(self._profile, key):
setattr(self._profile, key, value)
return self._profile
# Verify FakeRepo implements the protocol
_repo_check: ProfileRepository = FakeRepo(None)
@pytest.fixture
def mock_session() -> AsyncMock:
"""Create a mock AsyncSession."""
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
return session
@pytest.mark.asyncio
async def test_get_me_returns_profile(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id, username="demo")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
result = await service.get_me()
assert result.username == "demo"
assert result.id == str(user_id)
@pytest.mark.asyncio
async def test_get_me_not_found_raises_404(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(None),
session=mock_session,
current_user=user,
)
with pytest.raises(HTTPException) as exc_info:
await service.get_me()
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id, username="demo")
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
result = await service.update_me(ProfileUpdateRequest(username="updated"))
assert result.username == "updated"
mock_session.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = _create_mock_profile(user_id=user_id)
user = CurrentUser(id=user_id)
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=user,
)
# Create a request with all None values by bypassing validation
update = MagicMock(spec=ProfileUpdateRequest)
update.username = None
update.avatar_url = None
update.bio = None
with pytest.raises(HTTPException) as exc_info:
await service.update_me(update)
assert exc_info.value.status_code == 400
@pytest.mark.asyncio
async def test_get_by_username_returns_profile(mock_session: AsyncMock) -> None:
profile = _create_mock_profile(username="demo")
service = ProfileService(
repository=FakeRepo(profile),
session=mock_session,
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
result = await service.get_by_username("demo")
assert result.username == "demo"
@pytest.mark.asyncio
async def test_get_by_username_not_found_raises_404(mock_session: AsyncMock) -> None:
service = ProfileService(
repository=FakeRepo(None),
session=mock_session,
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.get_by_username("unknown")
assert exc_info.value.status_code == 404
@@ -1,65 +0,0 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from v1.profile.schemas import ProfileResponse, ProfileUpdateRequest
def test_profile_response_maps_fields() -> None:
response = ProfileResponse(
id="user-1",
username="demo",
avatar_url=None,
bio=None,
)
assert response.id == "user-1"
assert response.username == "demo"
def test_profile_update_requires_one_field() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest()
def test_profile_update_accepts_valid_https_url() -> None:
request = ProfileUpdateRequest(avatar_url="https://example.com/avatar.png")
assert request.avatar_url == "https://example.com/avatar.png"
def test_profile_update_accepts_valid_http_url() -> None:
request = ProfileUpdateRequest(
avatar_url="http://localhost:8001/storage/avatar.png"
)
assert request.avatar_url == "http://localhost:8001/storage/avatar.png"
def test_profile_update_rejects_invalid_url() -> None:
with pytest.raises(ValidationError) as exc_info:
ProfileUpdateRequest(avatar_url="not-a-valid-url")
errors = exc_info.value.errors()
assert len(errors) == 1
assert "avatar_url" in str(errors[0]["loc"])
def test_profile_update_rejects_javascript_url() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest(avatar_url="javascript:alert('xss')")
def test_profile_update_rejects_data_url() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest(avatar_url="data:text/html,<script>alert('xss')</script>")
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
request = ProfileUpdateRequest(username="tester", avatar_url=None)
assert request.avatar_url is None
assert request.username == "tester"
def test_profile_update_rejects_display_name_field() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest.model_validate({"display_name": "legacy"})