feat: 实现密码重置功能与用户搜索API,优化注册登录流程

- 新增忘记密码页面与重置密码确认流程(前端+后端)
- 修复注册验证码页登录跳转路由
- 新增用户搜索API(按邮箱查询)
- 简化infra脚本,统一为app.sh
- 补充密码重置与用户API测试覆盖
- 更新runtime文档与AGENTS配置
This commit is contained in:
qzl
2026-02-27 15:22:42 +08:00
parent 0d4811fee5
commit e4e995854d
37 changed files with 2101 additions and 222 deletions
+80 -14
View File
@@ -1,3 +1,7 @@
# Backend Development Rules
This document defines Python/FastAPI backend development constraints.
## Python Environment
**MUST use uv for dependency management and virtual environment execution.**
@@ -43,11 +47,10 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
- Tests can set env vars via `monkeypatch.setenv`, and should read values via `Settings()` unless the test is explicitly validating env plumbing
- Canonical principle: one source of truth per setting; no duplicate/derived env vars in backend code
## TDD First Policy
**Principle: tests before implementation.**
## TDD Workflow
### Coverage Requirements
- Minimum coverage: 80%
- Required test types:
- Unit: isolated functions, utilities, components
@@ -55,12 +58,14 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
- E2E: critical user flows (Playwright)
### Limited Exceptions
- Docs-only changes (README, comments, formatting) may skip integration/E2E
- Non-runtime config changes may skip E2E if no behavior changes
- Any runtime code change requires unit + integration + E2E
- If an exception is used, record the reason in the PR/test notes
### Mandatory TDD Workflow
1. Write tests (RED) - they must fail
2. Run tests - confirm failure
3. Implement minimal code (GREEN) - only to pass
@@ -69,19 +74,80 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
6. Verify coverage - must be 80%+
### Enforcement
- Must use the `tdd-guide` agent for new features
- Do not write implementation before tests
- Do not lower coverage requirements
- Must include unit, integration, and E2E tests
## Code Style
### Immutability
**ALWAYS create new objects, NEVER mutate.**
```python
# WRONG: Mutation
def update_user(user, name):
user["name"] = name
return user
# CORRECT: Immutability
def update_user(user, name):
return {**user, "name": name}
```
### File Organization
- Many small files over few large files
- 200-400 lines typical, 800 max per file
- Extract utilities from large components
### Error Handling
Always handle errors comprehensively:
```python
try:
result = risky_operation()
return result
except Exception as exc:
logger.exception("Operation failed")
raise RuntimeError("Detailed user-friendly message") from exc
```
## Security
### Mandatory Security Checks
Before ANY commit:
- [ ] No hardcoded secrets (API keys, passwords, tokens)
- [ ] All user inputs validated (use Pydantic)
- [ ] SQL injection prevention (parameterized queries)
- [ ] Authentication/authorization verified
### Secret Management
```python
# NEVER: Hardcoded secrets
api_key = "sk-proj-xxxxx"
# ALWAYS: Environment variables
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY not configured")
```
## Database Development Rules
### Core Principle
### Architecture
- **Supabase**: authentication (JWT source of truth)
- **Backend**: business authorization (service layer)
- **SQLAlchemy ORM**: data access layer (async + asyncpg, service_role connection)
### Architecture
### Code Organization
Use `schemas / repository / service` pattern:
- `schemas.py` — Pydantic models
- `repository.py` — CRUD only, no auth, no commit (only flush), must receive session (never create session/engine)
@@ -89,6 +155,7 @@ Use `schemas / repository / service` pattern:
- `dependencies.py` — DI (`get_db`, `get_current_user`)
### Auth & Data Access
- Backend must verify JWT signature and expiration (not just decode)
- Extract `user_id` from JWT `sub` claim
- Backend connects with **service_role** (bypasses RLS)
@@ -98,31 +165,28 @@ Use `schemas / repository / service` pattern:
- Prohibit calling Supabase Admin API (service_role key) from repository/service layers
### Migrations
- **Alembic is the single source of truth** for schema migrations
- ORM model changes → `alembic revision --autogenerate`
- Raw SQL (policies, triggers, functions) → `op.execute()`
- Migrations must be reversible; no reliance on generated IDs
### Enum Storage Convention
**Store enum names (strings), not integer values.**
- Use `VARCHAR(20)` + `CHECK` constraint in database
- Use Python `Enum` class with `str` base in code
- Benefits: debugging readability, easy to add new values without data migration, ORM-friendly
```python
# Correct
class AgentType(str, Enum):
INTENT_RECOGNITION = "INTENT_RECOGNITION"
TASK_EXECUTION = "TASK_EXECUTION"
RESULT_REPORTING = "RESULT_REPORTING"
# Migration
ALTER TABLE user_agents ADD CONSTRAINT chk_agent_type
CHECK (agent_type IN ('INTENT_RECOGNITION', 'TASK_EXECUTION', 'RESULT_REPORTING'));
```
### RLS Guidance
### RLS Policy
- Backend does not rely on RLS for correctness (uses service_role), but RLS is mandatory as a defensive boundary for tables in PostgREST-exposed schemas.
- **Mandatory default**: any new business table in `public` must enable RLS in the same Alembic migration.
- The same migration must create policies covering `SELECT/INSERT/UPDATE/DELETE` (minimum requirement).
@@ -130,11 +194,13 @@ ALTER TABLE user_agents ADD CONSTRAINT chk_agent_type
- `alembic_version` must not be exposed to `anon` or `authenticated`.
#### Exemption Rule (strict)
- Exemptions are allowed only when a new `public` table is guaranteed not to be exposed to PostgREST clients.
- Exemptions must be explicit in the migration file with rationale and verification notes (why safe, how exposure is prevented).
- Exemptions must be explicit in the migration file with rationale and verification notes.
- If exposure is uncertain, do not exempt: enable defensive RLS by default.
#### Migration Acceptance Checklist (RLS)
#### Migration Checklist
- [ ] New `public` business table has `ALTER TABLE ... ENABLE ROW LEVEL SECURITY` in migration
- [ ] Policies for `SELECT/INSERT/UPDATE/DELETE` are present in migration
- [ ] Policy target roles are explicit (`anon`, `authenticated`, or both)
@@ -40,7 +40,6 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_index("ix_llm_factory_name", "llm_factory", ["name"], unique=True)
_enable_rls("llm_factory")
op.create_table(
@@ -65,7 +64,6 @@ def upgrade() -> None:
sa.UniqueConstraint("model_code"),
)
op.create_index("ix_llms_factory_id", "llms", ["factory_id"], unique=False)
op.create_index("ix_llms_model_code", "llms", ["model_code"], unique=True)
op.create_foreign_key(
"fk_llms_factory_id",
"llms",
+61
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Mapping
from typing import Any, cast
from fastapi import HTTPException
@@ -10,6 +11,8 @@ from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
@@ -150,6 +153,64 @@ class SupabaseAuthGateway(AuthServiceGateway):
),
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
try:
reset_email = cast(Any, self._client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
except AuthError as exc:
logger.warning(
"Password reset request failed",
error_type=type(exc).__name__,
)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
user_id = str(getattr(user, "id", "")) if user is not None else ""
if session is None or not user_id:
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
self._admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
except AuthError as exc:
logger.warning(
"Password reset confirm failed", error_type=type(exc).__name__
)
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
) from exc
def _coerce_reset_email(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, Mapping):
nested = value.get("email") or value.get("value")
if isinstance(nested, str):
return nested
raise HTTPException(status_code=422, detail="Invalid email")
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
session = getattr(response, "session", None)
+32
View File
@@ -10,6 +10,8 @@ from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
@@ -123,3 +125,33 @@ async def get_user_by_email(
if current_user.role != "service_role" and current_user.email != email:
raise HTTPException(status_code=403, detail="Forbidden")
return await service.get_user_by_email(email)
@router.post("/password-reset", status_code=204)
async def request_password_reset(
payload: PasswordResetRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_request",
identifier=payload.email,
limit=5,
window_seconds=60,
)
await service.request_password_reset(payload)
return Response(status_code=204)
@router.post("/password-reset/confirm", status_code=204)
async def confirm_password_reset(
payload: PasswordResetConfirmRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_confirm",
identifier=payload.email,
limit=10,
window_seconds=600,
)
await service.confirm_password_reset(payload)
return Response(status_code=204)
+6
View File
@@ -61,5 +61,11 @@ class PasswordResetRequest(BaseModel):
redirect_to: str | None = None
class PasswordResetConfirmRequest(BaseModel):
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str = Field(min_length=6)
class PasswordResetResponse(BaseModel):
message: str = "Password reset email sent"
+18
View File
@@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Protocol
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
@@ -40,6 +42,14 @@ class AuthServiceGateway(Protocol):
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
class AuthService:
_gateway: AuthServiceGateway
@@ -71,3 +81,11 @@ class AuthService:
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return await self._gateway.get_user_by_email(email)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
await self._gateway.request_password_reset(request)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
await self._gateway.confirm_password_reset(request)
+18 -2
View File
@@ -11,11 +11,21 @@ from core.auth.models import CurrentUser
from core.config.settings import config
from core.db import get_db
from core.logging import get_logger
from v1.auth.gateway import SupabaseAuthGateway
from v1.users.repository import SQLAlchemyUserRepository
from v1.users.service import UserService
from v1.users.service import AuthLookupAdapter, UserService
logger = get_logger("v1.users.dependencies")
_auth_gateway: SupabaseAuthGateway | None = None
def get_auth_gateway() -> SupabaseAuthGateway:
global _auth_gateway
if _auth_gateway is None:
_auth_gateway = SupabaseAuthGateway()
return _auth_gateway
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
if not authorization:
@@ -98,4 +108,10 @@ def get_user_service(
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> UserService:
repository = SQLAlchemyUserRepository(session)
return UserService(repository=repository, session=session, current_user=user)
auth_gateway = AuthLookupAdapter(get_auth_gateway())
return UserService(
repository=repository,
session=session,
current_user=user,
auth_gateway=auth_gateway,
)
+25 -2
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import select, or_
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
@@ -33,6 +33,10 @@ class UserRepository(Protocol):
"""Update user by user ID. Returns updated user or None if not found."""
...
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
"""Search users by username (ilike) or email (exact match)."""
...
class SQLAlchemyUserRepository(BaseRepository[Profile]):
"""SQLAlchemy implementation of UserRepository.
@@ -77,5 +81,24 @@ class SQLAlchemyUserRepository(BaseRepository[Profile]):
try:
return await self.update_by_id(user_id, update_data)
except SQLAlchemyError:
logger.exception("User update failed", user_id=str(user_id))
logger.exception("User update failed", user=str(user_id))
raise
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
try:
stmt = (
select(Profile)
.where(Profile.deleted_at.is_(None))
.where(
or_(
Profile.username.ilike(f"%{query}%"),
)
)
.order_by(Profile.created_at.asc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError:
logger.exception("User search failed", query=query)
raise
+7 -9
View File
@@ -2,10 +2,10 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, Path
from fastapi import APIRouter, Depends
from v1.users.dependencies import get_user_service
from v1.users.schemas import UserResponse, UserUpdateRequest
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService
@@ -27,11 +27,9 @@ async def update_me(
return await service.update_me(payload)
@router.get("/{username}", response_model=UserResponse)
async def get_by_username(
username: Annotated[
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
],
@router.post("/search", response_model=list[UserResponse])
async def search_users(
payload: UserSearchRequest,
service: Annotated[UserService, Depends(get_user_service)],
) -> UserResponse:
return await service.get_by_username(username)
) -> list[UserResponse]:
return await service.search_users(payload)
+11
View File
@@ -19,6 +19,17 @@ class UserResponse(BaseModel):
bio: str | None = None
class UserSearchRequest(BaseModel):
query: str = Field(min_length=1, max_length=100)
class UserSearchResult(BaseModel):
id: str
username: str
avatar_url: str | None = None
bio: str | None = None
class UserUpdateRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
+80 -2
View File
@@ -1,6 +1,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import re
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
@@ -9,13 +11,37 @@ from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from v1.users.repository import UserRepository
from v1.users.schemas import UserResponse, UserUpdateRequest
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from v1.auth.schemas import UserByEmailResponse
logger = get_logger("v1.users.service")
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
class AuthLookupGateway(Protocol):
async def get_user_id_by_email(self, email: str) -> str | None: ...
class AuthByEmailGateway(Protocol):
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
class AuthLookupAdapter:
def __init__(self, gateway: AuthByEmailGateway) -> None:
self._gateway = gateway
async def get_user_id_by_email(self, email: str) -> str | None:
try:
response = await self._gateway.get_user_by_email(email)
return response.id
except HTTPException:
return None
class UserService(BaseService):
"""User service handling business logic and transactions.
@@ -28,16 +54,19 @@ class UserService(BaseService):
_repository: UserRepository
_session: AsyncSession
_auth_gateway: AuthLookupGateway | None
def __init__(
self,
repository: UserRepository,
session: AsyncSession,
current_user: CurrentUser | None,
auth_gateway: AuthLookupGateway | None = None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._session = session
self._auth_gateway = auth_gateway
async def get_me(self) -> UserResponse:
user_id = self.require_user_id()
@@ -101,3 +130,52 @@ class UserService(BaseService):
avatar_url=user.avatar_url,
bio=user.bio,
)
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
query = request.query.strip()
if _EMAIL_PATTERN.match(query):
return await self._search_by_email(query)
return await self._search_by_username(query)
async def _search_by_email(self, email: str) -> list[UserResponse]:
if self._auth_gateway is None:
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
if user_id_str is None:
return []
try:
user = await self._repository.get_by_user_id(UUID(user_id_str))
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
if user is None:
return []
return [
UserResponse(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
)
]
async def _search_by_username(self, query: str) -> list[UserResponse]:
try:
users = await self._repository.search_users(query, limit=20)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
return [
UserResponse(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
)
for user in users
]
@@ -14,6 +14,8 @@ from v1.users.dependencies import get_current_user
from v1.auth.rate_limit import reset_rate_limit_state
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
@@ -71,6 +73,18 @@ class FakeAuthService(AuthService):
email_confirmed_at=None,
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
return None
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
if request.token == "000000":
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
return None
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
def _get_service() -> AuthService:
@@ -665,3 +679,116 @@ def test_get_user_by_email_forbidden_when_querying_other_user() -> None:
assert body["detail"] == "Forbidden"
finally:
app.dependency_overrides = {}
def test_password_reset_request_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset",
json={"email": "user@example.com"},
)
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"email": "user@example.com",
"token": "123456",
"new_password": "newpassword123",
},
)
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_invalid_token_returns_401() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"email": "user@example.com",
"token": "000000",
"new_password": "newpassword123",
},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_weak_password_returns_422() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"email": "user@example.com",
"token": "123456",
"new_password": "123",
},
)
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
+77 -45
View File
@@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.users.dependencies import get_current_user, get_user_service
from v1.users.schemas import UserResponse, UserUpdateRequest
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService
@@ -18,6 +18,10 @@ class FakeUserService:
def __init__(self, user: UserResponse) -> None:
self._user = user
self._search_results: list[UserResponse] = []
def set_search_results(self, results: list[UserResponse]) -> None:
self._search_results = results
async def get_me(self) -> UserResponse:
if self._user.id is None:
@@ -45,6 +49,11 @@ class FakeUserService:
raise HTTPException(status_code=404, detail="User not found")
return self._user
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
if request.query:
return self._search_results if self._search_results else [self._user]
return []
def _override_user_service(
service: FakeUserService,
@@ -111,50 +120,6 @@ def test_patch_me_updates_user() -> None:
app.dependency_overrides = {}
def test_get_user_by_username() -> None:
user = UserResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.get("/api/v1/users/demo")
assert response.status_code == 200
body = response.json()
assert body["username"] == "demo"
finally:
app.dependency_overrides = {}
def test_user_not_found_returns_problem_details() -> None:
user = UserResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.get("/api/v1/users/unknown")
assert response.status_code == 404
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Not Found"
assert body["status"] == 404
finally:
app.dependency_overrides = {}
def test_patch_me_validation_error_returns_problem_details() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
@@ -178,3 +143,70 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
assert body["status"] == 422
finally:
app.dependency_overrides = {}
def test_search_users_returns_list() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
id=str(user_id),
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/users/search",
json={"query": "demo"},
)
assert response.status_code == 200
body = response.json()
assert isinstance(body, list)
finally:
app.dependency_overrides = {}
def test_search_users_empty_query_returns_422() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
id=str(user_id),
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/users/search",
json={"query": ""},
)
assert response.status_code == 422
finally:
app.dependency_overrides = {}
def test_get_user_by_username_returns_404() -> None:
user = UserResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.get("/api/v1/users/demo")
assert response.status_code == 404
finally:
app.dependency_overrides = {}
@@ -0,0 +1,155 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
class TestSupabaseAuthGateway:
@pytest.fixture
def gateway(self) -> SupabaseAuthGateway:
with patch("v1.auth.gateway.create_client") as mock_create:
mock_client = MagicMock()
mock_admin_client = MagicMock()
mock_create.side_effect = [mock_client, mock_admin_client]
return SupabaseAuthGateway()
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
self, gateway: SupabaseAuthGateway
) -> None:
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
await gateway.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: SupabaseAuthGateway
) -> None:
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await gateway.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
options={"redirect_to": "http://localhost:3000/reset-password"},
)
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: SupabaseAuthGateway
) -> None:
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await gateway.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: SupabaseAuthGateway
) -> None:
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await gateway.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: SupabaseAuthGateway
) -> None:
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await gateway.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: SupabaseAuthGateway
) -> None:
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
gateway._client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
gateway._admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
request = PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
)
await gateway.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
{
"type": "recovery",
"email": "test@example.com",
"token": "123456",
}
)
mock_update_user_by_id.assert_called_once_with(
"user-1",
{"password": "newpassword123"},
)
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
self, gateway: SupabaseAuthGateway
) -> None:
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
)
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
)
with pytest.raises(HTTPException) as exc_info:
await gateway.confirm_password_reset(request)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"