feat: 实现密码重置功能与用户搜索API,优化注册登录流程

- 新增忘记密码页面与重置密码确认流程(前端+后端)
- 修复注册验证码页登录跳转路由
- 新增用户搜索API(按邮箱查询)
- 简化infra脚本,统一为app.sh
- 补充密码重置与用户API测试覆盖
- 更新runtime文档与AGENTS配置
This commit is contained in:
qzl
2026-02-27 15:22:42 +08:00
parent 0d4811fee5
commit e4e995854d
37 changed files with 2101 additions and 222 deletions
+29 -25
View File
@@ -1,35 +1,39 @@
## Repository Structure # Project Development Guide
- `infra/`: Infrastructure and operations (Docker, scripts, deployment). This file serves as the entry point for project development, directing to appropriate constraint files based on development context.
- `backend/`: FastAPI backend.
- `apps/`: Flutter mobile app. ## Project Structure
- `docs/`: Documentation and design/planning artifacts.
```
social-app/
├── apps/ # Flutter mobile app
├── backend/ # FastAPI backend service
├── infra/ # Infrastructure (Docker, deployment scripts)
└── docs/ # Documentation and design/planning artifacts
```
## Rules Hierarchy ## Rules Hierarchy
- This root `AGENTS.md` defines global rules and applies to all changes. Follow this hierarchy when developing:
- When editing `backend/`, you must also follow `backend/AGENTS.md`.
- When editing `apps/`, you must also follow `apps/AGENTS.md`.
## Docker Startup ```
~/.config/opencode/AGENTS.md # Global core rules (skills, agents, process)
Always start services with the env file: ├── This file (root AGENTS.md) # Project-level entry
│ ├── backend/AGENTS.md # Backend-specific rules
```bash │ └── apps/AGENTS.md # Frontend-specific rules
docker compose --env-file .env -f infra/docker/docker-compose.yml up -d
``` ```
## Git Branch and Worktree Policy ## Development Guidance
- Use `dev` as the default base branch for day-to-day development. | Development Context | Follow Rules |
- New development worktrees must be created from `dev` (never from `main`). |--------------------|--------------|
- Do not develop or commit directly on `main` outside explicit release/merge workflows. | Backend Python dev | [backend/AGENTS.md](backend/AGENTS.md) |
- Do not rewrite `main` history unless explicitly requested (including reset and force push). | Flutter mobile dev | [apps/AGENTS.md](apps/AGENTS.md) |
| Infrastructure/ops | This file + infra/ directory conventions |
| API doc changes | Sync to `docs/runtime/runtime-route.md` |
## API Route Documentation ## Git Workflow
When modifying HTTP routes (adding, updating, or removing endpoints): - Default branch: `dev`
- Feature development: use worktree `git worktree add -b feature/xxx ../feature-xxx dev`
- Sync changes to `docs/runtime/runtime-route.md` - Never develop directly on `main`
- Include: HTTP method, path, request/response schema, status codes, error format
- Keep documentation in sync with actual implementation
+57 -20
View File
@@ -1,16 +1,55 @@
## Mobile Rules # Flutter Mobile Development Rules
- Flutter mobile rules are maintained here. This document defines Flutter mobile development constraints.
- If no more specific rule is defined here, follow the root `AGENTS.md`.
## Flutter Design-to-Code Workflow ## Design System
Before writing any Flutter UI code, follow this sequence: ### Design Tokens
1. **Get editor state**: Use `pencil_get_editor_state` to confirm the active design. All UI styling must use design tokens from `apps/lib/core/theme/design_tokens.dart`:
2. **Get structure**: Use `pencil_batch_get` to inspect node hierarchy and layout.
3. **Get variables**: Use `pencil_get_variables` to fetch colors, typography, and tokens. | Type | Usage |
4. **Implement**: Match design values and container hierarchy exactly. |------|-------|
| Colors | `AppColors.primary`, `AppColors.slate500`, `AppColors.background` |
| Spacing | `AppSpacing.xs`, `AppSpacing.sm`, `AppSpacing.md` |
| Radius | `AppRadius.sm`, `AppRadius.md`, `AppRadius.lg` |
**NEVER hardcode colors, sizes, or spacing values.**
### Reuse Existing Components
Use pre-built components instead of creating custom ones:
- Buttons: Use `AppButton` widget from `apps/lib/shared/widgets/app_button.dart`
- Input fields: Use standard Flutter `TextField` with `InputDecoration`
- Loading states: Use built-in loading indicators
## New Page Design Workflow
1. **Analyze existing pages**: Study login, register, home screens for:
- Layout structure (centered form, padding, spacing)
- Typography hierarchy (title 28px bold, label 13px, hint 14px)
- Component usage (AppButton, TextField style)
- Color and spacing tokens
2. **Use frontend-design skill for mockups**:
```
Use the `frontend-design` skill to create HTML/CSS mockups for review
Match colors to `apps/lib/core/theme/design_tokens.dart`
Match spacing to `AppSpacing` values
Match radius to `AppRadius` values
```
3. **Verify design tokens**:
- All colors from `AppColors`
- All spacing from `AppSpacing`
- All radius from `AppRadius`
- NO hardcoded values
4. **Code review checklist**:
- [ ] All colors/spacing/radius use design tokens
- [ ] Reuses existing components (AppButton)
- [ ] Consistent with existing page patterns
- [ ] No magic numbers
## Layout Mapping Rules ## Layout Mapping Rules
@@ -21,15 +60,13 @@ Map design layout properties to Flutter explicitly:
- `alignItems: start` -> `CrossAxisAlignment.start` - `alignItems: start` -> `CrossAxisAlignment.start`
- `alignItems: stretch` -> `CrossAxisAlignment.stretch` - `alignItems: stretch` -> `CrossAxisAlignment.stretch`
2. **Map full container chain**: From root to leaf, ensure each `alignItems` and `justifyContent` has a Flutter equivalent. 2. **Map full container chain**: From root to leaf, ensure each `alignItems` and `justifyContent` has a Flutter equivalent.
3. **Analyze before coding**: Use `pencil_snapshot_layout` or `pencil_batch_get` to verify each container's alignment settings. 3. **Analyze before coding**: Verify each container's alignment settings.
## Centering and Visual Balance ## Centering and Visual Balance
Apply these rules on any screen that relies on centered composition: 1. Centering must be evaluated inside **`SafeArea`** bounds, not full-screen bounds.
1. Centering must be evaluated inside **`SafeArea` bounds**, not full-screen bounds.
2. Avoid relying on proportional `Spacer` values as the only centering mechanism for critical content. 2. Avoid relying on proportional `Spacer` values as the only centering mechanism for critical content.
3. For layouts with persistent top/bottom regions (for example headers or footers), center the primary content in the remaining available region. 3. For layouts with persistent top/bottom regions (e.g., headers or footers), center the primary content in the remaining available region.
4. Distinguish geometric centering from visual centering; validate final visual balance with screenshot review. 4. Distinguish geometric centering from visual centering; validate final visual balance with screenshot review.
## Quality Gate ## Quality Gate
@@ -41,10 +78,10 @@ For important screens, add widget tests that reduce layout-regression risk:
## Prohibitions ## Prohibitions
- Do not use colors or themes not defined in the design. - DO NOT use colors not defined in design tokens
- Do not skip design container layers. - DO NOT skip design container layers
- Do not start implementation before retrieving design variables. - DO NOT start implementation before retrieving design variables
- Do not hardcode colors; use design variables. - DO NOT hardcode colors; use design variables
## UI Feedback System ## UI Feedback System
@@ -82,5 +119,5 @@ AppBanner(message: '请检查输入', type: ToastType.warning)
- Use `Toast` for transient feedback that auto-dismisses - Use `Toast` for transient feedback that auto-dismisses
- Use `AppBanner` for persistent inline messages (form errors) - Use `AppBanner` for persistent inline messages (form errors)
- Do NOT create custom SnackBar, Dialog, or Banner components - DO NOT create custom SnackBar, Dialog, or Banner components
- Do NOT use raw `ScaffoldMessenger` - DO NOT use raw `ScaffoldMessenger`
+5
View File
@@ -5,6 +5,7 @@ import 'go_router_refresh_stream.dart';
import '../../features/auth/ui/screens/login_screen.dart'; import '../../features/auth/ui/screens/login_screen.dart';
import '../../features/auth/ui/screens/register_screen.dart'; import '../../features/auth/ui/screens/register_screen.dart';
import '../../features/auth/ui/screens/register_verification_screen.dart'; import '../../features/auth/ui/screens/register_verification_screen.dart';
import '../../features/auth/ui/screens/reset_password_screen.dart';
import '../../features/home/ui/screens/home_screen.dart'; import '../../features/home/ui/screens/home_screen.dart';
import '../../features/messages/ui/screens/message_invite_list_screen.dart'; import '../../features/messages/ui/screens/message_invite_list_screen.dart';
import '../../features/messages/ui/screens/message_invite_detail_screen.dart'; import '../../features/messages/ui/screens/message_invite_detail_screen.dart';
@@ -67,6 +68,10 @@ GoRouter createAppRouter(AuthBloc authBloc) {
path: '/register/verification', path: '/register/verification',
builder: (context, state) => const RegisterVerificationScreen(), builder: (context, state) => const RegisterVerificationScreen(),
), ),
GoRoute(
path: '/reset-password',
builder: (context, state) => const ResetPasswordScreen(),
),
GoRoute(path: '/home', builder: (context, state) => const HomeScreen()), GoRoute(path: '/home', builder: (context, state) => const HomeScreen()),
GoRoute( GoRoute(
path: '/messages/invites', path: '/messages/invites',
+15
View File
@@ -50,4 +50,19 @@ class AuthApi {
Future<void> deleteSession(LogoutRequest request) async { Future<void> deleteSession(LogoutRequest request) async {
await _client.delete('$_prefix/sessions', data: request.toJson()); await _client.delete('$_prefix/sessions', data: request.toJson());
} }
Future<void> requestPasswordReset(String email) async {
await _client.post('$_prefix/password-reset', data: {'email': email});
}
Future<void> confirmPasswordReset({
required String email,
required String token,
required String newPassword,
}) async {
await _client.post(
'$_prefix/password-reset/confirm',
data: {'email': email, 'token': token, 'new_password': newPassword},
);
}
} }
@@ -14,4 +14,10 @@ abstract class AuthRepository {
Future<String?> getAccessToken(); Future<String?> getAccessToken();
Future<String?> getRefreshToken(); Future<String?> getRefreshToken();
Future<bool> isAuthenticated(); Future<bool> isAuthenticated();
Future<void> requestPasswordReset(String email);
Future<void> confirmPasswordReset({
required String email,
required String token,
required String newPassword,
});
} }
@@ -77,4 +77,22 @@ class AuthRepositoryImpl implements AuthRepository {
final token = await _tokenStorage.getAccessToken(); final token = await _tokenStorage.getAccessToken();
return token != null; return token != null;
} }
@override
Future<void> requestPasswordReset(String email) {
return _api.requestPasswordReset(email);
}
@override
Future<void> confirmPasswordReset({
required String email,
required String token,
required String newPassword,
}) {
return _api.confirmPasswordReset(
email: email,
token: token,
newPassword: newPassword,
);
}
} }
@@ -2,17 +2,20 @@ class SignupStartRequest {
final String username; final String username;
final String email; final String email;
final String password; final String password;
final String? inviteCode;
const SignupStartRequest({ const SignupStartRequest({
required this.username, required this.username,
required this.email, required this.email,
required this.password, required this.password,
this.inviteCode,
}); });
Map<String, dynamic> toJson() => { Map<String, dynamic> toJson() => {
'username': username, 'username': username,
'email': email, 'email': email,
'password': password, 'password': password,
if (inviteCode != null) 'invite_code': inviteCode,
}; };
} }
@@ -12,6 +12,7 @@ class RegisterState extends Equatable {
final Email email; final Email email;
final Password password; final Password password;
final VerificationCode verificationCode; final VerificationCode verificationCode;
final String inviteCode;
final FormzSubmissionStatus status; final FormzSubmissionStatus status;
final String? errorMessage; final String? errorMessage;
final String? pendingEmail; final String? pendingEmail;
@@ -23,6 +24,7 @@ class RegisterState extends Equatable {
this.email = const Email.pure(), this.email = const Email.pure(),
this.password = const Password.pure(), this.password = const Password.pure(),
this.verificationCode = const VerificationCode.pure(), this.verificationCode = const VerificationCode.pure(),
this.inviteCode = '',
this.status = FormzSubmissionStatus.initial, this.status = FormzSubmissionStatus.initial,
this.errorMessage, this.errorMessage,
this.pendingEmail, this.pendingEmail,
@@ -39,6 +41,7 @@ class RegisterState extends Equatable {
Email? email, Email? email,
Password? password, Password? password,
VerificationCode? verificationCode, VerificationCode? verificationCode,
String? inviteCode,
FormzSubmissionStatus? status, FormzSubmissionStatus? status,
String? errorMessage, String? errorMessage,
String? pendingEmail, String? pendingEmail,
@@ -50,6 +53,7 @@ class RegisterState extends Equatable {
email: email ?? this.email, email: email ?? this.email,
password: password ?? this.password, password: password ?? this.password,
verificationCode: verificationCode ?? this.verificationCode, verificationCode: verificationCode ?? this.verificationCode,
inviteCode: inviteCode ?? this.inviteCode,
status: status ?? this.status, status: status ?? this.status,
errorMessage: errorMessage, errorMessage: errorMessage,
pendingEmail: pendingEmail ?? this.pendingEmail, pendingEmail: pendingEmail ?? this.pendingEmail,
@@ -64,6 +68,7 @@ class RegisterState extends Equatable {
email, email,
password, password,
verificationCode, verificationCode,
inviteCode,
status, status,
errorMessage, errorMessage,
pendingEmail, pendingEmail,
@@ -93,6 +98,10 @@ class RegisterCubit extends Cubit<RegisterState> {
emit(state.copyWith(verificationCode: VerificationCode.dirty(value))); emit(state.copyWith(verificationCode: VerificationCode.dirty(value)));
} }
void inviteCodeChanged(String value) {
emit(state.copyWith(inviteCode: value));
}
Future<bool> submitStep1() async { Future<bool> submitStep1() async {
if (!state.isStep1Valid) return false; if (!state.isStep1Valid) return false;
@@ -104,6 +113,7 @@ class RegisterCubit extends Cubit<RegisterState> {
username: state.username.value, username: state.username.value,
email: state.email.value, email: state.email.value,
password: state.password.value, password: state.password.value,
inviteCode: state.inviteCode.isNotEmpty ? state.inviteCode : null,
), ),
); );
emit( emit(
@@ -202,6 +212,7 @@ class RegisterCubit extends Cubit<RegisterState> {
username: state.username.value, username: state.username.value,
email: state.email.value, email: state.email.value,
password: state.password.value, password: state.password.value,
inviteCode: state.inviteCode.isNotEmpty ? state.inviteCode : null,
), ),
); );
emit( emit(
@@ -0,0 +1,314 @@
import 'dart:async';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:equatable/equatable.dart';
import 'package:formz/formz.dart';
import '../../../../core/form_inputs/form_inputs.dart';
import '../../data/auth_repository.dart';
class ResetPasswordState extends Equatable {
final Email email;
final VerificationCode code;
final Password newPassword;
final Password confirmPassword;
final FormzSubmissionStatus status;
final String? errorMessage;
final bool isSuccess;
final int resendCountdown;
final bool codeSent;
const ResetPasswordState({
this.email = const Email.pure(),
this.code = const VerificationCode.pure(),
this.newPassword = const Password.pure(),
this.confirmPassword = const Password.pure(),
this.status = FormzSubmissionStatus.initial,
this.errorMessage,
this.isSuccess = false,
this.resendCountdown = 0,
this.codeSent = false,
});
bool get canSubmit {
if (!codeSent) {
return email.isValid && status != FormzSubmissionStatus.inProgress;
}
return email.isValid &&
code.isValid &&
newPassword.isValid &&
confirmPassword.isValid &&
newPassword.value == confirmPassword.value &&
status != FormzSubmissionStatus.inProgress;
}
ResetPasswordState copyWith({
Email? email,
VerificationCode? code,
Password? newPassword,
Password? confirmPassword,
FormzSubmissionStatus? status,
String? errorMessage,
bool? isSuccess,
int? resendCountdown,
bool? codeSent,
}) {
return ResetPasswordState(
email: email ?? this.email,
code: code ?? this.code,
newPassword: newPassword ?? this.newPassword,
confirmPassword: confirmPassword ?? this.confirmPassword,
status: status ?? this.status,
errorMessage: errorMessage,
isSuccess: isSuccess ?? this.isSuccess,
resendCountdown: resendCountdown ?? this.resendCountdown,
codeSent: codeSent ?? this.codeSent,
);
}
@override
List<Object?> get props => [
email,
code,
newPassword,
confirmPassword,
status,
errorMessage,
isSuccess,
resendCountdown,
codeSent,
];
}
class ResetPasswordCubit extends Cubit<ResetPasswordState> {
final AuthRepository _repository;
Timer? _resendTimer;
ResetPasswordCubit(this._repository) : super(const ResetPasswordState());
@override
Future<void> close() {
_resendTimer?.cancel();
return super.close();
}
void emailChanged(String value) {
emit(state.copyWith(email: Email.dirty(value), errorMessage: null));
}
void codeChanged(String value) {
emit(
state.copyWith(code: VerificationCode.dirty(value), errorMessage: null),
);
}
void newPasswordChanged(String value) {
emit(
state.copyWith(newPassword: Password.dirty(value), errorMessage: null),
);
}
void confirmPasswordChanged(String value) {
emit(
state.copyWith(
confirmPassword: Password.dirty(value),
errorMessage: null,
),
);
}
Future<void> sendCode() async {
if (state.status == FormzSubmissionStatus.inProgress ||
state.resendCountdown > 0) {
return;
}
if (!state.email.isValid) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: state.email.value.isEmpty ? '请输入邮箱' : '邮箱格式不正确',
),
);
return;
}
emit(
state.copyWith(
status: FormzSubmissionStatus.inProgress,
codeSent: true,
resendCountdown: 60,
errorMessage: null,
),
);
_startResendCountdown();
try {
await _repository.requestPasswordReset(state.email.value);
emit(
state.copyWith(
status: FormzSubmissionStatus.success,
errorMessage: 'CODE_SENT_SUCCESS',
),
);
} catch (e) {
_cancelResendCountdown();
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
codeSent: false,
resendCountdown: 0,
errorMessage: '网络错误,请稍后重试',
),
);
}
}
void _cancelResendCountdown() {
_resendTimer?.cancel();
}
void _startResendCountdown() {
_cancelResendCountdown();
_resendTimer = Timer.periodic(const Duration(seconds: 1), (timer) {
final newCountdown = state.resendCountdown - 1;
if (newCountdown <= 0) {
timer.cancel();
emit(state.copyWith(resendCountdown: 0));
} else {
emit(state.copyWith(resendCountdown: newCountdown));
}
});
}
Future<void> resendCode() async {
if (state.resendCountdown > 0 ||
state.status == FormzSubmissionStatus.inProgress) {
return;
}
if (!state.email.isValid) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: state.email.value.isEmpty ? '请输入邮箱' : '邮箱格式不正确',
),
);
return;
}
emit(
state.copyWith(
status: FormzSubmissionStatus.inProgress,
codeSent: true,
resendCountdown: 60,
errorMessage: null,
),
);
_startResendCountdown();
try {
await _repository.requestPasswordReset(state.email.value);
emit(
state.copyWith(
status: FormzSubmissionStatus.success,
errorMessage: 'CODE_SENT_SUCCESS',
),
);
} catch (e) {
_cancelResendCountdown();
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
resendCountdown: 0,
errorMessage: '网络错误,请稍后重试',
),
);
}
}
Future<void> submit() async {
if (!state.codeSent) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '请先获取验证码',
),
);
return;
}
if (!state.email.isValid) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '请输入有效的邮箱地址',
),
);
return;
}
if (!state.code.isValid) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '请输入6位验证码',
),
);
return;
}
if (!state.newPassword.isValid) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '新密码至少6位',
),
);
return;
}
if (!state.confirmPassword.isValid) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '请输入确认密码',
),
);
return;
}
if (state.newPassword.value != state.confirmPassword.value) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '两次密码输入不一致',
),
);
return;
}
emit(
state.copyWith(
status: FormzSubmissionStatus.inProgress,
errorMessage: null,
),
);
try {
await _repository.confirmPasswordReset(
email: state.email.value,
token: state.code.value,
newPassword: state.newPassword.value,
);
emit(
state.copyWith(status: FormzSubmissionStatus.success, isSuccess: true),
);
} catch (e) {
emit(
state.copyWith(
status: FormzSubmissionStatus.failure,
errorMessage: '密码重置失败,请检查验证码',
),
);
}
}
}
@@ -162,6 +162,8 @@ class _LoginViewState extends State<LoginView> {
? null ? null
: _handleLogin, : _handleLogin,
), ),
const SizedBox(height: 12),
_buildForgotPassword(),
], ],
), ),
); );
@@ -236,6 +238,20 @@ class _LoginViewState extends State<LoginView> {
); );
} }
Widget _buildForgotPassword() {
return GestureDetector(
onTap: () => context.push('/reset-password'),
child: const Text(
'忘记密码?',
style: TextStyle(
fontSize: 14,
fontWeight: FontWeight.w500,
color: AppColors.slate500,
),
),
);
}
Widget _buildFooter() { Widget _buildFooter() {
return GestureDetector( return GestureDetector(
onTap: () => context.push('/register'), onTap: () => context.push('/register'),
@@ -36,6 +36,7 @@ class _RegisterViewState extends State<RegisterView> {
final _nicknameController = TextEditingController(); final _nicknameController = TextEditingController();
final _emailController = TextEditingController(); final _emailController = TextEditingController();
final _passwordController = TextEditingController(); final _passwordController = TextEditingController();
final _inviteCodeController = TextEditingController();
bool _obscureText = true; bool _obscureText = true;
@override @override
@@ -43,6 +44,7 @@ class _RegisterViewState extends State<RegisterView> {
_nicknameController.dispose(); _nicknameController.dispose();
_emailController.dispose(); _emailController.dispose();
_passwordController.dispose(); _passwordController.dispose();
_inviteCodeController.dispose();
super.dispose(); super.dispose();
} }
@@ -51,6 +53,7 @@ class _RegisterViewState extends State<RegisterView> {
cubit.usernameChanged(_nicknameController.text); cubit.usernameChanged(_nicknameController.text);
cubit.emailChanged(_emailController.text); cubit.emailChanged(_emailController.text);
cubit.passwordChanged(_passwordController.text); cubit.passwordChanged(_passwordController.text);
cubit.inviteCodeChanged(_inviteCodeController.text);
if (!cubit.state.isStep1Valid || cubit.state.isSending) { if (!cubit.state.isStep1Valid || cubit.state.isSending) {
String? errorMsg; String? errorMsg;
@@ -159,6 +162,8 @@ class _RegisterViewState extends State<RegisterView> {
const SizedBox(height: 12), const SizedBox(height: 12),
_buildPasswordInput(), _buildPasswordInput(),
const SizedBox(height: 12), const SizedBox(height: 12),
_buildInput('邀请码(选填)', '请输入邀请码', _inviteCodeController),
const SizedBox(height: 12),
_buildStepIndicator(), _buildStepIndicator(),
if (state.errorMessage != null) if (state.errorMessage != null)
Padding( Padding(
@@ -48,10 +48,22 @@ class _RegisterVerificationViewState extends State<RegisterVerificationView> {
Timer? _countdownTimer; Timer? _countdownTimer;
int _countdown = 0; int _countdown = 0;
bool _firstSendCompleted = false; bool _firstSendCompleted = false;
bool _hintShown = false;
@override @override
void initState() { void initState() {
super.initState(); super.initState();
WidgetsBinding.instance.addPostFrameCallback((_) {
if (!_hintShown) {
_hintShown = true;
Toast.show(
context,
'验证码已发送,如未收到请检查垃圾邮件或确认邮箱已注册',
type: ToastType.info,
duration: const Duration(seconds: 5),
);
}
});
} }
@override @override
@@ -331,7 +343,7 @@ class _RegisterVerificationViewState extends State<RegisterVerificationView> {
Widget _buildFooter() { Widget _buildFooter() {
return GestureDetector( return GestureDetector(
onTap: () => context.pop(), onTap: () => context.go('/'),
child: const Text( child: const Text(
'已有账号?去登录', '已有账号?去登录',
style: TextStyle( style: TextStyle(
@@ -0,0 +1,355 @@
import 'package:flutter/material.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:formz/formz.dart';
import 'package:go_router/go_router.dart';
import '../../../../core/theme/design_tokens.dart';
import '../../../../core/di/injection.dart';
import '../../../../shared/widgets/app_button.dart';
import '../../../../shared/widgets/toast/toast.dart';
import '../../../../shared/widgets/toast/toast_type.dart';
import '../../presentation/cubits/reset_password_cubit.dart';
import '../../data/auth_repository.dart';
class ResetPasswordScreen extends StatelessWidget {
const ResetPasswordScreen({super.key});
@override
Widget build(BuildContext context) {
return BlocProvider(
create: (context) => ResetPasswordCubit(sl<AuthRepository>()),
child: const ResetPasswordView(),
);
}
}
class ResetPasswordView extends StatefulWidget {
const ResetPasswordView({super.key});
@override
State<ResetPasswordView> createState() => _ResetPasswordViewState();
}
class _ResetPasswordViewState extends State<ResetPasswordView> {
final _emailController = TextEditingController();
final _codeController = TextEditingController();
final _passwordController = TextEditingController();
final _confirmPasswordController = TextEditingController();
bool _obscurePassword = true;
bool _obscureConfirmPassword = true;
@override
void dispose() {
_emailController.dispose();
_codeController.dispose();
_passwordController.dispose();
_confirmPasswordController.dispose();
super.dispose();
}
Future<void> _handleSubmit() async {
final cubit = context.read<ResetPasswordCubit>();
cubit.emailChanged(_emailController.text);
cubit.codeChanged(_codeController.text);
cubit.newPasswordChanged(_passwordController.text);
cubit.confirmPasswordChanged(_confirmPasswordController.text);
await cubit.submit();
}
@override
Widget build(BuildContext context) {
return BlocListener<ResetPasswordCubit, ResetPasswordState>(
listenWhen: (previous, current) =>
previous.status != current.status ||
previous.errorMessage != current.errorMessage ||
previous.codeSent != current.codeSent,
listener: (context, state) {
if (state.status == FormzSubmissionStatus.success && state.isSuccess) {
Toast.show(context, '密码重置成功,请使用新密码登录', type: ToastType.success);
context.go('/');
} else if (state.status == FormzSubmissionStatus.success &&
state.codeSent &&
state.errorMessage == 'CODE_SENT_SUCCESS') {
Toast.show(context, '验证码已发送到您的邮箱', type: ToastType.success);
} else if (state.status == FormzSubmissionStatus.failure &&
state.errorMessage != null &&
state.errorMessage != '' &&
state.errorMessage != 'CODE_SENT_SUCCESS') {
Toast.show(context, state.errorMessage!, type: ToastType.error);
}
},
child: Scaffold(
backgroundColor: AppColors.background,
body: SafeArea(
child: Padding(
padding: const EdgeInsets.symmetric(horizontal: 24),
child: Column(
crossAxisAlignment: CrossAxisAlignment.center,
children: [
Expanded(
child: Center(
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.center,
children: [
_buildTitle(),
const SizedBox(height: 32),
_buildFormContainer(),
],
),
),
),
],
),
),
),
),
);
}
Widget _buildTitle() {
return const Text(
'忘记密码',
style: TextStyle(
fontSize: 28,
fontWeight: FontWeight.w700,
color: AppColors.slate900,
),
);
}
Widget _buildFormContainer() {
return BlocBuilder<ResetPasswordCubit, ResetPasswordState>(
builder: (context, state) {
return SizedBox(
width: 327,
child: Column(
crossAxisAlignment: CrossAxisAlignment.stretch,
children: [
_buildEmailInput(state.email.displayError != null),
const SizedBox(height: 12),
_buildCodeInput(state.code.displayError != null, state),
const SizedBox(height: 12),
_buildPasswordInput(state.newPassword.displayError != null),
const SizedBox(height: 12),
_buildConfirmPasswordInput(
state.confirmPassword.displayError != null,
),
const SizedBox(height: 24),
_buildSubmitButton(state),
const SizedBox(height: 16),
_buildBackToLogin(),
],
),
);
},
);
}
Widget _buildEmailInput(bool hasError) {
return Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
const Text(
'邮箱',
style: TextStyle(
fontSize: 13,
fontWeight: FontWeight.w500,
color: AppColors.slate600,
),
),
const SizedBox(height: 6),
TextField(
controller: _emailController,
keyboardType: TextInputType.emailAddress,
onChanged: (value) {
context.read<ResetPasswordCubit>().emailChanged(value);
},
decoration: InputDecoration(
hintText: '请输入邮箱',
errorText: hasError ? ' ' : null,
),
),
],
);
}
Widget _buildCodeInput(bool hasError, ResetPasswordState state) {
return Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
const Text(
'验证码',
style: TextStyle(
fontSize: 13,
fontWeight: FontWeight.w500,
color: AppColors.slate600,
),
),
const SizedBox(height: 6),
Row(
children: [
Expanded(
child: TextField(
controller: _codeController,
keyboardType: TextInputType.number,
onChanged: (value) {
context.read<ResetPasswordCubit>().codeChanged(value);
},
decoration: InputDecoration(
hintText: '请输入 6 位验证码',
errorText: hasError ? ' ' : null,
),
),
),
const SizedBox(width: 12),
SizedBox(
height: 40,
child: TextButton(
onPressed:
state.resendCountdown > 0 ||
state.status == FormzSubmissionStatus.inProgress
? null
: () {
if (state.codeSent) {
context.read<ResetPasswordCubit>().resendCode();
} else {
context.read<ResetPasswordCubit>().sendCode();
}
},
style: TextButton.styleFrom(
backgroundColor: state.codeSent
? AppColors.background
: AppColors.primary,
foregroundColor: state.codeSent
? AppColors.primary
: AppColors.primaryForeground,
shape: RoundedRectangleBorder(
borderRadius: BorderRadius.circular(AppRadius.sm),
),
padding: const EdgeInsets.symmetric(horizontal: 14),
),
child: Text(
state.resendCountdown > 0
? '${state.resendCountdown}'
: (state.codeSent ? '重新发送' : '发送验证码'),
style: const TextStyle(
fontSize: 13,
fontWeight: FontWeight.w500,
),
),
),
),
],
),
],
);
}
Widget _buildPasswordInput(bool hasError) {
return Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
const Text(
'新密码',
style: TextStyle(
fontSize: 13,
fontWeight: FontWeight.w500,
color: AppColors.slate600,
),
),
const SizedBox(height: 6),
TextField(
controller: _passwordController,
obscureText: _obscurePassword,
onChanged: (value) {
context.read<ResetPasswordCubit>().newPasswordChanged(value);
},
decoration: InputDecoration(
hintText: '请输入新密码(至少 6 位)',
errorText: hasError ? ' ' : null,
suffixIcon: IconButton(
icon: Icon(
_obscurePassword ? Icons.visibility_off : Icons.visibility,
size: 20,
color: AppColors.slate400,
),
onPressed: () {
setState(() {
_obscurePassword = !_obscurePassword;
});
},
),
),
),
],
);
}
Widget _buildConfirmPasswordInput(bool hasError) {
return Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
const Text(
'确认密码',
style: TextStyle(
fontSize: 13,
fontWeight: FontWeight.w500,
color: AppColors.slate600,
),
),
const SizedBox(height: 6),
TextField(
controller: _confirmPasswordController,
obscureText: _obscureConfirmPassword,
onChanged: (value) {
context.read<ResetPasswordCubit>().confirmPasswordChanged(value);
},
decoration: InputDecoration(
hintText: '请再次输入新密码',
errorText: hasError ? ' ' : null,
suffixIcon: IconButton(
icon: Icon(
_obscureConfirmPassword
? Icons.visibility_off
: Icons.visibility,
size: 20,
color: AppColors.slate400,
),
onPressed: () {
setState(() {
_obscureConfirmPassword = !_obscureConfirmPassword;
});
},
),
),
),
],
);
}
Widget _buildSubmitButton(ResetPasswordState state) {
final isLoading = state.status == FormzSubmissionStatus.inProgress;
final isDisabled = isLoading || !state.codeSent;
return AppButton(
text: '重置密码',
onPressed: isDisabled ? null : _handleSubmit,
);
}
Widget _buildBackToLogin() {
return GestureDetector(
onTap: () => context.go('/'),
child: const Text(
'返回登录',
style: TextStyle(
fontSize: 14,
fontWeight: FontWeight.w500,
color: AppColors.slate500,
),
textAlign: TextAlign.center,
),
);
}
}
+7 -3
View File
@@ -17,8 +17,12 @@ class UsersApi {
return UserResponse.fromJson(response.data); return UserResponse.fromJson(response.data);
} }
Future<UserResponse> getByUsername(String username) async { Future<List<UserResponse>> searchUsers(String query) async {
final response = await _client.get('$_prefix/$username'); final response = await _client.post(
return UserResponse.fromJson(response.data); '$_prefix/search',
data: {'query': query},
);
final List<dynamic> data = response.data;
return data.map((json) => UserResponse.fromJson(json)).toList();
} }
} }
@@ -3,5 +3,5 @@ import 'models/user_response.dart';
abstract class UsersRepository { abstract class UsersRepository {
Future<UserResponse> getMe(); Future<UserResponse> getMe();
Future<UserResponse> updateMe(UserUpdateRequest request); Future<UserResponse> updateMe(UserUpdateRequest request);
Future<UserResponse> getByUsername(String username); Future<List<UserResponse>> searchUsers(String query);
} }
@@ -18,7 +18,7 @@ class UsersRepositoryImpl implements UsersRepository {
} }
@override @override
Future<UserResponse> getByUsername(String username) { Future<List<UserResponse>> searchUsers(String query) {
return _api.getByUsername(username); return _api.searchUsers(query);
} }
} }
@@ -0,0 +1,65 @@
import 'dart:async';
import 'package:flutter_test/flutter_test.dart';
import 'package:formz/formz.dart';
import 'package:mocktail/mocktail.dart';
import 'package:social_app/features/auth/data/auth_repository.dart';
import 'package:social_app/features/auth/presentation/cubits/reset_password_cubit.dart';
class MockAuthRepository extends Mock implements AuthRepository {}
void main() {
late ResetPasswordCubit cubit;
late MockAuthRepository mockRepository;
setUp(() {
mockRepository = MockAuthRepository();
cubit = ResetPasswordCubit(mockRepository);
});
tearDown(() async {
await cubit.close();
});
test(
'sendCode enters countdown immediately and prevents duplicate clicks',
() async {
final completer = Completer<void>();
when(
() => mockRepository.requestPasswordReset(any()),
).thenAnswer((_) => completer.future);
cubit.emailChanged('test@example.com');
final firstRequest = cubit.sendCode();
await Future<void>.delayed(Duration.zero);
expect(cubit.state.status, FormzSubmissionStatus.inProgress);
expect(cubit.state.codeSent, isTrue);
expect(cubit.state.resendCountdown, 60);
await cubit.sendCode();
verify(
() => mockRepository.requestPasswordReset('test@example.com'),
).called(1);
completer.complete();
await firstRequest;
},
);
test('sendCode failure cancels countdown and restores retry state', () async {
when(
() => mockRepository.requestPasswordReset(any()),
).thenThrow(Exception('network error'));
cubit.emailChanged('test@example.com');
await cubit.sendCode();
expect(cubit.state.status, FormzSubmissionStatus.failure);
expect(cubit.state.codeSent, isFalse);
expect(cubit.state.resendCountdown, 0);
expect(cubit.state.errorMessage, '网络错误,请稍后重试');
});
}
+80 -14
View File
@@ -1,3 +1,7 @@
# Backend Development Rules
This document defines Python/FastAPI backend development constraints.
## Python Environment ## Python Environment
**MUST use uv for dependency management and virtual environment execution.** **MUST use uv for dependency management and virtual environment execution.**
@@ -43,11 +47,10 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
- Tests can set env vars via `monkeypatch.setenv`, and should read values via `Settings()` unless the test is explicitly validating env plumbing - Tests can set env vars via `monkeypatch.setenv`, and should read values via `Settings()` unless the test is explicitly validating env plumbing
- Canonical principle: one source of truth per setting; no duplicate/derived env vars in backend code - Canonical principle: one source of truth per setting; no duplicate/derived env vars in backend code
## TDD First Policy ## TDD Workflow
**Principle: tests before implementation.**
### Coverage Requirements ### Coverage Requirements
- Minimum coverage: 80% - Minimum coverage: 80%
- Required test types: - Required test types:
- Unit: isolated functions, utilities, components - Unit: isolated functions, utilities, components
@@ -55,12 +58,14 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
- E2E: critical user flows (Playwright) - E2E: critical user flows (Playwright)
### Limited Exceptions ### Limited Exceptions
- Docs-only changes (README, comments, formatting) may skip integration/E2E - Docs-only changes (README, comments, formatting) may skip integration/E2E
- Non-runtime config changes may skip E2E if no behavior changes - Non-runtime config changes may skip E2E if no behavior changes
- Any runtime code change requires unit + integration + E2E - Any runtime code change requires unit + integration + E2E
- If an exception is used, record the reason in the PR/test notes - If an exception is used, record the reason in the PR/test notes
### Mandatory TDD Workflow ### Mandatory TDD Workflow
1. Write tests (RED) - they must fail 1. Write tests (RED) - they must fail
2. Run tests - confirm failure 2. Run tests - confirm failure
3. Implement minimal code (GREEN) - only to pass 3. Implement minimal code (GREEN) - only to pass
@@ -69,19 +74,80 @@ Do not bypass or weaken checks (no ignores, disables, or config relaxations). Re
6. Verify coverage - must be 80%+ 6. Verify coverage - must be 80%+
### Enforcement ### Enforcement
- Must use the `tdd-guide` agent for new features - Must use the `tdd-guide` agent for new features
- Do not write implementation before tests - Do not write implementation before tests
- Do not lower coverage requirements - Do not lower coverage requirements
- Must include unit, integration, and E2E tests - Must include unit, integration, and E2E tests
## Code Style
### Immutability
**ALWAYS create new objects, NEVER mutate.**
```python
# WRONG: Mutation
def update_user(user, name):
user["name"] = name
return user
# CORRECT: Immutability
def update_user(user, name):
return {**user, "name": name}
```
### File Organization
- Many small files over few large files
- 200-400 lines typical, 800 max per file
- Extract utilities from large components
### Error Handling
Always handle errors comprehensively:
```python
try:
result = risky_operation()
return result
except Exception as exc:
logger.exception("Operation failed")
raise RuntimeError("Detailed user-friendly message") from exc
```
## Security
### Mandatory Security Checks
Before ANY commit:
- [ ] No hardcoded secrets (API keys, passwords, tokens)
- [ ] All user inputs validated (use Pydantic)
- [ ] SQL injection prevention (parameterized queries)
- [ ] Authentication/authorization verified
### Secret Management
```python
# NEVER: Hardcoded secrets
api_key = "sk-proj-xxxxx"
# ALWAYS: Environment variables
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY not configured")
```
## Database Development Rules ## Database Development Rules
### Core Principle ### Architecture
- **Supabase**: authentication (JWT source of truth) - **Supabase**: authentication (JWT source of truth)
- **Backend**: business authorization (service layer) - **Backend**: business authorization (service layer)
- **SQLAlchemy ORM**: data access layer (async + asyncpg, service_role connection) - **SQLAlchemy ORM**: data access layer (async + asyncpg, service_role connection)
### Architecture ### Code Organization
Use `schemas / repository / service` pattern: Use `schemas / repository / service` pattern:
- `schemas.py` — Pydantic models - `schemas.py` — Pydantic models
- `repository.py` — CRUD only, no auth, no commit (only flush), must receive session (never create session/engine) - `repository.py` — CRUD only, no auth, no commit (only flush), must receive session (never create session/engine)
@@ -89,6 +155,7 @@ Use `schemas / repository / service` pattern:
- `dependencies.py` — DI (`get_db`, `get_current_user`) - `dependencies.py` — DI (`get_db`, `get_current_user`)
### Auth & Data Access ### Auth & Data Access
- Backend must verify JWT signature and expiration (not just decode) - Backend must verify JWT signature and expiration (not just decode)
- Extract `user_id` from JWT `sub` claim - Extract `user_id` from JWT `sub` claim
- Backend connects with **service_role** (bypasses RLS) - Backend connects with **service_role** (bypasses RLS)
@@ -98,31 +165,28 @@ Use `schemas / repository / service` pattern:
- Prohibit calling Supabase Admin API (service_role key) from repository/service layers - Prohibit calling Supabase Admin API (service_role key) from repository/service layers
### Migrations ### Migrations
- **Alembic is the single source of truth** for schema migrations - **Alembic is the single source of truth** for schema migrations
- ORM model changes → `alembic revision --autogenerate` - ORM model changes → `alembic revision --autogenerate`
- Raw SQL (policies, triggers, functions) → `op.execute()` - Raw SQL (policies, triggers, functions) → `op.execute()`
- Migrations must be reversible; no reliance on generated IDs - Migrations must be reversible; no reliance on generated IDs
### Enum Storage Convention ### Enum Storage Convention
**Store enum names (strings), not integer values.** **Store enum names (strings), not integer values.**
- Use `VARCHAR(20)` + `CHECK` constraint in database - Use `VARCHAR(20)` + `CHECK` constraint in database
- Use Python `Enum` class with `str` base in code - Use Python `Enum` class with `str` base in code
- Benefits: debugging readability, easy to add new values without data migration, ORM-friendly
```python ```python
# Correct
class AgentType(str, Enum): class AgentType(str, Enum):
INTENT_RECOGNITION = "INTENT_RECOGNITION" INTENT_RECOGNITION = "INTENT_RECOGNITION"
TASK_EXECUTION = "TASK_EXECUTION" TASK_EXECUTION = "TASK_EXECUTION"
RESULT_REPORTING = "RESULT_REPORTING" RESULT_REPORTING = "RESULT_REPORTING"
# Migration
ALTER TABLE user_agents ADD CONSTRAINT chk_agent_type
CHECK (agent_type IN ('INTENT_RECOGNITION', 'TASK_EXECUTION', 'RESULT_REPORTING'));
``` ```
### RLS Guidance ### RLS Policy
- Backend does not rely on RLS for correctness (uses service_role), but RLS is mandatory as a defensive boundary for tables in PostgREST-exposed schemas. - Backend does not rely on RLS for correctness (uses service_role), but RLS is mandatory as a defensive boundary for tables in PostgREST-exposed schemas.
- **Mandatory default**: any new business table in `public` must enable RLS in the same Alembic migration. - **Mandatory default**: any new business table in `public` must enable RLS in the same Alembic migration.
- The same migration must create policies covering `SELECT/INSERT/UPDATE/DELETE` (minimum requirement). - The same migration must create policies covering `SELECT/INSERT/UPDATE/DELETE` (minimum requirement).
@@ -130,11 +194,13 @@ ALTER TABLE user_agents ADD CONSTRAINT chk_agent_type
- `alembic_version` must not be exposed to `anon` or `authenticated`. - `alembic_version` must not be exposed to `anon` or `authenticated`.
#### Exemption Rule (strict) #### Exemption Rule (strict)
- Exemptions are allowed only when a new `public` table is guaranteed not to be exposed to PostgREST clients. - Exemptions are allowed only when a new `public` table is guaranteed not to be exposed to PostgREST clients.
- Exemptions must be explicit in the migration file with rationale and verification notes (why safe, how exposure is prevented). - Exemptions must be explicit in the migration file with rationale and verification notes.
- If exposure is uncertain, do not exempt: enable defensive RLS by default. - If exposure is uncertain, do not exempt: enable defensive RLS by default.
#### Migration Acceptance Checklist (RLS) #### Migration Checklist
- [ ] New `public` business table has `ALTER TABLE ... ENABLE ROW LEVEL SECURITY` in migration - [ ] New `public` business table has `ALTER TABLE ... ENABLE ROW LEVEL SECURITY` in migration
- [ ] Policies for `SELECT/INSERT/UPDATE/DELETE` are present in migration - [ ] Policies for `SELECT/INSERT/UPDATE/DELETE` are present in migration
- [ ] Policy target roles are explicit (`anon`, `authenticated`, or both) - [ ] Policy target roles are explicit (`anon`, `authenticated`, or both)
@@ -40,7 +40,6 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"), sa.UniqueConstraint("name"),
) )
op.create_index("ix_llm_factory_name", "llm_factory", ["name"], unique=True)
_enable_rls("llm_factory") _enable_rls("llm_factory")
op.create_table( op.create_table(
@@ -65,7 +64,6 @@ def upgrade() -> None:
sa.UniqueConstraint("model_code"), sa.UniqueConstraint("model_code"),
) )
op.create_index("ix_llms_factory_id", "llms", ["factory_id"], unique=False) op.create_index("ix_llms_factory_id", "llms", ["factory_id"], unique=False)
op.create_index("ix_llms_model_code", "llms", ["model_code"], unique=True)
op.create_foreign_key( op.create_foreign_key(
"fk_llms_factory_id", "fk_llms_factory_id",
"llms", "llms",
+61
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from fastapi import HTTPException from fastapi import HTTPException
@@ -10,6 +11,8 @@ from core.config.settings import SupabaseSettings, config
from core.logging import get_logger from core.logging import get_logger
from v1.auth.schemas import ( from v1.auth.schemas import (
AuthUser, AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest, SessionCreateRequest,
SessionRefreshRequest, SessionRefreshRequest,
SessionResponse, SessionResponse,
@@ -150,6 +153,64 @@ class SupabaseAuthGateway(AuthServiceGateway):
), ),
) )
async def request_password_reset(self, request: PasswordResetRequest) -> None:
try:
reset_email = cast(Any, self._client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
except AuthError as exc:
logger.warning(
"Password reset request failed",
error_type=type(exc).__name__,
)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
user_id = str(getattr(user, "id", "")) if user is not None else ""
if session is None or not user_id:
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
self._admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
except AuthError as exc:
logger.warning(
"Password reset confirm failed", error_type=type(exc).__name__
)
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
) from exc
def _coerce_reset_email(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, Mapping):
nested = value.get("email") or value.get("value")
if isinstance(nested, str):
return nested
raise HTTPException(status_code=422, detail="Invalid email")
def _map_auth_response(response: object, failure_message: str) -> SessionResponse: def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
session = getattr(response, "session", None) session = getattr(response, "session", None)
+32
View File
@@ -10,6 +10,8 @@ from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service from v1.auth.dependencies import get_auth_service
from v1.users.dependencies import get_current_user from v1.users.dependencies import get_current_user
from v1.auth.schemas import ( from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest, SessionCreateRequest,
SessionDeleteRequest, SessionDeleteRequest,
SessionRefreshRequest, SessionRefreshRequest,
@@ -123,3 +125,33 @@ async def get_user_by_email(
if current_user.role != "service_role" and current_user.email != email: if current_user.role != "service_role" and current_user.email != email:
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
return await service.get_user_by_email(email) return await service.get_user_by_email(email)
@router.post("/password-reset", status_code=204)
async def request_password_reset(
payload: PasswordResetRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_request",
identifier=payload.email,
limit=5,
window_seconds=60,
)
await service.request_password_reset(payload)
return Response(status_code=204)
@router.post("/password-reset/confirm", status_code=204)
async def confirm_password_reset(
payload: PasswordResetConfirmRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="password_reset_confirm",
identifier=payload.email,
limit=10,
window_seconds=600,
)
await service.confirm_password_reset(payload)
return Response(status_code=204)
+6
View File
@@ -61,5 +61,11 @@ class PasswordResetRequest(BaseModel):
redirect_to: str | None = None redirect_to: str | None = None
class PasswordResetConfirmRequest(BaseModel):
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str = Field(min_length=6)
class PasswordResetResponse(BaseModel): class PasswordResetResponse(BaseModel):
message: str = "Password reset email sent" message: str = "Password reset email sent"
+18
View File
@@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Protocol from typing import Protocol
from v1.auth.schemas import ( from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest, SessionCreateRequest,
SessionRefreshRequest, SessionRefreshRequest,
SessionResponse, SessionResponse,
@@ -40,6 +42,14 @@ class AuthServiceGateway(Protocol):
async def get_user_by_email(self, email: str) -> UserByEmailResponse: async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
class AuthService: class AuthService:
_gateway: AuthServiceGateway _gateway: AuthServiceGateway
@@ -71,3 +81,11 @@ class AuthService:
async def get_user_by_email(self, email: str) -> UserByEmailResponse: async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return await self._gateway.get_user_by_email(email) return await self._gateway.get_user_by_email(email)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
await self._gateway.request_password_reset(request)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
await self._gateway.confirm_password_reset(request)
+18 -2
View File
@@ -11,11 +11,21 @@ from core.auth.models import CurrentUser
from core.config.settings import config from core.config.settings import config
from core.db import get_db from core.db import get_db
from core.logging import get_logger from core.logging import get_logger
from v1.auth.gateway import SupabaseAuthGateway
from v1.users.repository import SQLAlchemyUserRepository from v1.users.repository import SQLAlchemyUserRepository
from v1.users.service import UserService from v1.users.service import AuthLookupAdapter, UserService
logger = get_logger("v1.users.dependencies") logger = get_logger("v1.users.dependencies")
_auth_gateway: SupabaseAuthGateway | None = None
def get_auth_gateway() -> SupabaseAuthGateway:
global _auth_gateway
if _auth_gateway is None:
_auth_gateway = SupabaseAuthGateway()
return _auth_gateway
def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser: def get_current_user(authorization: str | None = Header(default=None)) -> CurrentUser:
if not authorization: if not authorization:
@@ -98,4 +108,10 @@ def get_user_service(
user: Annotated[CurrentUser, Depends(get_current_user)], user: Annotated[CurrentUser, Depends(get_current_user)],
) -> UserService: ) -> UserService:
repository = SQLAlchemyUserRepository(session) repository = SQLAlchemyUserRepository(session)
return UserService(repository=repository, session=session, current_user=user) auth_gateway = AuthLookupAdapter(get_auth_gateway())
return UserService(
repository=repository,
session=session,
current_user=user,
auth_gateway=auth_gateway,
)
+25 -2
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Protocol from typing import TYPE_CHECKING, Protocol
from uuid import UUID from uuid import UUID
from sqlalchemy import select from sqlalchemy import select, or_
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository from core.db.base_repository import BaseRepository
@@ -33,6 +33,10 @@ class UserRepository(Protocol):
"""Update user by user ID. Returns updated user or None if not found.""" """Update user by user ID. Returns updated user or None if not found."""
... ...
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
"""Search users by username (ilike) or email (exact match)."""
...
class SQLAlchemyUserRepository(BaseRepository[Profile]): class SQLAlchemyUserRepository(BaseRepository[Profile]):
"""SQLAlchemy implementation of UserRepository. """SQLAlchemy implementation of UserRepository.
@@ -77,5 +81,24 @@ class SQLAlchemyUserRepository(BaseRepository[Profile]):
try: try:
return await self.update_by_id(user_id, update_data) return await self.update_by_id(user_id, update_data)
except SQLAlchemyError: except SQLAlchemyError:
logger.exception("User update failed", user_id=str(user_id)) logger.exception("User update failed", user=str(user_id))
raise
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
try:
stmt = (
select(Profile)
.where(Profile.deleted_at.is_(None))
.where(
or_(
Profile.username.ilike(f"%{query}%"),
)
)
.order_by(Profile.created_at.asc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError:
logger.exception("User search failed", query=query)
raise raise
+7 -9
View File
@@ -2,10 +2,10 @@ from __future__ import annotations
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Path from fastapi import APIRouter, Depends
from v1.users.dependencies import get_user_service from v1.users.dependencies import get_user_service
from v1.users.schemas import UserResponse, UserUpdateRequest from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService from v1.users.service import UserService
@@ -27,11 +27,9 @@ async def update_me(
return await service.update_me(payload) return await service.update_me(payload)
@router.get("/{username}", response_model=UserResponse) @router.post("/search", response_model=list[UserResponse])
async def get_by_username( async def search_users(
username: Annotated[ payload: UserSearchRequest,
str, Path(min_length=3, max_length=30, pattern="^[a-zA-Z0-9_]+$")
],
service: Annotated[UserService, Depends(get_user_service)], service: Annotated[UserService, Depends(get_user_service)],
) -> UserResponse: ) -> list[UserResponse]:
return await service.get_by_username(username) return await service.search_users(payload)
+11
View File
@@ -19,6 +19,17 @@ class UserResponse(BaseModel):
bio: str | None = None bio: str | None = None
class UserSearchRequest(BaseModel):
query: str = Field(min_length=1, max_length=100)
class UserSearchResult(BaseModel):
id: str
username: str
avatar_url: str | None = None
bio: str | None = None
class UserUpdateRequest(BaseModel): class UserUpdateRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
+80 -2
View File
@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING import re
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@@ -9,13 +11,37 @@ from core.auth.models import CurrentUser
from core.db.base_service import BaseService from core.db.base_service import BaseService
from core.logging import get_logger from core.logging import get_logger
from v1.users.repository import UserRepository from v1.users.repository import UserRepository
from v1.users.schemas import UserResponse, UserUpdateRequest from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from v1.auth.schemas import UserByEmailResponse
logger = get_logger("v1.users.service") logger = get_logger("v1.users.service")
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
class AuthLookupGateway(Protocol):
async def get_user_id_by_email(self, email: str) -> str | None: ...
class AuthByEmailGateway(Protocol):
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
class AuthLookupAdapter:
def __init__(self, gateway: AuthByEmailGateway) -> None:
self._gateway = gateway
async def get_user_id_by_email(self, email: str) -> str | None:
try:
response = await self._gateway.get_user_by_email(email)
return response.id
except HTTPException:
return None
class UserService(BaseService): class UserService(BaseService):
"""User service handling business logic and transactions. """User service handling business logic and transactions.
@@ -28,16 +54,19 @@ class UserService(BaseService):
_repository: UserRepository _repository: UserRepository
_session: AsyncSession _session: AsyncSession
_auth_gateway: AuthLookupGateway | None
def __init__( def __init__(
self, self,
repository: UserRepository, repository: UserRepository,
session: AsyncSession, session: AsyncSession,
current_user: CurrentUser | None, current_user: CurrentUser | None,
auth_gateway: AuthLookupGateway | None = None,
) -> None: ) -> None:
super().__init__(current_user=current_user) super().__init__(current_user=current_user)
self._repository = repository self._repository = repository
self._session = session self._session = session
self._auth_gateway = auth_gateway
async def get_me(self) -> UserResponse: async def get_me(self) -> UserResponse:
user_id = self.require_user_id() user_id = self.require_user_id()
@@ -101,3 +130,52 @@ class UserService(BaseService):
avatar_url=user.avatar_url, avatar_url=user.avatar_url,
bio=user.bio, bio=user.bio,
) )
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
query = request.query.strip()
if _EMAIL_PATTERN.match(query):
return await self._search_by_email(query)
return await self._search_by_username(query)
async def _search_by_email(self, email: str) -> list[UserResponse]:
if self._auth_gateway is None:
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
if user_id_str is None:
return []
try:
user = await self._repository.get_by_user_id(UUID(user_id_str))
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
if user is None:
return []
return [
UserResponse(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
)
]
async def _search_by_username(self, query: str) -> list[UserResponse]:
try:
users = await self._repository.search_users(query, limit=20)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
return [
UserResponse(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
)
for user in users
]
@@ -14,6 +14,8 @@ from v1.users.dependencies import get_current_user
from v1.auth.rate_limit import reset_rate_limit_state from v1.auth.rate_limit import reset_rate_limit_state
from v1.auth.schemas import ( from v1.auth.schemas import (
AuthUser, AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest, SessionCreateRequest,
SessionRefreshRequest, SessionRefreshRequest,
SessionResponse, SessionResponse,
@@ -71,6 +73,18 @@ class FakeAuthService(AuthService):
email_confirmed_at=None, email_confirmed_at=None,
) )
async def request_password_reset(self, request: PasswordResetRequest) -> None:
return None
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
if request.token == "000000":
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
return None
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]: def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
def _get_service() -> AuthService: def _get_service() -> AuthService:
@@ -665,3 +679,116 @@ def test_get_user_by_email_forbidden_when_querying_other_user() -> None:
assert body["detail"] == "Forbidden" assert body["detail"] == "Forbidden"
finally: finally:
app.dependency_overrides = {} app.dependency_overrides = {}
def test_password_reset_request_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset",
json={"email": "user@example.com"},
)
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"email": "user@example.com",
"token": "123456",
"new_password": "newpassword123",
},
)
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_invalid_token_returns_401() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"email": "user@example.com",
"token": "000000",
"new_password": "newpassword123",
},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_weak_password_returns_422() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"email": "user@example.com",
"token": "123456",
"new_password": "123",
},
)
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
+77 -45
View File
@@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
from app import app from app import app
from core.auth.models import CurrentUser from core.auth.models import CurrentUser
from v1.users.dependencies import get_current_user, get_user_service from v1.users.dependencies import get_current_user, get_user_service
from v1.users.schemas import UserResponse, UserUpdateRequest from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService from v1.users.service import UserService
@@ -18,6 +18,10 @@ class FakeUserService:
def __init__(self, user: UserResponse) -> None: def __init__(self, user: UserResponse) -> None:
self._user = user self._user = user
self._search_results: list[UserResponse] = []
def set_search_results(self, results: list[UserResponse]) -> None:
self._search_results = results
async def get_me(self) -> UserResponse: async def get_me(self) -> UserResponse:
if self._user.id is None: if self._user.id is None:
@@ -45,6 +49,11 @@ class FakeUserService:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
return self._user return self._user
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
if request.query:
return self._search_results if self._search_results else [self._user]
return []
def _override_user_service( def _override_user_service(
service: FakeUserService, service: FakeUserService,
@@ -111,50 +120,6 @@ def test_patch_me_updates_user() -> None:
app.dependency_overrides = {} app.dependency_overrides = {}
def test_get_user_by_username() -> None:
user = UserResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.get("/api/v1/users/demo")
assert response.status_code == 200
body = response.json()
assert body["username"] == "demo"
finally:
app.dependency_overrides = {}
def test_user_not_found_returns_problem_details() -> None:
user = UserResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.get("/api/v1/users/unknown")
assert response.status_code == 404
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Not Found"
assert body["status"] == 404
finally:
app.dependency_overrides = {}
def test_patch_me_validation_error_returns_problem_details() -> None: def test_patch_me_validation_error_returns_problem_details() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001") user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse( user = UserResponse(
@@ -178,3 +143,70 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
assert body["status"] == 422 assert body["status"] == 422
finally: finally:
app.dependency_overrides = {} app.dependency_overrides = {}
def test_search_users_returns_list() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
id=str(user_id),
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/users/search",
json={"query": "demo"},
)
assert response.status_code == 200
body = response.json()
assert isinstance(body, list)
finally:
app.dependency_overrides = {}
def test_search_users_empty_query_returns_422() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
id=str(user_id),
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/users/search",
json={"query": ""},
)
assert response.status_code == 422
finally:
app.dependency_overrides = {}
def test_get_user_by_username_returns_404() -> None:
user = UserResponse(
id="00000000-0000-0000-0000-000000000001",
username="demo",
avatar_url=None,
bio=None,
)
app.dependency_overrides[get_user_service] = _override_user_service(
FakeUserService(user)
)
client = TestClient(app)
try:
response = client.get("/api/v1/users/demo")
assert response.status_code == 404
finally:
app.dependency_overrides = {}
@@ -0,0 +1,155 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
class TestSupabaseAuthGateway:
@pytest.fixture
def gateway(self) -> SupabaseAuthGateway:
with patch("v1.auth.gateway.create_client") as mock_create:
mock_client = MagicMock()
mock_admin_client = MagicMock()
mock_create.side_effect = [mock_client, mock_admin_client]
return SupabaseAuthGateway()
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
self, gateway: SupabaseAuthGateway
) -> None:
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
await gateway.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: SupabaseAuthGateway
) -> None:
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await gateway.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
options={"redirect_to": "http://localhost:3000/reset-password"},
)
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: SupabaseAuthGateway
) -> None:
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await gateway.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: SupabaseAuthGateway
) -> None:
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await gateway.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: SupabaseAuthGateway
) -> None:
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await gateway.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: SupabaseAuthGateway
) -> None:
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
gateway._client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
gateway._admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
request = PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
)
await gateway.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
{
"type": "recovery",
"email": "test@example.com",
"token": "123456",
}
)
mock_update_user_by_id.assert_called_once_with(
"user-1",
{"password": "newpassword123"},
)
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
self, gateway: SupabaseAuthGateway
) -> None:
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
)
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
)
with pytest.raises(HTTPException) as exc_info:
await gateway.confirm_password_reset(request)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"
@@ -0,0 +1,309 @@
# Invite Code Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 在现有 OTP 注册链路中引入邀请码能力,支持用户自动生成专属邀请码、注册时可选填邀请码并记录邀请关系与使用次数。
**Architecture:** 采用数据库中心实现:通过 Alembic 新增 `invite_codes` 表、扩展 `profiles` 字段,并在 `auth.users` 的现有 trigger 函数中完成邀请码校验与记账,保证注册与邀请关系写入尽量原子。应用层只负责透传 `invite_code` 到 Supabase `raw_user_meta_data`
**Tech Stack:** FastAPI, SQLAlchemy, Alembic, Supabase Auth, PostgreSQL PL/pgSQL, Pytest
---
### Task 1: 更新注册请求 SchemaTDD
**Files:**
- Modify: `backend/src/v1/auth/schemas.py`
- Modify: `backend/tests/integration/test_auth_routes.py`
**Step 1: Write the failing test**
`test_signup_start_returns_pending_response` 基础上新增断言路径:请求体带 `invite_code` 时返回仍为 202,且未触发 422。
**Step 2: Run test to verify it fails**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k signup_start_returns_pending_response -v`
Expected: FAIL`invite_code` 为额外字段或校验不通过)
**Step 3: Write minimal implementation**
`VerificationCreateRequest` 增加可选字段:
```python
invite_code: str | None = Field(default=None, min_length=8, max_length=8)
```
**Step 4: Run test to verify it passes**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k signup_start_returns_pending_response -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/auth/schemas.py backend/tests/integration/test_auth_routes.py
git commit -m "feat: accept invite code in signup request"
```
### Task 2: 透传 invite_code 到 Supabase metadataTDD
**Files:**
- Modify: `backend/src/v1/auth/gateway.py`
- Modify: `backend/tests/unit/v1/auth/test_auth_service.py`
**Step 1: Write the failing test**
`test_supabase_signup_passes_username_in_metadata` 增加 `invite_code` 并断言:
```python
assert captured_payload["data"] == {
"username": "demo",
"invite_code": "A1B2C3D4",
}
```
**Step 2: Run test to verify it fails**
Run: `cd backend && uv run pytest tests/unit/v1/auth/test_auth_service.py -k metadata -v`
Expected: FAILmetadata 未包含 `invite_code`
**Step 3: Write minimal implementation**
`create_verification` 中构建 metadata
```python
metadata = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
payload = {
"email": request.email,
"password": request.password,
"data": metadata,
}
```
**Step 4: Run test to verify it passes**
Run: `cd backend && uv run pytest tests/unit/v1/auth/test_auth_service.py -k metadata -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/auth/gateway.py backend/tests/unit/v1/auth/test_auth_service.py
git commit -m "feat: pass invite code through signup metadata"
```
### Task 3: 新增 invite_codes 表与 profiles.referred_by(迁移先行)
**Files:**
- Create: `backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py`
- Modify: `backend/src/models/profile.py`
- Create: `backend/src/models/invite_code.py`
- Modify: `backend/src/models/__init__.py`
**Step 1: Write the failing test**
`backend/tests/unit/database/test_profile_models.py` 新增 `referred_by` 读写测试;新增 `backend/tests/unit/database/test_invite_code_models.py` 验证 `InviteCode` 基本创建与约束字段。
**Step 2: Run test to verify it fails**
Run: `cd backend && uv run pytest tests/unit/database/test_profile_models.py tests/unit/database/test_invite_code_models.py -v`
Expected: FAIL(字段/模型不存在)
**Step 3: Write minimal implementation**
- Alembic 创建 `invite_codes`
- `code` 唯一索引
- `owner_id` 外键到 `profiles.id`(可空)
- `status``used_count``max_uses` check 约束
- `max_uses` 默认 `NULL`(无限制)
- `expires_at` 默认 `NULL`(无限制)
- `reward_config` JSONB 默认 `{}`
- 启用 RLS(按项目默认 deny-all
- **注意**:本期不开放 invite_codes 表直接读取,用户邀请码通过 profile 聚合接口返回(后续实现)
- Alembic 给 `profiles` 增加 `referred_by` + 索引 + 外键
- ORM 同步 `Profile.referred_by``InviteCode` 模型
**Step 4: Run test to verify it passes**
Run: `cd backend && uv run pytest tests/unit/database/test_profile_models.py tests/unit/database/test_invite_code_models.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py backend/src/models/profile.py backend/src/models/invite_code.py backend/src/models/__init__.py backend/tests/unit/database/test_profile_models.py backend/tests/unit/database/test_invite_code_models.py
git commit -m "feat: add invite code schema and profile referral fields"
```
### Task 4: 扩展注册 trigger 生成邀请码并消费邀请(TDD)
**Files:**
- Modify: `backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py`
- Modify: `backend/tests/integration/test_auth_routes.py`
**Step 1: Write the failing test**
新增集成测试(建议通过测试替身/fixture 验证行为):
- 注册不带邀请码时,profile 创建后存在 owner 邀请码
- 注册带有效邀请码时,`referred_by` 生效且 `used_count + 1`
**Step 2: Run test to verify it fails**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k invite -v`
Expected: FAIL(触发器逻辑尚未实现)
**Step 3: Write minimal implementation**
在迁移 SQL 中:
- 新增 helper function:生成 8 位随机码(排除易混淆字符 0/O/1/I/L,冲突重试)
- 重建 `public.create_profile_for_new_user()`
1. 插入 `profiles`
2. 创建该用户专属 `invite_codes``owner_id = NEW.id`
3. 读取 `NEW.raw_user_meta_data ->> 'invite_code'`
4. 校验邀请码状态/过期/次数
5. 若有效:更新 `profiles.referred_by`,并 `used_count = used_count + 1`
**Step 4: Run test to verify it passes**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k invite -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py backend/tests/integration/test_auth_routes.py
git commit -m "feat: extend signup trigger for invite code generation and usage"
```
### Task 5: 覆盖邀请码边界场景(TDD)
**Files:**
- Modify: `backend/tests/integration/test_auth_routes.py`
- Optional Modify: `backend/tests/e2e/test_auth_flow.py`
**Step 1: Write the failing test**
新增场景测试:
- 邀请码不存在
- 邀请码 disabled
- 邀请码 expires_at 已过期
- 邀请码达到 `max_uses`
断言:注册仍成功(202/200 链路正常),仅邀请关系不建立。
**Step 2: Run test to verify it fails**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k "invite and (expired or disabled or max_uses or invalid)" -v`
Expected: FAIL
**Step 3: Write minimal implementation**
修正 trigger 判断顺序和条件,确保“邀请无效不影响注册”原则。
**Step 4: Run test to verify it passes**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k invite -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/tests/integration/test_auth_routes.py backend/alembic/versions/20260227_0006_invite_codes_and_profile_referral.py
git commit -m "test: cover invite code edge cases in signup flow"
```
### Task 6: 文档同步与运行手册更新
**Files:**
- Modify: `docs/runtime/runtime-route.md`
- Modify: `docs/runtime/runtime-runbook.md`
**Step 1: Write the failing test**
无自动化测试;改为文档一致性检查清单(手工):
- 注册接口 request 字段包含 `invite_code`
- 说明邀请码消费时机与“无效码不阻断注册”
**Step 2: Run check to verify missing docs**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k signup_start -v`
Expected: PASS(作为行为基线),文档尚未同步
**Step 3: Write minimal implementation**
- 更新 `POST /auth/verifications` 请求字段
- 新增邀请码行为说明
- 在 runbook 变更日志添加本次改动记录
**Step 4: Run check after docs update**
Run: `cd backend && uv run pytest tests/integration/test_auth_routes.py -k signup_start -v`
Expected: PASS(行为与文档一致)
**Step 5: Commit**
```bash
git add docs/runtime/runtime-route.md docs/runtime/runtime-runbook.md
git commit -m "docs: document invite code behavior in signup flow"
```
### Task 7: 全量验证与风险审查(L2)
**Files:**
- Verify only
**Step 1: Run lint/type checks**
Run:
- `cd backend && uv run ruff check src tests`
- `cd backend && uv run basedpyright src`
Expected: 全部通过
**Step 2: Run test suites**
Run:
- `cd backend && uv run pytest tests/unit -v`
- `cd backend && uv run pytest tests/integration -v`
- `cd backend && uv run pytest tests/e2e/test_auth_flow.py -v`
Expected: 通过
**Step 3: Run mandatory review gates for L2**
- `refactor-cleaner` agent:确认无死代码/重复代码
- `code-reviewer` agent:检查 DB trigger、安全边界、可维护性
Expected: CRITICAL/HIGH 为 0
**Step 4: Security-specific sanity checks**
检查项:
- 未硬编码密钥
- SQL 逻辑无注入风险(trigger 中仅参数/列操作)
- 邀请码校验失败不泄露内部细节
**Step 5: Commit verification evidence (if needed in docs/PR notes)**
```bash
git add <updated verification notes if any>
git commit -m "chore: record invite code verification results"
```
---
## 交付验收标准
1. 新用户注册后必有 1 条专属邀请码。
2. 注册时传入有效邀请码会建立 `profiles.referred_by` 并增加 `used_count`
3. 无效邀请码不会阻断注册成功。
4. 支持运营码(`owner_id IS NULL`)与后续奖励扩展(`reward_config`)。
5. 文档已同步,测试与检查通过。
## 备注
- 本需求触发 L2(数据库迁移 + trigger + 多文件大改),必须走双审查 gate。
- 不在本期实现运营后台批量发码 API;仅完成数据层与注册链路支撑。
+65 -11
View File
@@ -171,6 +171,48 @@
--- ---
### POST /auth/password-reset
发送密码重置验证码。
**Request:**
```json
{
"email": "string (email)",
"redirect_to": "string? (optional)"
}
```
**Response:** 204 No Content
**Errors:**
- 422: 请求参数无效
- 429: 请求过于频繁
---
### POST /auth/password-reset/confirm
验证 recovery 验证码并完成改密。
**Request:**
```json
{
"email": "string (email)",
"token": "string (6 digits)",
"new_password": "string (min 6 chars)"
}
```
**Response:** 204 No Content
**Errors:**
- 401: 验证码无效或已过期
- 422: 请求参数无效
- 429: 请求过于频繁
---
### GET /auth/users ### GET /auth/users
按邮箱查询用户(需要认证)。 按邮箱查询用户(需要认证)。
@@ -245,26 +287,38 @@
--- ---
### GET /users/{username} ### POST /users/search
按用户名查询用户(需要认证)。 搜索用户(需要认证)。
**Path Parameters:** 支持两种查询模式:
- `username`: string (3-30 chars, alphanumeric and underscore) - **用户名查询**:模糊匹配,返回最多 20 个结果
- **邮箱查询**:精确匹配,返回 0 或 1 个结果
查询类型自动识别:包含 `@` 符号视为邮箱查询。
**Request:**
```json
{
"query": "string (1-100 chars)"
}
```
**Response:** 200 OK **Response:** 200 OK
```json ```json
{ [
"id": "string", {
"username": "string", "id": "string",
"avatar_url": "string?", "username": "string",
"bio": "string?" "avatar_url": "string?",
} "bio": "string?"
}
]
``` ```
**Errors:** **Errors:**
- 401: 未认证 - 401: 未认证
- 404: 用户不存在 - 503: Auth 服务不可用(仅邮箱查询)
- 422: 请求参数无效 - 422: 请求参数无效
--- ---
+1
View File
@@ -244,3 +244,4 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml up -d --force-
| 2026-02-25 | 简化启动方式:dev-app-up -> app-up,分离 bootstrap 与服务启动 | | 2026-02-25 | 简化启动方式:dev-app-up -> app-up,分离 bootstrap 与服务启动 |
| 2026-02-25 | 重构为运维分层手册:Bootstrap Gate、分层验证、故障与回滚流程 | | 2026-02-25 | 重构为运维分层手册:Bootstrap Gate、分层验证、故障与回滚流程 |
| 2026-02-25 | 新增配置漂移故障条目:修复 Auth 邮件模板失效与 signup 超时场景 | | 2026-02-25 | 新增配置漂移故障条目:修复 Auth 邮件模板失效与 signup 超时场景 |
| 2026-02-27 | 用户搜索支持邮箱精确匹配:query 含 @ 符号时走 auth.users → profiles 两步查询 |
-17
View File
@@ -1,17 +0,0 @@
#!/bin/bash
set -euo pipefail
SESSION_NAME="${SESSION_NAME:-social-dev}"
echo "=== App Down ==="
if ! tmux has-session -t "$SESSION_NAME" 2>/dev/null; then
echo "No tmux session '$SESSION_NAME' found."
exit 0
fi
echo "Stopping tmux session '$SESSION_NAME'..."
tmux kill-session -t "$SESSION_NAME"
echo "Session stopped and cleaned up."
-66
View File
@@ -1,66 +0,0 @@
#!/bin/bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
SESSION_NAME="${SESSION_NAME:-social-dev}"
COMPOSE_FILE="$ROOT_DIR/infra/docker/docker-compose.yml"
ENV_FILE="$ROOT_DIR/.env"
echo "=== App Up ==="
echo "This script starts web + worker processes in tmux."
echo "NOTE: Bootstrap (migrate + init-data) must be run separately."
echo ""
if ! command -v tmux >/dev/null 2>&1; then
echo "Error: tmux is required." >&2
exit 1
fi
if [ ! -f "$ENV_FILE" ]; then
echo "Error: env file not found at $ENV_FILE" >&2
exit 1
fi
if [ ! -f "$COMPOSE_FILE" ]; then
echo "Error: compose file not found at $COMPOSE_FILE" >&2
exit 1
fi
set -a
# shellcheck disable=SC1090
. "$ENV_FILE"
set +a
if tmux has-session -t "$SESSION_NAME" 2>/dev/null; then
echo "Error: tmux session '$SESSION_NAME' already exists." >&2
echo "Hint: tmux kill-session -t $SESSION_NAME" >&2
exit 1
fi
echo "Starting web + worker processes in tmux session '$SESSION_NAME'..."
WEB_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=web uv run gunicorn app:app --bind \
${SOCIAL_WEB__HOST:-0.0.0.0}:${SOCIAL_WEB__PORT:-8000} --workers \
${SOCIAL_WEB__GUNICORN__WORKERS:-2} --worker-class \
${SOCIAL_WEB__GUNICORN__WORKER_CLASS:-uvicorn.workers.UvicornWorker} --timeout \
${SOCIAL_WEB__GUNICORN__TIMEOUT:-60}"
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run celery -A core.celery.app worker --loglevel=info --queues=critical --concurrency=${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run celery -A core.celery.app worker --loglevel=info --queues=default --concurrency=${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run celery -A core.celery.app worker --loglevel=info --queues=bulk --concurrency=${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
tmux new-session -d -s "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-critical "bash -lc \"$WORKER_CRITICAL_CMD; echo '[worker-critical] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-default "bash -lc \"$WORKER_DEFAULT_CMD; echo '[worker-default] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-bulk "bash -lc \"$WORKER_BULK_CMD; echo '[worker-bulk] exited'; exec bash\""
echo ""
echo "=== App Started ==="
echo "Log files will be created in logs/ directory:"
echo " - web.log, web.error.log"
echo " - worker-critical.log, worker-critical.error.log"
echo " - worker-default.log, worker-default.error.log"
echo " - worker-bulk.log, worker-bulk.error.log"
echo ""
echo "tmux attach -t $SESSION_NAME"
echo "tmux list-windows -t $SESSION_NAME"
+107
View File
@@ -0,0 +1,107 @@
#!/bin/bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
SESSION_NAME="${SESSION_NAME:-social-dev}"
COMPOSE_FILE="$ROOT_DIR/infra/docker/docker-compose.yml"
ENV_FILE="$ROOT_DIR/.env"
usage() {
echo "Usage: $0 {start|stop}"
echo ""
echo "Commands:"
echo " start Start web + worker processes in tmux"
echo " stop Stop and clean up tmux session"
exit 1
}
start() {
echo "=== App Up ==="
echo "This script starts web + worker processes in tmux."
echo "NOTE: Bootstrap (migrate + init-data) must be run separately."
echo ""
if ! command -v tmux >/dev/null 2>&1; then
echo "Error: tmux is required." >&2
exit 1
fi
if [ ! -f "$ENV_FILE" ]; then
echo "Error: env file not found at $ENV_FILE" >&2
exit 1
fi
if [ ! -f "$COMPOSE_FILE" ]; then
echo "Error: compose file not found at $COMPOSE_FILE" >&2
exit 1
fi
set -a
# shellcheck disable=SC1090
. "$ENV_FILE"
set +a
if tmux has-session -t "$SESSION_NAME" 2>/dev/null; then
echo "Error: tmux session '$SESSION_NAME' already exists." >&2
echo "Hint: tmux kill-session -t $SESSION_NAME" >&2
exit 1
fi
echo "Starting web + worker processes in tmux session '$SESSION_NAME'..."
WEB_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=web uv run gunicorn app:app --bind \
${SOCIAL_WEB__HOST:-0.0.0.0}:${SOCIAL_WEB__PORT:-8000} --workers \
${SOCIAL_WEB__GUNICORN__WORKERS:-2} --worker-class \
${SOCIAL_WEB__GUNICORN__WORKER_CLASS:-uvicorn.workers.UvicornWorker} --timeout \
${SOCIAL_WEB__GUNICORN__TIMEOUT:-60} \
--log-level ${SOCIAL_RUNTIME__LOG_LEVEL:-info}"
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run celery -A core.celery.app worker --loglevel=info --queues=critical --concurrency=${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run celery -A core.celery.app worker --loglevel=info --queues=default --concurrency=${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run celery -A core.celery.app worker --loglevel=info --queues=bulk --concurrency=${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
tmux new-session -d -s "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-critical "bash -lc \"$WORKER_CRITICAL_CMD; echo '[worker-critical] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-default "bash -lc \"$WORKER_DEFAULT_CMD; echo '[worker-default] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-bulk "bash -lc \"$WORKER_BULK_CMD; echo '[worker-bulk] exited'; exec bash\""
echo ""
echo "=== App Started ==="
echo "Log files will be created in logs/ directory:"
echo " - web.log, web.error.log"
echo " - worker-critical.log, worker-critical.error.log"
echo " - worker-default.log, worker-default.error.log"
echo " - worker-bulk.log, worker-bulk.error.log"
echo ""
echo "tmux attach -t $SESSION_NAME"
echo "tmux list-windows -t $SESSION_NAME"
}
stop() {
echo "=== App Down ==="
if tmux has-session -t "$SESSION_NAME" 2>/dev/null; then
echo "Stopping tmux session '$SESSION_NAME'..."
tmux kill-session -t "$SESSION_NAME"
else
echo "No tmux session '$SESSION_NAME' found."
fi
echo "Checking for orphaned processes..."
if pgrep -f "gunicorn.*app:app" > /dev/null 2>&1; then
echo "Killing orphaned gunicorn processes..."
pkill -f "gunicorn.*app:app"
fi
if pgrep -f "celery.*worker" > /dev/null 2>&1; then
echo "Killing orphaned celery processes..."
pkill -f "celery.*worker"
fi
echo "Session stopped and cleaned up."
}
case "${1:-}" in
start) start ;;
stop) stop ;;
*) usage ;;
esac