Merge branch 'feature/auth-profile-impl' into dev

This commit is contained in:
qzl
2026-02-25 10:21:09 +08:00
24 changed files with 720 additions and 166 deletions
@@ -0,0 +1,121 @@
"""drop_profile_display_name_and_trigger_username
Revision ID: 20260224_drop_profile
Revises: 20260224_bind_profiles_auth
Create Date: 2026-02-24 22:10:00.000000
"""
from typing import Sequence, Union
from alembic import op
revision: str = "20260224_drop_profile"
down_revision: Union[str, Sequence[str], None] = "20260224_bind_profiles_auth"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
ALTER TABLE public.profiles
DROP CONSTRAINT IF EXISTS uq_profiles_username
"""
)
op.execute(
"""
ALTER TABLE public.profiles
DROP COLUMN IF EXISTS display_name
"""
)
op.execute(
"""
CREATE OR REPLACE FUNCTION public.create_profile_for_new_user()
RETURNS trigger
LANGUAGE plpgsql
SECURITY DEFINER
SET search_path = public
AS $$
BEGIN
INSERT INTO public.profiles (id, username)
VALUES (
NEW.id,
COALESCE(
NULLIF(NEW.raw_user_meta_data->>'username', ''),
'user_' || substr(replace(NEW.id::text, '-', ''), 1, 25)
)
)
ON CONFLICT (id) DO NOTHING;
RETURN NEW;
END;
$$
"""
)
def downgrade() -> None:
op.execute(
"""
ALTER TABLE public.profiles
ADD COLUMN IF NOT EXISTS display_name VARCHAR(50)
"""
)
op.execute(
"""
WITH ranked AS (
SELECT
id,
username,
row_number() OVER (
PARTITION BY username
ORDER BY created_at ASC, id ASC
) AS rn
FROM public.profiles
WHERE username IS NOT NULL
)
UPDATE public.profiles p
SET username = LEFT(p.username, 24) || '_' || (ranked.rn - 1)::text
FROM ranked
WHERE p.id = ranked.id
AND ranked.rn > 1
"""
)
op.execute(
"""
ALTER TABLE public.profiles
ADD CONSTRAINT uq_profiles_username UNIQUE (username)
"""
)
op.execute(
"""
CREATE OR REPLACE FUNCTION public.create_profile_for_new_user()
RETURNS trigger
LANGUAGE plpgsql
SECURITY DEFINER
SET search_path = public
AS $$
BEGIN
INSERT INTO public.profiles (id, username, display_name)
VALUES (
NEW.id,
'user_' || substr(replace(NEW.id::text, '-', ''), 1, 25),
COALESCE(
NULLIF(NEW.raw_user_meta_data->>'display_name', ''),
NULLIF(NEW.raw_user_meta_data->>'full_name', '')
)
)
ON CONFLICT (id) DO NOTHING;
RETURN NEW;
END;
$$
"""
)
+2
View File
@@ -7,3 +7,5 @@ from uuid import UUID
@dataclass(frozen=True) @dataclass(frozen=True)
class CurrentUser: class CurrentUser:
id: UUID id: UUID
email: str | None = None
role: str | None = None
-5
View File
@@ -26,14 +26,9 @@ class Profile(TimestampMixin, SoftDeleteMixin, Base):
) )
username: Mapped[str] = mapped_column( username: Mapped[str] = mapped_column(
String(30), String(30),
unique=True,
nullable=False, nullable=False,
index=True, index=True,
) )
display_name: Mapped[str | None] = mapped_column(
String(50),
nullable=True,
)
avatar_url: Mapped[str | None] = mapped_column( avatar_url: Mapped[str | None] = mapped_column(
Text, Text,
nullable=True, nullable=True,
+2 -1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from v1.auth.service import AuthService, SupabaseAuthGateway from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.service import AuthService
def get_auth_service() -> AuthService: def get_auth_service() -> AuthService:
+154
View File
@@ -0,0 +1,154 @@
from __future__ import annotations
import asyncio
from typing import Any, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.schemas import (
AuthTokenResponse,
AuthUser,
AuthUserByEmailResponse,
LoginRequest,
RefreshRequest,
SignupRequest,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
_admin_client: Any
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
self._admin_client = create_client(settings.url, settings.service_role_key)
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
"data": {"username": request.username},
}
try:
sign_up = cast(Any, self._client.auth.sign_up)
response = await asyncio.to_thread(sign_up, payload)
return _map_auth_response(response, "Authentication failed")
except AuthError as exc:
logger.warning("Signup failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Authentication failed"
) from exc
async def login(self, request: LoginRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error_type=type(exc).__name__)
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def logout(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
normalized_email = email.lower()
user = next(
(
candidate
for candidate in users
if str(getattr(candidate, "email", "")).lower() == normalized_email
),
None,
)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return AuthUserByEmailResponse(
id=str(getattr(user, "id", "")),
email=str(getattr(user, "email", "")),
created_at=str(getattr(user, "created_at", "")),
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
else None
),
)
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return AuthTokenResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
while True:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = list(getattr(response, "users", []))
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
+18 -1
View File
@@ -1,10 +1,16 @@
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Depends, Response from typing import Annotated
from fastapi import APIRouter, Depends, Response
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.auth.dependencies import get_auth_service from v1.auth.dependencies import get_auth_service
from v1.profile.dependencies import get_current_user
from v1.auth.schemas import ( from v1.auth.schemas import (
AuthTokenResponse, AuthTokenResponse,
AuthUserByEmailResponse,
LoginRequest, LoginRequest,
LogoutRequest, LogoutRequest,
RefreshRequest, RefreshRequest,
@@ -47,3 +53,14 @@ async def logout(
) -> Response: ) -> Response:
await service.logout(payload.refresh_token) await service.logout(payload.refresh_token)
return Response(status_code=204) return Response(status_code=204)
@router.get("/users/by-email", response_model=AuthUserByEmailResponse)
async def get_user_by_email(
email: str,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
service: AuthService = Depends(get_auth_service),
) -> AuthUserByEmailResponse:
if current_user.role != "service_role" and current_user.email != email:
raise HTTPException(status_code=403, detail="Forbidden")
return await service.get_user_by_email(email)
+8 -1
View File
@@ -6,9 +6,9 @@ from pydantic import BaseModel, EmailStr, Field
class SignupRequest(BaseModel): class SignupRequest(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: EmailStr email: EmailStr
password: str = Field(min_length=6) password: str = Field(min_length=6)
display_name: str | None = None
redirect_to: str | None = None redirect_to: str | None = None
@@ -38,6 +38,13 @@ class AuthTokenResponse(BaseModel):
user: AuthUser user: AuthUser
class AuthUserByEmailResponse(BaseModel):
id: str
email: EmailStr
created_at: str
email_confirmed_at: str | None = None
class SignupPendingResponse(BaseModel): class SignupPendingResponse(BaseModel):
status: Literal["pending_verification"] = "pending_verification" status: Literal["pending_verification"] = "pending_verification"
user: AuthUser user: AuthUser
+6 -103
View File
@@ -1,25 +1,16 @@
from __future__ import annotations from __future__ import annotations
import asyncio from typing import Protocol
from typing import Any, Protocol, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.schemas import ( from v1.auth.schemas import (
AuthTokenResponse, AuthTokenResponse,
AuthUser, AuthUserByEmailResponse,
LoginRequest, LoginRequest,
RefreshRequest, RefreshRequest,
SignupRequest, SignupRequest,
) )
logger = get_logger("v1.auth.service")
class AuthServiceGateway(Protocol): class AuthServiceGateway(Protocol):
async def signup(self, request: SignupRequest) -> AuthTokenResponse: async def signup(self, request: SignupRequest) -> AuthTokenResponse:
raise NotImplementedError raise NotImplementedError
@@ -33,79 +24,8 @@ class AuthServiceGateway(Protocol):
async def logout(self, refresh_token: str | None) -> None: async def logout(self, refresh_token: str | None) -> None:
raise NotImplementedError raise NotImplementedError
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
class SupabaseAuthGateway(AuthServiceGateway): raise NotImplementedError
_client: Any
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
async def signup(self, request: SignupRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
}
if request.display_name:
payload = {
**payload,
"data": {"display_name": request.display_name},
}
try:
sign_up = cast(Any, self._client.auth.sign_up)
response = await asyncio.to_thread(sign_up, payload)
return _map_auth_response(response, "Authentication failed")
except AuthError as exc:
logger.warning("Signup failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Authentication failed"
) from exc
async def login(self, request: LoginRequest) -> AuthTokenResponse:
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error=str(exc))
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh(self, request: RefreshRequest) -> AuthTokenResponse:
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
except AuthError as exc:
logger.warning("Refresh failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
async def logout(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error=str(exc))
raise HTTPException(
status_code=401, detail="Invalid refresh token"
) from exc
class AuthService: class AuthService:
@@ -126,22 +46,5 @@ class AuthService:
async def logout(self, refresh_token: str | None) -> None: async def logout(self, refresh_token: str | None) -> None:
await self._gateway.logout(refresh_token) await self._gateway.logout(refresh_token)
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
def _map_auth_response(response: object, failure_message: str) -> AuthTokenResponse: return await self._gateway.get_user_by_email(email)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
return AuthTokenResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
+3 -1
View File
@@ -82,7 +82,9 @@ def get_current_user(authorization: str | None = Header(default=None)) -> Curren
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id)) logger.debug("JWT validation successful", user_id=str(user_id))
return CurrentUser(id=user_id) email = payload.get("email") if isinstance(payload.get("email"), str) else None
role = payload.get("role") if isinstance(payload.get("role"), str) else None
return CurrentUser(id=user_id, email=email, role=role)
def get_profile_service( def get_profile_service(
+10 -1
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol from typing import TYPE_CHECKING, Protocol
from uuid import UUID from uuid import UUID
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository from core.db.base_repository import BaseRepository
@@ -54,7 +55,15 @@ class SQLAlchemyProfileRepository(BaseRepository[Profile]):
async def get_by_username(self, username: str) -> Profile | None: async def get_by_username(self, username: str) -> Profile | None:
try: try:
return await self.get_one(Profile.username == username) stmt = (
select(Profile)
.where(Profile.username == username)
.where(Profile.deleted_at.is_(None))
.order_by(Profile.created_at.asc())
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError: except SQLAlchemyError:
logger.exception("Profile lookup failed", username=username) logger.exception("Profile lookup failed", username=username)
raise raise
+12 -4
View File
@@ -1,18 +1,26 @@
from __future__ import annotations from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator from pydantic import (
AnyHttpUrl,
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
class ProfileResponse(BaseModel): class ProfileResponse(BaseModel):
id: str id: str
username: str username: str
display_name: str | None = None
avatar_url: str | None = None avatar_url: str | None = None
bio: str | None = None bio: str | None = None
class ProfileUpdateRequest(BaseModel): class ProfileUpdateRequest(BaseModel):
display_name: str | None = Field(default=None, max_length=50) model_config = ConfigDict(extra="forbid")
username: str | None = Field(default=None, min_length=3, max_length=30)
avatar_url: str | None = Field(default=None) avatar_url: str | None = Field(default=None)
bio: str | None = Field(default=None, max_length=200) bio: str | None = Field(default=None, max_length=200)
@@ -28,6 +36,6 @@ class ProfileUpdateRequest(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def require_one_field(self) -> "ProfileUpdateRequest": def require_one_field(self) -> "ProfileUpdateRequest":
if self.display_name is None and self.avatar_url is None and self.bio is None: if self.username is None and self.avatar_url is None and self.bio is None:
raise ValueError("At least one field must be provided") raise ValueError("At least one field must be provided")
return self return self
+1 -4
View File
@@ -51,7 +51,6 @@ class ProfileService(BaseService):
return ProfileResponse( return ProfileResponse(
id=str(profile.id), id=str(profile.id),
username=profile.username, username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url, avatar_url=profile.avatar_url,
bio=profile.bio, bio=profile.bio,
) )
@@ -61,7 +60,7 @@ class ProfileService(BaseService):
update_data: dict[str, str | None] = { update_data: dict[str, str | None] = {
key: value key: value
for key, value in { for key, value in {
"display_name": update.display_name, "username": update.username,
"avatar_url": update.avatar_url, "avatar_url": update.avatar_url,
"bio": update.bio, "bio": update.bio,
}.items() }.items()
@@ -84,7 +83,6 @@ class ProfileService(BaseService):
return ProfileResponse( return ProfileResponse(
id=str(profile.id), id=str(profile.id),
username=profile.username, username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url, avatar_url=profile.avatar_url,
bio=profile.bio, bio=profile.bio,
) )
@@ -100,7 +98,6 @@ class ProfileService(BaseService):
return ProfileResponse( return ProfileResponse(
id=str(profile.id), id=str(profile.id),
username=profile.username, username=profile.username,
display_name=profile.display_name,
avatar_url=profile.avatar_url, avatar_url=profile.avatar_url,
bio=profile.bio, bio=profile.bio,
) )
+5 -1
View File
@@ -95,7 +95,11 @@ def test_auth_flow_e2e() -> None:
signup = request_context.post( signup = request_context.post(
"/api/v1/auth/signup", "/api/v1/auth/signup",
data=json.dumps( data=json.dumps(
{"email": "user@example.com", "password": "secret123"} {
"username": "demo",
"email": "user@example.com",
"password": "secret123",
}
), ),
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
+6 -8
View File
@@ -27,11 +27,10 @@ class FakeProfileService:
async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse: async def update_me(self, update: ProfileUpdateRequest) -> ProfileResponse:
return ProfileResponse( return ProfileResponse(
id=self._profile.id, id=self._profile.id,
username=self._profile.username, username=(
display_name=( update.username
update.display_name if update.username is not None
if update.display_name is not None else self._profile.username
else self._profile.display_name
), ),
avatar_url=( avatar_url=(
update.avatar_url update.avatar_url
@@ -75,7 +74,6 @@ def test_profile_flow_e2e() -> None:
profile = ProfileResponse( profile = ProfileResponse(
id=str(user_id), id=str(user_id),
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -98,11 +96,11 @@ def test_profile_flow_e2e() -> None:
updated = request_context.patch( updated = request_context.patch(
"/api/v1/profile/me", "/api/v1/profile/me",
data=json.dumps({"display_name": "Updated"}), data=json.dumps({"username": "updated"}),
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
assert updated.status == 200 assert updated.status == 200
assert updated.json()["display_name"] == "Updated" assert updated.json()["username"] == "updated"
public = request_context.get("/api/v1/profile/demo") public = request_context.get("/api/v1/profile/demo")
assert public.status == 200 assert public.status == 200
+145 -1
View File
@@ -1,14 +1,18 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable from typing import Callable
from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app import app from app import app
from core.auth.models import CurrentUser
from v1.auth.dependencies import get_auth_service from v1.auth.dependencies import get_auth_service
from v1.profile.dependencies import get_current_user
from v1.auth.schemas import ( from v1.auth.schemas import (
AuthTokenResponse, AuthTokenResponse,
AuthUserByEmailResponse,
AuthUser, AuthUser,
LoginRequest, LoginRequest,
RefreshRequest, RefreshRequest,
@@ -33,6 +37,16 @@ class FakeAuthService(AuthService):
async def logout(self, refresh_token: str | None) -> None: async def logout(self, refresh_token: str | None) -> None:
return None return None
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
if email == "missing@example.com":
raise HTTPException(status_code=404, detail="User not found")
return AuthUserByEmailResponse(
id="user-1",
email=email,
created_at="2026-02-24T00:00:00Z",
email_confirmed_at=None,
)
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]: def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
def _get_service() -> AuthService: def _get_service() -> AuthService:
@@ -58,7 +72,11 @@ def test_signup_returns_token_response() -> None:
try: try:
response = client.post( response = client.post(
"/api/v1/auth/signup", "/api/v1/auth/signup",
json={"email": "user@example.com", "password": "secret123"}, json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
},
) )
assert response.status_code == 200 assert response.status_code == 200
body = response.json() body = response.json()
@@ -176,3 +194,129 @@ def test_signup_validation_error_returns_problem_details() -> None:
assert body["detail"] == "Invalid request" assert body["detail"] == "Invalid request"
finally: finally:
app.dependency_overrides = {} app.dependency_overrides = {}
def test_signup_missing_username_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/signup",
json={"email": "user@example.com", "password": "secret123"},
)
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Content"
assert body["status"] == 422
assert body["detail"] == "Invalid request"
finally:
app.dependency_overrides = {}
def test_get_user_by_email_returns_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/auth/users/by-email",
params={"email": "user@example.com"},
)
assert response.status_code == 200
body = response.json()
assert body["email"] == "user@example.com"
assert body["id"] == "user-1"
finally:
app.dependency_overrides = {}
def test_get_user_by_email_not_found_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="missing@example.com",
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/auth/users/by-email",
params={"email": "missing@example.com"},
)
assert response.status_code == 404
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Not Found"
assert body["status"] == 404
assert body["detail"] == "User not found"
finally:
app.dependency_overrides = {}
def test_get_user_by_email_forbidden_when_querying_other_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="self@example.com",
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/auth/users/by-email",
params={"email": "target@example.com"},
)
assert response.status_code == 403
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Forbidden"
assert body["status"] == 403
assert body["detail"] == "Forbidden"
finally:
app.dependency_overrides = {}
@@ -29,11 +29,10 @@ class FakeProfileService:
raise HTTPException(status_code=404, detail="Profile not found") raise HTTPException(status_code=404, detail="Profile not found")
return ProfileResponse( return ProfileResponse(
id=self._profile.id, id=self._profile.id,
username=self._profile.username, username=(
display_name=( update.username
update.display_name if update.username is not None
if update.display_name is not None else self._profile.username
else self._profile.display_name
), ),
avatar_url=( avatar_url=(
update.avatar_url update.avatar_url
@@ -70,7 +69,6 @@ def test_get_me_returns_profile() -> None:
profile = ProfileResponse( profile = ProfileResponse(
id=str(user_id), id=str(user_id),
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -94,7 +92,6 @@ def test_patch_me_updates_profile() -> None:
profile = ProfileResponse( profile = ProfileResponse(
id=str(user_id), id=str(user_id),
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -107,11 +104,11 @@ def test_patch_me_updates_profile() -> None:
try: try:
response = client.patch( response = client.patch(
"/api/v1/profile/me", "/api/v1/profile/me",
json={"display_name": "Updated"}, json={"username": "updated"},
) )
assert response.status_code == 200 assert response.status_code == 200
body = response.json() body = response.json()
assert body["display_name"] == "Updated" assert body["username"] == "updated"
finally: finally:
app.dependency_overrides = {} app.dependency_overrides = {}
@@ -120,7 +117,6 @@ def test_get_profile_by_username() -> None:
profile = ProfileResponse( profile = ProfileResponse(
id="00000000-0000-0000-0000-000000000001", id="00000000-0000-0000-0000-000000000001",
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -142,7 +138,6 @@ def test_profile_not_found_returns_problem_details() -> None:
profile = ProfileResponse( profile = ProfileResponse(
id="00000000-0000-0000-0000-000000000001", id="00000000-0000-0000-0000-000000000001",
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -167,7 +162,6 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
profile = ProfileResponse( profile = ProfileResponse(
id=str(user_id), id=str(user_id),
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -186,3 +180,25 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
assert body["status"] == 422 assert body["status"] == 422
finally: finally:
app.dependency_overrides = {} app.dependency_overrides = {}
def test_patch_me_rejects_display_name_field() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
profile = ProfileResponse(
id=str(user_id),
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_profile_service] = _override_profile_service(
FakeProfileService(profile)
)
app.dependency_overrides[get_current_user] = _override_current_user(user_id)
client = TestClient(app)
try:
response = client.patch("/api/v1/profile/me", json={"display_name": "x"})
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
@@ -4,7 +4,7 @@ from datetime import datetime, timezone
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import pytest import pytest
from sqlalchemy import String from sqlalchemy import Column, String, Table
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@@ -21,10 +21,19 @@ class Widget(SoftDeleteMixin, Base):
@pytest.fixture @pytest.fixture
async def db_engine(): async def db_engine():
auth_users = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
yield engine yield engine
Base.metadata.remove(auth_users)
await engine.dispose() await engine.dispose()
@@ -0,0 +1,16 @@
from __future__ import annotations
from pathlib import Path
def test_drop_display_name_migration_exists_and_uses_username_metadata() -> None:
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
migration = (
versions_dir / "20260224_drop_profile_display_name_and_trigger_username.py"
)
assert migration.exists()
content = migration.read_text(encoding="utf-8")
assert "DROP COLUMN" in content and "display_name" in content
assert "raw_user_meta_data->>'username'" in content
@@ -3,7 +3,7 @@ from __future__ import annotations
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from sqlalchemy import select from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base from core.db.base import Base
@@ -13,13 +13,22 @@ from models.profile import Profile
@pytest.fixture @pytest.fixture
async def db_engine(): async def db_engine():
"""Create in-memory SQLite engine for testing.""" """Create in-memory SQLite engine for testing."""
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine( engine = create_async_engine(
"sqlite+aiosqlite:///:memory:", "sqlite+aiosqlite:///:memory:",
echo=False, echo=False,
) )
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
yield engine yield engine
Base.metadata.remove(users_table)
await engine.dispose() await engine.dispose()
@@ -43,7 +52,6 @@ async def test_profile_model_create(db_session: AsyncSession) -> None:
profile = Profile( profile = Profile(
id=profile_id, id=profile_id,
username="testuser", username="testuser",
display_name="Test User",
) )
db_session.add(profile) db_session.add(profile)
await db_session.commit() await db_session.commit()
@@ -51,7 +59,6 @@ async def test_profile_model_create(db_session: AsyncSession) -> None:
assert profile.id == profile_id assert profile.id == profile_id
assert profile.username == "testuser" assert profile.username == "testuser"
assert profile.display_name == "Test User"
assert profile.created_at is not None assert profile.created_at is not None
assert profile.updated_at is not None assert profile.updated_at is not None
assert profile.deleted_at is None assert profile.deleted_at is None
@@ -64,7 +71,6 @@ async def test_profile_model_get_by_id(db_session: AsyncSession) -> None:
profile = Profile( profile = Profile(
id=profile_id, id=profile_id,
username="testuser", username="testuser",
display_name="Test User",
) )
db_session.add(profile) db_session.add(profile)
await db_session.commit() await db_session.commit()
@@ -80,7 +86,6 @@ async def test_profile_model_get_by_username(db_session: AsyncSession) -> None:
profile = Profile( profile = Profile(
id=uuid4(), id=uuid4(),
username="testuser", username="testuser",
display_name="Test User",
) )
db_session.add(profile) db_session.add(profile)
await db_session.commit() await db_session.commit()
@@ -99,16 +104,31 @@ async def test_profile_model_update(db_session: AsyncSession) -> None:
profile = Profile( profile = Profile(
id=uuid4(), id=uuid4(),
username="testuser", username="testuser",
display_name="Test User",
bio="Old bio", bio="Old bio",
) )
db_session.add(profile) db_session.add(profile)
await db_session.commit() await db_session.commit()
profile.display_name = "Updated User"
profile.bio = "New bio" profile.bio = "New bio"
await db_session.commit() await db_session.commit()
await db_session.refresh(profile) await db_session.refresh(profile)
assert profile.display_name == "Updated User"
assert profile.bio == "New bio" assert profile.bio == "New bio"
@pytest.mark.asyncio
async def test_profile_model_allows_duplicate_usernames(
db_session: AsyncSession,
) -> None:
first = Profile(id=uuid4(), username="same_name")
second = Profile(id=uuid4(), username="same_name")
db_session.add(first)
db_session.add(second)
await db_session.commit()
result = await db_session.execute(
select(Profile).where(Profile.username == "same_name")
)
found = result.scalars().all()
assert len(found) == 2
@@ -14,7 +14,14 @@ from v1.auth.schemas import (
def test_signup_requires_valid_email() -> None: def test_signup_requires_valid_email() -> None:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
SignupRequest(email="not-an-email", password="secret123") SignupRequest(username="demo", email="not-an-email", password="secret123")
def test_signup_requires_username() -> None:
with pytest.raises(ValidationError):
SignupRequest.model_validate(
{"email": "user@example.com", "password": "secret123"}
)
def test_login_requires_valid_email() -> None: def test_login_requires_valid_email() -> None:
@@ -2,8 +2,10 @@ from __future__ import annotations
import pytest import pytest
import v1.auth.gateway as auth_gateway_module
from v1.auth.schemas import ( from v1.auth.schemas import (
AuthTokenResponse, AuthTokenResponse,
AuthUserByEmailResponse,
AuthUser, AuthUser,
LoginRequest, LoginRequest,
RefreshRequest, RefreshRequest,
@@ -28,6 +30,14 @@ class FakeGateway(AuthServiceGateway):
async def logout(self, refresh_token: str | None) -> None: async def logout(self, refresh_token: str | None) -> None:
return None return None
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
return AuthUserByEmailResponse(
id="user-1",
email=email,
created_at="2026-02-24T00:00:00Z",
email_confirmed_at=None,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_signup_maps_response() -> None: async def test_signup_maps_response() -> None:
@@ -42,7 +52,7 @@ async def test_signup_maps_response() -> None:
service = AuthService(gateway=FakeGateway(token_response)) service = AuthService(gateway=FakeGateway(token_response))
result = await service.signup( result = await service.signup(
SignupRequest(email="user@example.com", password="secret123") SignupRequest(username="demo", email="user@example.com", password="secret123")
) )
assert result.access_token == "access" assert result.access_token == "access"
@@ -66,9 +76,72 @@ class LogoutAssertingGateway(AuthServiceGateway):
async def logout(self, refresh_token: str | None) -> None: async def logout(self, refresh_token: str | None) -> None:
assert refresh_token == self._expected_refresh_token assert refresh_token == self._expected_refresh_token
async def get_user_by_email(self, email: str) -> AuthUserByEmailResponse:
raise NotImplementedError
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logout_forwards_refresh_token() -> None: async def test_logout_forwards_refresh_token() -> None:
service = AuthService(gateway=LogoutAssertingGateway("refresh-token")) service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
await service.logout("refresh-token") await service.logout("refresh-token")
@pytest.mark.asyncio
async def test_get_user_by_email_forwards_to_gateway() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
service = AuthService(gateway=FakeGateway(token_response))
result = await service.get_user_by_email("user@example.com")
assert result.email == "user@example.com"
@pytest.mark.asyncio
async def test_supabase_signup_passes_username_in_metadata(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_payload: dict[str, object] = {}
class FakeSupabaseAuth:
def sign_up(self, payload: dict[str, object]) -> object:
captured_payload.update(payload)
class _User:
id = "user-1"
email = "user@example.com"
class _Session:
access_token = "access"
refresh_token = "refresh"
expires_in = 3600
token_type = "bearer"
class _Response:
user = _User()
session = _Session()
return _Response()
class FakeClient:
auth = FakeSupabaseAuth()
monkeypatch.setattr(auth_gateway_module, "create_client", lambda *_: FakeClient())
gateway = auth_gateway_module.SupabaseAuthGateway()
await gateway.signup(
SignupRequest(
username="demo",
email="user@example.com",
password="secret123",
)
)
assert captured_payload["data"] == {"username": "demo"}
@@ -16,7 +16,6 @@ from v1.profile.service import ProfileService
def _create_mock_profile( def _create_mock_profile(
user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"), user_id: UUID = UUID("00000000-0000-0000-0000-000000000001"),
username: str = "demo", username: str = "demo",
display_name: str | None = "Demo User",
avatar_url: str | None = None, avatar_url: str | None = None,
bio: str | None = None, bio: str | None = None,
) -> Profile: ) -> Profile:
@@ -24,7 +23,6 @@ def _create_mock_profile(
profile = MagicMock(spec=Profile) profile = MagicMock(spec=Profile)
profile.id = user_id profile.id = user_id
profile.username = username profile.username = username
profile.display_name = display_name
profile.avatar_url = avatar_url profile.avatar_url = avatar_url
profile.bio = bio profile.bio = bio
return profile return profile
@@ -115,9 +113,9 @@ async def test_update_me_updates_fields(mock_session: AsyncMock) -> None:
current_user=user, current_user=user,
) )
result = await service.update_me(ProfileUpdateRequest(display_name="Updated")) result = await service.update_me(ProfileUpdateRequest(username="updated"))
assert result.display_name == "Updated" assert result.username == "updated"
mock_session.commit.assert_awaited_once() mock_session.commit.assert_awaited_once()
@@ -134,7 +132,7 @@ async def test_update_me_no_fields_raises_400(mock_session: AsyncMock) -> None:
# Create a request with all None values by bypassing validation # Create a request with all None values by bypassing validation
update = MagicMock(spec=ProfileUpdateRequest) update = MagicMock(spec=ProfileUpdateRequest)
update.display_name = None update.username = None
update.avatar_url = None update.avatar_url = None
update.bio = None update.bio = None
@@ -10,7 +10,6 @@ def test_profile_response_maps_fields() -> None:
response = ProfileResponse( response = ProfileResponse(
id="user-1", id="user-1",
username="demo", username="demo",
display_name="Demo User",
avatar_url=None, avatar_url=None,
bio=None, bio=None,
) )
@@ -56,6 +55,11 @@ def test_profile_update_rejects_data_url() -> None:
def test_profile_update_accepts_none_avatar_url_with_other_field() -> None: def test_profile_update_accepts_none_avatar_url_with_other_field() -> None:
request = ProfileUpdateRequest(display_name="Test", avatar_url=None) request = ProfileUpdateRequest(username="tester", avatar_url=None)
assert request.avatar_url is None assert request.avatar_url is None
assert request.display_name == "Test" assert request.username == "tester"
def test_profile_update_rejects_display_name_field() -> None:
with pytest.raises(ValidationError):
ProfileUpdateRequest.model_validate({"display_name": "legacy"})
+53 -4
View File
@@ -28,18 +28,43 @@ tmux attach -t social-dev
docker compose --env-file .env -f infra/docker/docker-compose.yml up -d docker compose --env-file .env -f infra/docker/docker-compose.yml up -d
# 2. 运行迁移和初始化 # 2. 运行迁移和初始化
docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm init-job docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm --build init-job
# 3. 一键执行应用层启动(bootstrap + web + workers # 3. 一键执行应用层启动(bootstrap + web + workers
bash infra/scripts/dev-app-up.sh bash infra/scripts/dev-app-up.sh
``` ```
### 生产环境迁移防遗漏(必读)
- 生产发布前必须先通过 bootstrap gate,再启动业务进程;禁止绕过 gate 直接起服务。
- 使用容器执行迁移时必须带 `--build`,确保最新 Alembic 迁移已进入镜像。
- 建议在迁移后做一次版本核对,确认已到预期 head。
```bash
# 1) 先执行 bootstrap gate
make runtime-bootstrap-gate
# 2) 如采用 init-job 单跑,必须带 --build
docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm --build init-job
# 3) 核对 Alembic 版本
docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T db \
psql -U postgres -d postgres -c "SELECT version_num FROM public.alembic_version;"
```
### 本地 CLI (开发调试) ### 本地 CLI (开发调试)
> 适用于本地开发调试,不依赖 Docker。 > 适用于本地开发调试,不依赖 Docker。
> 开发调试阶段推荐直接使用本地一次性迁移脚本,不通过 Docker 触发数据库迁移,避免反复重建镜像。
```bash ```bash
# 初始化/迁移 # 推荐:一次性迁移(开发调试)
PYTHONPATH=backend/src uv run python -m core.runtime.cli migrate
# 需要初始化数据时再执行
PYTHONPATH=backend/src uv run python -m core.runtime.cli init-data
# 或一键执行(migrate + init-data
PYTHONPATH=backend/src uv run python -m core.runtime.cli bootstrap PYTHONPATH=backend/src uv run python -m core.runtime.cli bootstrap
# 启动 Web (gunicorn) # 启动 Web (gunicorn)
@@ -102,7 +127,7 @@ tmux kill-session -t social-dev
curl -fsS http://127.0.0.1:${SOCIAL_SUPABASE__KONG_HTTP_PORT:-8000}/health curl -fsS http://127.0.0.1:${SOCIAL_SUPABASE__KONG_HTTP_PORT:-8000}/health
# 数据库迁移与初始化 # 数据库迁移与初始化
docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm init-job docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm --build init-job
``` ```
## 查看服务状态 ## 查看服务状态
@@ -112,7 +137,30 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml ps
docker compose --env-file .env -f infra/docker/docker-compose.yml logs -f db docker compose --env-file .env -f infra/docker/docker-compose.yml logs -f db
# init-job 为一次性任务(run --rm),如需查看日志请重跑: # init-job 为一次性任务(run --rm),如需查看日志请重跑:
docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm init-job docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job run --rm --build init-job
```
## Auth/Profile 验证
```bash
# signup: username + email + password
curl -sS -X POST http://127.0.0.1:8000/api/v1/auth/signup \
-H 'Content-Type: application/json' \
-d '{"username":"demo","email":"demo@example.com","password":"secret123"}'
# login: email + password
curl -sS -X POST http://127.0.0.1:8000/api/v1/auth/login \
-H 'Content-Type: application/json' \
-d '{"email":"demo@example.com","password":"secret123"}'
# by-email lookup
curl -sS "http://127.0.0.1:8000/api/v1/auth/users/by-email?email=demo@example.com"
# patch profile: username/avatar_url/bio only
curl -sS -X PATCH http://127.0.0.1:8000/api/v1/profile/me \
-H 'Content-Type: application/json' \
-H "Authorization: Bearer <access_token>" \
-d '{"username":"demo2","bio":"hello"}'
``` ```
--- ---
@@ -125,3 +173,4 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml --profile job
| 2026-02-24 | 清理配置:合并 AppSettings 到 WebSettings,删除 Worker 旧配置 (enabled_queues/queues),统一使用 SOCIAL_WEB__GUNICORN__* 命名 | | 2026-02-24 | 清理配置:合并 AppSettings 到 WebSettings,删除 Worker 旧配置 (enabled_queues/queues),统一使用 SOCIAL_WEB__GUNICORN__* 命名 |
| 2026-02-24 | 开发阶段 compose 暂不编排 web/worker,仅保留 redis/supabase 与 init-job | | 2026-02-24 | 开发阶段 compose 暂不编排 web/worker,仅保留 redis/supabase 与 init-job |
| 2026-02-24 | 新增 dev-app-up 脚本:手动基础设施后,一键 bootstrap + tmux 拉起 web/worker | | 2026-02-24 | 新增 dev-app-up 脚本:手动基础设施后,一键 bootstrap + tmux 拉起 web/worker |
| 2026-02-25 | 补充迁移防遗漏规则:容器迁移命令统一追加 --build;开发调试优先使用本地 CLI 一次性迁移脚本 |