feat: 实现密码重置功能与用户搜索API,优化注册登录流程
- 新增忘记密码页面与重置密码确认流程(前端+后端) - 修复注册验证码页登录跳转路由 - 新增用户搜索API(按邮箱查询) - 简化infra脚本,统一为app.sh - 补充密码重置与用户API测试覆盖 - 更新runtime文档与AGENTS配置
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -10,6 +11,8 @@ from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -150,6 +153,64 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
),
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
try:
|
||||
reset_email = cast(Any, self._client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {"redirect_to": request.redirect_to}
|
||||
await asyncio.to_thread(reset_email, email, options=options)
|
||||
else:
|
||||
await asyncio.to_thread(reset_email, email)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset request failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
user_id = str(getattr(user, "id", "")) if user is not None else ""
|
||||
if session is None or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self._admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset confirm failed", error_type=type(exc).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
) from exc
|
||||
|
||||
|
||||
def _coerce_reset_email(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
nested = value.get("email") or value.get("value")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
|
||||
raise HTTPException(status_code=422, detail="Invalid email")
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
|
||||
session = getattr(response, "session", None)
|
||||
|
||||
@@ -10,6 +10,8 @@ from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
@@ -123,3 +125,33 @@ async def get_user_by_email(
|
||||
if current_user.role != "service_role" and current_user.email != email:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return await service.get_user_by_email(email)
|
||||
|
||||
|
||||
@router.post("/password-reset", status_code=204)
|
||||
async def request_password_reset(
|
||||
payload: PasswordResetRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="password_reset_request",
|
||||
identifier=payload.email,
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.request_password_reset(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", status_code=204)
|
||||
async def confirm_password_reset(
|
||||
payload: PasswordResetConfirmRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="password_reset_confirm",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=600,
|
||||
)
|
||||
await service.confirm_password_reset(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -61,5 +61,11 @@ class PasswordResetRequest(BaseModel):
|
||||
redirect_to: str | None = None
|
||||
|
||||
|
||||
class PasswordResetConfirmRequest(BaseModel):
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str = Field(min_length=6)
|
||||
|
||||
|
||||
class PasswordResetResponse(BaseModel):
|
||||
message: str = "Password reset email sent"
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
@@ -40,6 +42,14 @@ class AuthServiceGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
@@ -71,3 +81,11 @@ class AuthService:
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
return await self._gateway.get_user_by_email(email)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
await self._gateway.request_password_reset(request)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
await self._gateway.confirm_password_reset(request)
|
||||
|
||||
@@ -11,11 +11,21 @@ from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from core.logging import get_logger
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.users.repository import SQLAlchemyUserRepository
|
||||
from v1.users.service import UserService
|
||||
from v1.users.service import AuthLookupAdapter, UserService
|
||||
|
||||
logger = get_logger("v1.users.dependencies")
|
||||
|
||||
_auth_gateway: SupabaseAuthGateway | None = None
|
||||
|
||||
|
||||
def get_auth_gateway() -> SupabaseAuthGateway:
|
||||
global _auth_gateway
|
||||
if _auth_gateway is None:
|
||||
_auth_gateway = SupabaseAuthGateway()
|
||||
return _auth_gateway
|
||||
|
||||
|
||||
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
|
||||
if not authorization:
|
||||
@@ -98,4 +108,10 @@ def get_user_service(
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UserService:
|
||||
repository = SQLAlchemyUserRepository(session)
|
||||
return UserService(repository=repository, session=session, current_user=user)
|
||||
auth_gateway = AuthLookupAdapter(get_auth_gateway())
|
||||
return UserService(
|
||||
repository=repository,
|
||||
session=session,
|
||||
current_user=user,
|
||||
auth_gateway=auth_gateway,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
@@ -33,6 +33,10 @@ class UserRepository(Protocol):
|
||||
"""Update user by user ID. Returns updated user or None if not found."""
|
||||
...
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
|
||||
"""Search users by username (ilike) or email (exact match)."""
|
||||
...
|
||||
|
||||
|
||||
class SQLAlchemyUserRepository(BaseRepository[Profile]):
|
||||
"""SQLAlchemy implementation of UserRepository.
|
||||
@@ -77,5 +81,24 @@ class SQLAlchemyUserRepository(BaseRepository[Profile]):
|
||||
try:
|
||||
return await self.update_by_id(user_id, update_data)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("User update failed", user_id=str(user_id))
|
||||
logger.exception("User update failed", user=str(user_id))
|
||||
raise
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
|
||||
try:
|
||||
stmt = (
|
||||
select(Profile)
|
||||
.where(Profile.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
Profile.username.ilike(f"%{query}%"),
|
||||
)
|
||||
)
|
||||
.order_by(Profile.created_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
except SQLAlchemyError:
|
||||
logger.exception("User search failed", query=query)
|
||||
raise
|
||||
|
||||
@@ -2,10 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from v1.users.dependencies import get_user_service
|
||||
from v1.users.schemas import UserResponse, UserUpdateRequest
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
@@ -27,11 +27,9 @@ async def update_me(
|
||||
return await service.update_me(payload)
|
||||
|
||||
|
||||
@router.get("/{username}", response_model=UserResponse)
|
||||
async def get_by_username(
|
||||
username: Annotated[
|
||||
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
|
||||
],
|
||||
@router.post("/search", response_model=list[UserResponse])
|
||||
async def search_users(
|
||||
payload: UserSearchRequest,
|
||||
service: Annotated[UserService, Depends(get_user_service)],
|
||||
) -> UserResponse:
|
||||
return await service.get_by_username(username)
|
||||
) -> list[UserResponse]:
|
||||
return await service.search_users(payload)
|
||||
|
||||
@@ -19,6 +19,17 @@ class UserResponse(BaseModel):
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
class UserSearchRequest(BaseModel):
|
||||
query: str = Field(min_length=1, max_length=100)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
|
||||
|
||||
class UserUpdateRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -9,13 +11,37 @@ from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from v1.users.repository import UserRepository
|
||||
from v1.users.schemas import UserResponse, UserUpdateRequest
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
|
||||
logger = get_logger("v1.users.service")
|
||||
|
||||
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
class AuthLookupGateway(Protocol):
|
||||
async def get_user_id_by_email(self, email: str) -> str | None: ...
|
||||
|
||||
|
||||
class AuthByEmailGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
|
||||
|
||||
|
||||
class AuthLookupAdapter:
|
||||
def __init__(self, gateway: AuthByEmailGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def get_user_id_by_email(self, email: str) -> str | None:
|
||||
try:
|
||||
response = await self._gateway.get_user_by_email(email)
|
||||
return response.id
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
"""User service handling business logic and transactions.
|
||||
@@ -28,16 +54,19 @@ class UserService(BaseService):
|
||||
|
||||
_repository: UserRepository
|
||||
_session: AsyncSession
|
||||
_auth_gateway: AuthLookupGateway | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: UserRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
auth_gateway: AuthLookupGateway | None = None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
self._auth_gateway = auth_gateway
|
||||
|
||||
async def get_me(self) -> UserResponse:
|
||||
user_id = self.require_user_id()
|
||||
@@ -101,3 +130,52 @@ class UserService(BaseService):
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
|
||||
query = request.query.strip()
|
||||
|
||||
if _EMAIL_PATTERN.match(query):
|
||||
return await self._search_by_email(query)
|
||||
|
||||
return await self._search_by_username(query)
|
||||
|
||||
async def _search_by_email(self, email: str) -> list[UserResponse]:
|
||||
if self._auth_gateway is None:
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
|
||||
if user_id_str is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
user = await self._repository.get_by_user_id(UUID(user_id_str))
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
if user is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
]
|
||||
|
||||
async def _search_by_username(self, query: str) -> list[UserResponse]:
|
||||
try:
|
||||
users = await self._repository.search_users(query, limit=20)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user