feat: complete auth/profile username migration and runtime safeguards

This commit is contained in:
qzl
2026-02-25 10:20:43 +08:00
parent 8bdcb674bb
commit 7d6dda57c1
24 changed files with 720 additions and 166 deletions
+2
View File
@@ -7,3 +7,5 @@ from uuid import UUID
@dataclass(frozen=True)
class CurrentUser:
id: UUID
email: str | None = None
role: str | None = None
-5
View File
@@ -26,14 +26,9 @@ class Profile(TimestampMixin, SoftDeleteMixin, Base):
)
username: Mapped[str] = mapped_column(
String(30),
unique=True,
nullable=False,
index=True,
)
display_name: Mapped[str | None] = mapped_column(
String(50),
nullable=True,
)
avatar_url: Mapped[str | None] = mapped_column(
Text,
nullable=True,
+2 -1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from v1.auth.service import AuthService, SupabaseAuthGateway
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.service import AuthService
def get_auth_service() -> AuthService:
+154
View File
@@ -0,0 +1,154 @@
from __future__ import annotations
import asyncio
from typing import Any, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.schemas import (
AuthTokenResponse,
AuthUser,
AuthUserByEmailResponse,
LoginRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
_admin_client: Any
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
self._admin_client = create_client(settings.url, settings.service_role_key)
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
"data": {"username": request.username},
}
try:
sign_up = cast(Any, self._client.auth.sign_up)
response = await asyncio.to_thread(sign_up, payload)
return _map_auth_response(response, "Authentication failed")
except AuthError as exc:
logger.warning("Signup failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Authentication failed"
) from exc
async def login(self, request: LoginRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error_type=type(exc).__name__)
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def logout(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
normalized_email = email.lower()
user = next(
(
candidate
for candidate in users
if str(getattr(candidate, "email", "")).lower() == normalized_email
),
None,
)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return AuthUserByEmailResponse(
id=str(getattr(user, "id", "")),
email=str(getattr(user, "email", "")),
created_at=str(getattr(user, "created_at", "")),
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
else None
),
)
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return AuthTokenResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
while True:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = list(getattr(response, "users", []))
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
+18 -1
View File
@@ -1,10 +1,16 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Response
from typing import Annotated
from fastapi import APIRouter, Depends, Response
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.auth.dependencies import get_auth_service
from v1.profile.dependencies import get_current_user
from v1.auth.schemas import (
AuthTokenResponse,
AuthUserByEmailResponse,
LoginRequest,
LogoutRequest,
RefreshRequest,
@@ -47,3 +53,14 @@ async def logout(
) -> Response:
await service.logout(payload.refresh_token)
return Response(status_code=204)
@router.get("/users/by-email", response_model=AuthUserByEmailResponse)
async def get_user_by_email(
email: str,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
service: AuthService = Depends(get_auth_service),
) -> AuthUserByEmailResponse:
if current_user.role != "service_role" and current_user.email != email:
raise HTTPException(status_code=403, detail="Forbidden")
return await service.get_user_by_email(email)
+8 -1
View File
@@ -6,9 +6,9 @@ from pydantic import BaseModel, EmailStr, Field
class SignupRequest(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: EmailStr
password: str = Field(min_length=6)
display_name: str | None = None
redirect_to: str | None = None
@@ -38,6 +38,13 @@ class AuthTokenResponse(BaseModel):
user: AuthUser
class AuthUserByEmailResponse(BaseModel):
id: str
email: EmailStr
created_at: str
email_confirmed_at: str | None = None
class SignupPendingResponse(BaseModel):
status: Literal["pending_verification"] = "pending_verification"
user: AuthUser
+6 -103
View File
@@ -1,25 +1,16 @@
from __future__ import annotations
import asyncio
from typing import Any, Protocol, cast
from typing import Protocol
from fastapi import HTTPException
from supabase import AuthError, create_client
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.schemas import (
AuthTokenResponse,
AuthUser,
AuthUserByEmailResponse,
LoginRequest,
RefreshRequest,
SignupRequest,
)
logger = get_logger("v1.auth.service")
class AuthServiceGateway(Protocol):
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
raise NotImplementedError
@@ -33,79 +24,8 @@ class AuthServiceGateway(Protocol):
async def logout(self, refresh_token: str | None) -> None:
raise NotImplementedError
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
}
if request.display_name:
payload = {
**payload,
"data": {"display_name": request.display_name},
}
try:
sign_up = cast(Any, self._client.auth.sign_up)
response = await asyncio.to_thread(sign_up, payload)
return _map_auth_response(response, "Authentication failed")
except AuthError as exc:
logger.warning("Signup failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Authentication failed"
) from exc
async def login(self, request: LoginRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error=str(exc))
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def logout(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
raise NotImplementedError
class AuthService:
@@ -126,22 +46,5 @@ class AuthService:
async def logout(self, refresh_token: str | None) -> None:
await self._gateway.logout(refresh_token)
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return AuthTokenResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
return await self._gateway.get_user_by_email(email)
+3 -1
View File
@@ -82,7 +82,9 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id))
return CurrentUser(id=user_id)
email = payload.get("email") if isinstance(payload.get("email"), str) else None
role = payload.get("role") if isinstance(payload.get("role"), str) else None
return CurrentUser(id=user_id, email=email, role=role)
def get_profile_service(
+10 -1
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
@@ -54,7 +55,15 @@ class SQLAlchemyProfileRepository(BaseRepository[Profile]):
async def get_by_username(self, username: str) -> Profile | None:
try:
return await self.get_one(Profile.username == username)
stmt = (
select(Profile)
.where(Profile.username == username)
.where(Profile.deleted_at.is_(None))
.order_by(Profile.created_at.asc())
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError:
logger.exception("Profile lookup failed", username=username)
raise
+12 -4
View File
@@ -1,18 +1,26 @@
from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator
from pydantic import (
AnyHttpUrl,
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
class ProfileResponse(BaseModel):
id: str
username: str
display_name: str | None = None
avatar_url: str | None = None
bio: str | None = None
class ProfileUpdateRequest(BaseModel):
display_name: str | None = Field(default=None, max_length=50)
model_config = ConfigDict(extra="forbid")
username: str | None = Field(default=None, min_length=3, max_length=30)
avatar_url: str | None = Field(default=None)
bio: str | None = Field(default=None, max_length=200)
@@ -28,6 +36,6 @@ class ProfileUpdateRequest(BaseModel):
@model_validator(mode="after")
def require_one_field(self) -> "ProfileUpdateRequest":
if self.display_name is None and self.avatar_url is None and self.bio is None:
if self.username is None and self.avatar_url is None and self.bio is None:
raise ValueError("At least one field must be provided")
return self
+1 -4
View File
@@ -51,7 +51,6 @@ class ProfileService(BaseService):
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
@@ -61,7 +60,7 @@ class ProfileService(BaseService):
update_data: dict[str, str | None] = {
key: value
for key, value in {
"display_name": update.display_name,
"username": update.username,
"avatar_url": update.avatar_url,
"bio": update.bio,
}.items()
@@ -84,7 +83,6 @@ class ProfileService(BaseService):
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)
@@ -100,7 +98,6 @@ class ProfileService(BaseService):
return ProfileResponse(
id=str(profile.id),
username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url,
bio=profile.bio,
)