feat: 实现密码重置功能与用户搜索API,优化注册登录流程
- 新增忘记密码页面与重置密码确认流程(前端+后端) - 修复注册验证码页登录跳转路由 - 新增用户搜索API(按邮箱查询) - 简化infra脚本,统一为app.sh - 补充密码重置与用户API测试覆盖 - 更新runtime文档与AGENTS配置
This commit is contained in:
+80
-14
@@ -1,3 +1,7 @@
|
||||
# Backend Development Rules
|
||||
|
||||
This document defines Python/FastAPI backend development constraints.
|
||||
|
||||
## Python Environment
|
||||
|
||||
**MUST use uv for dependency management and virtual environment execution.**
|
||||
@@ -43,11 +47,10 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
|
||||
- Tests can set env vars via `monkeypatch.setenv`, and should read values via `Settings()` unless the test is explicitly validating env plumbing
|
||||
- Canonical principle: one source of truth per setting; no duplicate/derived env vars in backend code
|
||||
|
||||
## TDD First Policy
|
||||
|
||||
**Principle: tests before implementation.**
|
||||
## TDD Workflow
|
||||
|
||||
### Coverage Requirements
|
||||
|
||||
- Minimum coverage: 80%
|
||||
- Required test types:
|
||||
- Unit: isolated functions, utilities, components
|
||||
@@ -55,12 +58,14 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
|
||||
- E2E: critical user flows (Playwright)
|
||||
|
||||
### Limited Exceptions
|
||||
|
||||
- Docs-only changes (README, comments, formatting) may skip integration/E2E
|
||||
- Non-runtime config changes may skip E2E if no behavior changes
|
||||
- Any runtime code change requires unit + integration + E2E
|
||||
- If an exception is used, record the reason in the PR/test notes
|
||||
|
||||
### Mandatory TDD Workflow
|
||||
|
||||
1. Write tests (RED) - they must fail
|
||||
2. Run tests - confirm failure
|
||||
3. Implement minimal code (GREEN) - only to pass
|
||||
@@ -69,19 +74,80 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
|
||||
6. Verify coverage - must be 80%+
|
||||
|
||||
### Enforcement
|
||||
|
||||
- Must use the `tdd-guide` agent for new features
|
||||
- Do not write implementation before tests
|
||||
- Do not lower coverage requirements
|
||||
- Must include unit, integration, and E2E tests
|
||||
|
||||
## Code Style
|
||||
|
||||
### Immutability
|
||||
|
||||
**ALWAYS create new objects, NEVER mutate.**
|
||||
|
||||
```python
|
||||
# WRONG: Mutation
|
||||
def update_user(user, name):
|
||||
user["name"] = name
|
||||
return user
|
||||
|
||||
# CORRECT: Immutability
|
||||
def update_user(user, name):
|
||||
return {**user, "name": name}
|
||||
```
|
||||
|
||||
### File Organization
|
||||
|
||||
- Many small files over few large files
|
||||
- 200-400 lines typical, 800 max per file
|
||||
- Extract utilities from large components
|
||||
|
||||
### Error Handling
|
||||
|
||||
Always handle errors comprehensively:
|
||||
|
||||
```python
|
||||
try:
|
||||
result = risky_operation()
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("Operation failed")
|
||||
raise RuntimeError("Detailed user-friendly message") from exc
|
||||
```
|
||||
|
||||
## Security
|
||||
|
||||
### Mandatory Security Checks
|
||||
|
||||
Before ANY commit:
|
||||
- [ ] No hardcoded secrets (API keys, passwords, tokens)
|
||||
- [ ] All user inputs validated (use Pydantic)
|
||||
- [ ] SQL injection prevention (parameterized queries)
|
||||
- [ ] Authentication/authorization verified
|
||||
|
||||
### Secret Management
|
||||
|
||||
```python
|
||||
# NEVER: Hardcoded secrets
|
||||
api_key = "sk-proj-xxxxx"
|
||||
|
||||
# ALWAYS: Environment variables
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY not configured")
|
||||
```
|
||||
|
||||
## Database Development Rules
|
||||
|
||||
### Core Principle
|
||||
### Architecture
|
||||
|
||||
- **Supabase**: authentication (JWT source of truth)
|
||||
- **Backend**: business authorization (service layer)
|
||||
- **SQLAlchemy ORM**: data access layer (async + asyncpg, service_role connection)
|
||||
|
||||
### Architecture
|
||||
### Code Organization
|
||||
|
||||
Use `schemas / repository / service` pattern:
|
||||
- `schemas.py` — Pydantic models
|
||||
- `repository.py` — CRUD only, no auth, no commit (only flush), must receive session (never create session/engine)
|
||||
@@ -89,6 +155,7 @@ Use `schemas / repository / service` pattern:
|
||||
- `dependencies.py` — DI (`get_db`, `get_current_user`)
|
||||
|
||||
### Auth & Data Access
|
||||
|
||||
- Backend must verify JWT signature and expiration (not just decode)
|
||||
- Extract `user_id` from JWT `sub` claim
|
||||
- Backend connects with **service_role** (bypasses RLS)
|
||||
@@ -98,31 +165,28 @@ Use `schemas / repository / service` pattern:
|
||||
- Prohibit calling Supabase Admin API (service_role key) from repository/service layers
|
||||
|
||||
### Migrations
|
||||
|
||||
- **Alembic is the single source of truth** for schema migrations
|
||||
- ORM model changes → `alembic revision --autogenerate`
|
||||
- Raw SQL (policies, triggers, functions) → `op.execute()`
|
||||
- Migrations must be reversible; no reliance on generated IDs
|
||||
|
||||
### Enum Storage Convention
|
||||
|
||||
**Store enum names (strings), not integer values.**
|
||||
|
||||
- Use `VARCHAR(20)` + `CHECK` constraint in database
|
||||
- Use Python `Enum` class with `str` base in code
|
||||
- Benefits: debugging readability, easy to add new values without data migration, ORM-friendly
|
||||
|
||||
```python
|
||||
# Correct
|
||||
class AgentType(str, Enum):
|
||||
INTENT_RECOGNITION = "INTENT_RECOGNITION"
|
||||
TASK_EXECUTION = "TASK_EXECUTION"
|
||||
RESULT_REPORTING = "RESULT_REPORTING"
|
||||
|
||||
# Migration
|
||||
ALTER TABLE user_agents ADD CONSTRAINT chk_agent_type
|
||||
CHECK (agent_type IN ('INTENT_RECOGNITION', 'TASK_EXECUTION', 'RESULT_REPORTING'));
|
||||
```
|
||||
|
||||
### RLS Guidance
|
||||
### RLS Policy
|
||||
|
||||
- Backend does not rely on RLS for correctness (uses service_role), but RLS is mandatory as a defensive boundary for tables in PostgREST-exposed schemas.
|
||||
- **Mandatory default**: any new business table in `public` must enable RLS in the same Alembic migration.
|
||||
- The same migration must create policies covering `SELECT/INSERT/UPDATE/DELETE` (minimum requirement).
|
||||
@@ -130,11 +194,13 @@ ALTER TABLE user_agents ADD CONSTRAINT chk_agent_type
|
||||
- `alembic_version` must not be exposed to `anon` or `authenticated`.
|
||||
|
||||
#### Exemption Rule (strict)
|
||||
|
||||
- Exemptions are allowed only when a new `public` table is guaranteed not to be exposed to PostgREST clients.
|
||||
- Exemptions must be explicit in the migration file with rationale and verification notes (why safe, how exposure is prevented).
|
||||
- Exemptions must be explicit in the migration file with rationale and verification notes.
|
||||
- If exposure is uncertain, do not exempt: enable defensive RLS by default.
|
||||
|
||||
#### Migration Acceptance Checklist (RLS)
|
||||
#### Migration Checklist
|
||||
|
||||
- [ ] New `public` business table has `ALTER TABLE ... ENABLE ROW LEVEL SECURITY` in migration
|
||||
- [ ] Policies for `SELECT/INSERT/UPDATE/DELETE` are present in migration
|
||||
- [ ] Policy target roles are explicit (`anon`, `authenticated`, or both)
|
||||
|
||||
@@ -40,7 +40,6 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_index("ix_llm_factory_name", "llm_factory", ["name"], unique=True)
|
||||
_enable_rls("llm_factory")
|
||||
|
||||
op.create_table(
|
||||
@@ -65,7 +64,6 @@ def upgrade() -> None:
|
||||
sa.UniqueConstraint("model_code"),
|
||||
)
|
||||
op.create_index("ix_llms_factory_id", "llms", ["factory_id"], unique=False)
|
||||
op.create_index("ix_llms_model_code", "llms", ["model_code"], unique=True)
|
||||
op.create_foreign_key(
|
||||
"fk_llms_factory_id",
|
||||
"llms",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -10,6 +11,8 @@ from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -150,6 +153,64 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
),
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
try:
|
||||
reset_email = cast(Any, self._client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {"redirect_to": request.redirect_to}
|
||||
await asyncio.to_thread(reset_email, email, options=options)
|
||||
else:
|
||||
await asyncio.to_thread(reset_email, email)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset request failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
user_id = str(getattr(user, "id", "")) if user is not None else ""
|
||||
if session is None or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self._admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset confirm failed", error_type=type(exc).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
) from exc
|
||||
|
||||
|
||||
def _coerce_reset_email(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
nested = value.get("email") or value.get("value")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
|
||||
raise HTTPException(status_code=422, detail="Invalid email")
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
|
||||
session = getattr(response, "session", None)
|
||||
|
||||
@@ -10,6 +10,8 @@ from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
@@ -123,3 +125,33 @@ async def get_user_by_email(
|
||||
if current_user.role != "service_role" and current_user.email != email:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return await service.get_user_by_email(email)
|
||||
|
||||
|
||||
@router.post("/password-reset", status_code=204)
|
||||
async def request_password_reset(
|
||||
payload: PasswordResetRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="password_reset_request",
|
||||
identifier=payload.email,
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.request_password_reset(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", status_code=204)
|
||||
async def confirm_password_reset(
|
||||
payload: PasswordResetConfirmRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="password_reset_confirm",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=600,
|
||||
)
|
||||
await service.confirm_password_reset(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -61,5 +61,11 @@ class PasswordResetRequest(BaseModel):
|
||||
redirect_to: str | None = None
|
||||
|
||||
|
||||
class PasswordResetConfirmRequest(BaseModel):
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str = Field(min_length=6)
|
||||
|
||||
|
||||
class PasswordResetResponse(BaseModel):
|
||||
message: str = "Password reset email sent"
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -40,6 +42,14 @@ class AuthServiceGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
@@ -71,3 +81,11 @@ class AuthService:
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
return await self._gateway.get_user_by_email(email)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
await self._gateway.request_password_reset(request)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
await self._gateway.confirm_password_reset(request)
|
||||
|
||||
@@ -11,11 +11,21 @@ from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from core.logging import get_logger
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.users.repository import SQLAlchemyUserRepository
|
||||
from v1.users.service import UserService
|
||||
from v1.users.service import AuthLookupAdapter, UserService
|
||||
|
||||
logger = get_logger("v1.users.dependencies")
|
||||
|
||||
_auth_gateway: SupabaseAuthGateway | None = None
|
||||
|
||||
|
||||
def get_auth_gateway() -> SupabaseAuthGateway:
|
||||
global _auth_gateway
|
||||
if _auth_gateway is None:
|
||||
_auth_gateway = SupabaseAuthGateway()
|
||||
return _auth_gateway
|
||||
|
||||
|
||||
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
|
||||
if not authorization:
|
||||
@@ -98,4 +108,10 @@ def get_user_service(
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UserService:
|
||||
repository = SQLAlchemyUserRepository(session)
|
||||
return UserService(repository=repository, session=session, current_user=user)
|
||||
auth_gateway = AuthLookupAdapter(get_auth_gateway())
|
||||
return UserService(
|
||||
repository=repository,
|
||||
session=session,
|
||||
current_user=user,
|
||||
auth_gateway=auth_gateway,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
@@ -33,6 +33,10 @@ class UserRepository(Protocol):
|
||||
"""Update user by user ID. Returns updated user or None if not found."""
|
||||
...
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
|
||||
"""Search users by username (ilike) or email (exact match)."""
|
||||
...
|
||||
|
||||
|
||||
class SQLAlchemyUserRepository(BaseRepository[Profile]):
|
||||
"""SQLAlchemy implementation of UserRepository.
|
||||
@@ -77,5 +81,24 @@ class SQLAlchemyUserRepository(BaseRepository[Profile]):
|
||||
try:
|
||||
return await self.update_by_id(user_id, update_data)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("User update failed", user_id=str(user_id))
|
||||
logger.exception("User update failed", user=str(user_id))
|
||||
raise
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
|
||||
try:
|
||||
stmt = (
|
||||
select(Profile)
|
||||
.where(Profile.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
Profile.username.ilike(f"%{query}%"),
|
||||
)
|
||||
)
|
||||
.order_by(Profile.created_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
except SQLAlchemyError:
|
||||
logger.exception("User search failed", query=query)
|
||||
raise
|
||||
|
||||
@@ -2,10 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from v1.users.dependencies import get_user_service
|
||||
from v1.users.schemas import UserResponse, UserUpdateRequest
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
@@ -27,11 +27,9 @@ async def update_me(
|
||||
return await service.update_me(payload)
|
||||
|
||||
|
||||
@router.get("/{username}", response_model=UserResponse)
|
||||
async def get_by_username(
|
||||
username: Annotated[
|
||||
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
|
||||
],
|
||||
@router.post("/search", response_model=list[UserResponse])
|
||||
async def search_users(
|
||||
payload: UserSearchRequest,
|
||||
service: Annotated[UserService, Depends(get_user_service)],
|
||||
) -> UserResponse:
|
||||
return await service.get_by_username(username)
|
||||
) -> list[UserResponse]:
|
||||
return await service.search_users(payload)
|
||||
|
||||
@@ -19,6 +19,17 @@ class UserResponse(BaseModel):
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
class UserSearchRequest(BaseModel):
|
||||
query: str = Field(min_length=1, max_length=100)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
class UserUpdateRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -9,13 +11,37 @@ from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from v1.users.repository import UserRepository
|
||||
from v1.users.schemas import UserResponse, UserUpdateRequest
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
|
||||
logger = get_logger("v1.users.service")
|
||||
|
||||
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
class AuthLookupGateway(Protocol):
|
||||
async def get_user_id_by_email(self, email: str) -> str | None: ...
|
||||
|
||||
|
||||
class AuthByEmailGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
|
||||
|
||||
|
||||
class AuthLookupAdapter:
|
||||
def __init__(self, gateway: AuthByEmailGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def get_user_id_by_email(self, email: str) -> str | None:
|
||||
try:
|
||||
response = await self._gateway.get_user_by_email(email)
|
||||
return response.id
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
"""User service handling business logic and transactions.
|
||||
@@ -28,16 +54,19 @@ class UserService(BaseService):
|
||||
|
||||
_repository: UserRepository
|
||||
_session: AsyncSession
|
||||
_auth_gateway: AuthLookupGateway | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: UserRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
auth_gateway: AuthLookupGateway | None = None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
self._auth_gateway = auth_gateway
|
||||
|
||||
async def get_me(self) -> UserResponse:
|
||||
user_id = self.require_user_id()
|
||||
@@ -101,3 +130,52 @@ class UserService(BaseService):
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
|
||||
query = request.query.strip()
|
||||
|
||||
if _EMAIL_PATTERN.match(query):
|
||||
return await self._search_by_email(query)
|
||||
|
||||
return await self._search_by_username(query)
|
||||
|
||||
async def _search_by_email(self, email: str) -> list[UserResponse]:
|
||||
if self._auth_gateway is None:
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
|
||||
if user_id_str is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
user = await self._repository.get_by_user_id(UUID(user_id_str))
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
if user is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
]
|
||||
|
||||
async def _search_by_username(self, query: str) -> list[UserResponse]:
|
||||
try:
|
||||
users = await self._repository.search_users(query, limit=20)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
@@ -14,6 +14,8 @@ from v1.users.dependencies import get_current_user
|
||||
from v1.auth.rate_limit import reset_rate_limit_state
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -71,6 +73,18 @@ class FakeAuthService(AuthService):
|
||||
email_confirmed_at=None,
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
return None
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
if request.token == "000000":
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
|
||||
def _get_service() -> AuthService:
|
||||
@@ -665,3 +679,116 @@ def test_get_user_by_email_forbidden_when_querying_other_user() -> None:
|
||||
assert body["detail"] == "Forbidden"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_request_returns_204() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/password-reset",
|
||||
json={"email": "user@example.com"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_confirm_returns_204() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"email": "user@example.com",
|
||||
"token": "123456",
|
||||
"new_password": "newpassword123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_confirm_invalid_token_returns_401() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"email": "user@example.com",
|
||||
"token": "000000",
|
||||
"new_password": "newpassword123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_confirm_weak_password_returns_422() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"email": "user@example.com",
|
||||
"token": "123456",
|
||||
"new_password": "123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.users.dependencies import get_current_user, get_user_service
|
||||
from v1.users.schemas import UserResponse, UserUpdateRequest
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@ class FakeUserService:
|
||||
|
||||
def __init__(self, user: UserResponse) -> None:
|
||||
self._user = user
|
||||
self._search_results: list[UserResponse] = []
|
||||
|
||||
def set_search_results(self, results: list[UserResponse]) -> None:
|
||||
self._search_results = results
|
||||
|
||||
async def get_me(self) -> UserResponse:
|
||||
if self._user.id is None:
|
||||
@@ -45,6 +49,11 @@ class FakeUserService:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return self._user
|
||||
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
|
||||
if request.query:
|
||||
return self._search_results if self._search_results else [self._user]
|
||||
return []
|
||||
|
||||
|
||||
def _override_user_service(
|
||||
service: FakeUserService,
|
||||
@@ -111,50 +120,6 @@ def test_patch_me_updates_user() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_user_by_username() -> None:
|
||||
user = UserResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_user_service] = _override_user_service(
|
||||
FakeUserService(user)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/users/demo")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["username"] == "demo"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_user_not_found_returns_problem_details() -> None:
|
||||
user = UserResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_user_service] = _override_user_service(
|
||||
FakeUserService(user)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/users/unknown")
|
||||
assert response.status_code == 404
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Not Found"
|
||||
assert body["status"] == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
@@ -178,3 +143,70 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
assert body["status"] == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_search_users_returns_list() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_user_service] = _override_user_service(
|
||||
FakeUserService(user)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/users/search",
|
||||
json={"query": "demo"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert isinstance(body, list)
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_search_users_empty_query_returns_422() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_user_service] = _override_user_service(
|
||||
FakeUserService(user)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/users/search",
|
||||
json={"query": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_get_user_by_username_returns_404() -> None:
|
||||
user = UserResponse(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
app.dependency_overrides[get_user_service] = _override_user_service(
|
||||
FakeUserService(user)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/users/demo")
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
|
||||
|
||||
|
||||
class TestSupabaseAuthGateway:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> SupabaseAuthGateway:
|
||||
with patch("v1.auth.gateway.create_client") as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
mock_create.side_effect = [mock_client, mock_admin_client]
|
||||
return SupabaseAuthGateway()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_calls_email_with_string(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
await gateway.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_with_redirect(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
await gateway.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
options={"redirect_to": "http://localhost:3000/reset-password"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_swallows_auth_error(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
from supabase import AuthError
|
||||
|
||||
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
|
||||
result = await gateway.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_extracts_email_from_mapping(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"email": "test@example.com"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
await gateway.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_invalid_email_shape(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"unexpected": "value"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.request_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid email"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_updates_password_by_user_id(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id="user-1"),
|
||||
)
|
||||
mock_verify_otp = MagicMock(return_value=verify_response)
|
||||
gateway._client.auth.verify_otp = mock_verify_otp
|
||||
|
||||
mock_update_user_by_id = MagicMock()
|
||||
gateway._admin_client.auth.admin = SimpleNamespace(
|
||||
update_user_by_id=mock_update_user_by_id
|
||||
)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
token="123456",
|
||||
new_password="newpassword123",
|
||||
)
|
||||
|
||||
await gateway.confirm_password_reset(request)
|
||||
|
||||
mock_verify_otp.assert_called_once_with(
|
||||
{
|
||||
"type": "recovery",
|
||||
"email": "test@example.com",
|
||||
"token": "123456",
|
||||
}
|
||||
)
|
||||
mock_update_user_by_id.assert_called_once_with(
|
||||
"user-1",
|
||||
{"password": "newpassword123"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_raises_when_user_id_missing(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
) -> None:
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id=""),
|
||||
)
|
||||
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
token="123456",
|
||||
new_password="newpassword123",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.confirm_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid or expired verification code"
|
||||
Reference in New Issue
Block a user