diff --git a/backend/src/core/auth/models.py b/backend/src/core/auth/models.py index f7a31d5..fb12451 100644 --- a/backend/src/core/auth/models.py +++ b/backend/src/core/auth/models.py @@ -7,5 +7,5 @@ from uuid import UUID @dataclass(frozen=True) class CurrentUser: id: UUID - email: str | None = None + phone: str | None = None role: str | None = None diff --git a/backend/src/core/config/settings.py b/backend/src/core/config/settings.py index 722fe48..5dcd117 100644 --- a/backend/src/core/config/settings.py +++ b/backend/src/core/config/settings.py @@ -61,6 +61,7 @@ class RuntimeSettings(BaseModel): ] ) sql_log_queries: bool = False + trusted_proxy_ips: list[str] = Field(default_factory=list) @field_validator("log_dir", mode="before") @classmethod @@ -162,6 +163,12 @@ class AgentRuntimeSettings(BaseModel): user_context_cache_max_turns: int = Field(default=6, ge=1, le=100) +class AutomationSchedulerSettings(BaseModel): + enabled: bool = True + interval_seconds: int = Field(default=60, ge=5, le=3600) + batch_limit: int = Field(default=100, ge=1, le=1000) + + class LlmSettings(BaseModel): provider_keys: dict[str, str] = Field(default_factory=dict) @@ -225,7 +232,7 @@ class AppVersionSettings(BaseModel): class TestSettings(BaseModel): - email: str = "" + phone: str = "" password: str = "" @@ -250,6 +257,7 @@ class Settings(BaseSettings): llm: LlmSettings = LlmSettings() litellm: LiteLLMSettings = LiteLLMSettings() agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings() + automation_scheduler: AutomationSchedulerSettings = AutomationSchedulerSettings() taskiq: TaskiqSettings = TaskiqSettings() database: DatabaseSettings = DatabaseSettings() app_version: AppVersionSettings = AppVersionSettings() diff --git a/backend/src/core/runtime/cli.py b/backend/src/core/runtime/cli.py index d12e039..e1b266b 100644 --- a/backend/src/core/runtime/cli.py +++ b/backend/src/core/runtime/cli.py @@ -5,6 +5,8 @@ import subprocess import sys from pathlib import Path +from core.agentscope.runtime.tasks import run_automation_scheduler_scan +from core.config.settings import config from core.config.initial.init_data import initialize_data from core.logging import get_logger @@ -101,12 +103,34 @@ async def bootstrap() -> bool: return True +async def run_automation_scheduler_forever() -> bool: + if not config.automation_scheduler.enabled: + logger.info("Automation scheduler disabled by config") + return True + + interval_seconds = int(config.automation_scheduler.interval_seconds) + batch_limit = int(config.automation_scheduler.batch_limit) + logger.info( + "Starting automation scheduler loop", + interval_seconds=interval_seconds, + batch_limit=batch_limit, + ) + while True: + try: + await run_automation_scheduler_scan(limit=batch_limit) + except Exception as exc: + logger.exception("Automation scheduler scan failed", error=str(exc)) + await asyncio.sleep(interval_seconds) + + def main() -> int: """CLI entry point.""" if len(sys.argv) < 2: logger.error("No command provided") logger.info("Usage: python -m core.runtime.cli ") - logger.info("Available commands: migrate, init-data, bootstrap") + logger.info( + "Available commands: migrate, init-data, bootstrap, automation-scheduler" + ) return 1 command = sys.argv[1] @@ -117,9 +141,13 @@ def main() -> int: success = asyncio.run(run_init_data()) elif command == "bootstrap": success = asyncio.run(bootstrap()) + elif command == "automation-scheduler": + success = asyncio.run(run_automation_scheduler_forever()) else: logger.error("Unknown command", command=command) - logger.info("Available commands: migrate, init-data, bootstrap") + logger.info( + "Available commands: migrate, init-data, bootstrap, automation-scheduler" + ) return 1 return 0 if success else 1 diff --git a/backend/src/schemas/messages/chat_message.py b/backend/src/schemas/messages/chat_message.py index b2ee290..db555ce 100644 --- a/backend/src/schemas/messages/chat_message.py +++ b/backend/src/schemas/messages/chat_message.py @@ -6,7 +6,7 @@ from typing import Any, ClassVar from uuid import UUID from pydantic import BaseModel, ConfigDict, Field -from schemas.agent.runtime_models import AgentOutput +from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput from ..agent import AgentType, ToolAgentOutput @@ -24,6 +24,7 @@ class AgentChatMessageMetadata(BaseModel): run_id: str agent_type: AgentType | None = None user_message_attachments: list[UserMessageAttachment] | None = None + router_agent_output: RouterAgentOutput | None = None tool_agent_output: ToolAgentOutput | None = None agent_output: AgentOutput | None = None diff --git a/backend/src/schemas/user/context.py b/backend/src/schemas/user/context.py index 716a433..9000c16 100644 --- a/backend/src/schemas/user/context.py +++ b/backend/src/schemas/user/context.py @@ -66,7 +66,7 @@ class UserContext(BaseModel): id: str username: str - email: str | None = None + phone: str | None = None avatar_url: str | None = None bio: str | None = None settings: ProfileSettingsUnion | None = None diff --git a/backend/src/v1/agent/repository.py b/backend/src/v1/agent/repository.py index c1380b1..06a50da 100644 --- a/backend/src/v1/agent/repository.py +++ b/backend/src/v1/agent/repository.py @@ -6,7 +6,7 @@ from typing import Protocol from uuid import UUID, uuid4 from fastapi import HTTPException -from sqlalchemy import select +from sqlalchemy import Select, select from sqlalchemy.ext.asyncio import AsyncSession from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole @@ -95,6 +95,7 @@ class AgentRepository: session_id: str, content: str, metadata: AgentChatMessageMetadata | None, + visibility_mask: int, ) -> None: from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage @@ -124,6 +125,7 @@ class AgentRepository: seq=next_seq, role=AgentChatMessageRole.USER, content=content, + visibility_mask=max(int(visibility_mask), 0), metadata_json=metadata.model_dump(by_alias=True) if metadata else None, ) self._session.add(message) @@ -132,7 +134,11 @@ class AgentRepository: await self._session.flush() async def get_history_day( - self, *, session_id: str, before: date | None + self, + *, + session_id: str, + before: date | None, + visibility_mask: int | None = None, ) -> dict[str, object] | None: try: session_uuid = UUID(session_id) @@ -152,6 +158,10 @@ class AgentRepository: .order_by(AgentChatMessage.created_at.desc()) .limit(1) ) + target_created_at_stmt = self._apply_visibility_filter( + stmt=target_created_at_stmt, + visibility_mask=visibility_mask, + ) if before_start is not None: target_created_at_stmt = target_created_at_stmt.where( AgentChatMessage.created_at < before_start @@ -175,6 +185,10 @@ class AgentRepository: .where(AgentChatMessage.created_at < end) .order_by(AgentChatMessage.seq.asc()) ) + message_stmt = self._apply_visibility_filter( + stmt=message_stmt, + visibility_mask=visibility_mask, + ) messages = (await self._session.execute(message_stmt)).scalars().all() has_more_stmt = ( select(AgentChatMessage.id) @@ -183,6 +197,10 @@ class AgentRepository: .where(AgentChatMessage.created_at < start) .limit(1) ) + has_more_stmt = self._apply_visibility_filter( + stmt=has_more_stmt, + visibility_mask=visibility_mask, + ) has_more = ( await self._session.execute(has_more_stmt) ).scalar_one_or_none() is not None @@ -196,7 +214,11 @@ class AgentRepository: } async def get_recent_messages_by_user_window( - self, *, session_id: str, user_message_limit: int + self, + *, + session_id: str, + user_message_limit: int, + visibility_mask: int | None = None, ) -> list[dict[str, object]]: try: session_uuid = UUID(session_id) @@ -210,6 +232,10 @@ class AgentRepository: .where(AgentChatMessage.deleted_at.is_(None)) .order_by(AgentChatMessage.seq.desc()) ) + message_stmt = self._apply_visibility_filter( + stmt=message_stmt, + visibility_mask=visibility_mask, + ) messages_desc = (await self._session.execute(message_stmt)).scalars().all() if not messages_desc: return [] @@ -294,6 +320,21 @@ class AgentRepository: ) return payload_model.model_dump(mode="json", exclude_none=True) + def _apply_visibility_filter( + self, + *, + stmt: Select, + visibility_mask: int | None, + ) -> Select: + if visibility_mask is None: + return stmt + required_mask = max(int(visibility_mask), 0) + if required_mask == 0: + return stmt + return stmt.where( + (AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0 + ) + def _has_title(title: object) -> bool: return isinstance(title, str) and bool(title.strip()) diff --git a/backend/src/v1/agent/router.py b/backend/src/v1/agent/router.py index 708caf0..796d4a2 100644 --- a/backend/src/v1/agent/router.py +++ b/backend/src/v1/agent/router.py @@ -47,6 +47,7 @@ logger = get_logger("v1.agent.router") _LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$") _MAX_SSE_CONNECTIONS_PER_USER = 3 _SSE_SLOT_TTL_SECONDS = 15 * 60 +_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"} _MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024 _TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024 _MULTIPART_OVERHEAD_BYTES = 64 * 1024 @@ -72,8 +73,14 @@ async def _acquire_sse_slot(*, user_id: str) -> bool: count = await redis.incr(key) if count == 1: await redis.expire(key, _SSE_SLOT_TTL_SECONDS) + else: + ttl = await redis.ttl(key) + if int(ttl) < 0: + await redis.expire(key, _SSE_SLOT_TTL_SECONDS) if int(count) > _MAX_SSE_CONNECTIONS_PER_USER: - await redis.decr(key) + after_decr = await redis.decr(key) + if int(after_decr) <= 0: + await redis.delete(key) return False return True except Exception as exc: # noqa: BLE001 @@ -82,7 +89,7 @@ async def _acquire_sse_slot(*, user_id: str) -> bool: user_id=user_id, reason=str(exc), ) - return False + return True async def _release_sse_slot(*, user_id: str) -> None: @@ -92,10 +99,21 @@ async def _release_sse_slot(*, user_id: str) -> None: count = await redis.decr(key) if int(count) <= 0: await redis.delete(key) + return None + ttl = await redis.ttl(key) + if int(ttl) < 0: + await redis.expire(key, _SSE_SLOT_TTL_SECONDS) except Exception: # noqa: BLE001 return None +def _is_terminal_run_event(event: dict[str, object]) -> bool: + raw_event_type = event.get("type") + return ( + isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES + ) + + @router.post( "/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED ) @@ -145,8 +163,13 @@ async def stream_events( async def _event_iter() -> AsyncIterator[str]: cursor = last_event_id idle_polls = 0 + terminal_event_reached = False try: - while not await request.is_disconnected() and idle_polls < idle_limit: + while ( + not terminal_event_reached + and not await request.is_disconnected() + and idle_polls < idle_limit + ): try: rows = await service.stream_events( thread_id=thread_id, @@ -181,6 +204,9 @@ async def stream_events( continue cursor = row_id yield to_sse_event(row_id, event) + if _is_terminal_run_event(event): + terminal_event_reached = True + break finally: await _release_sse_slot(user_id=str(current_user.id)) diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index 2bba279..01a471c 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -17,6 +17,9 @@ from core.auth.models import CurrentUser from core.agentscope.schemas.agui_input import extract_latest_user_payload from core.config.settings import config from core.logging import get_logger +from schemas.agent.forwarded_props import parse_forwarded_props_agent_type +from schemas.agent.system_agent import SystemAgentLLMConfig +from schemas.agent.visibility import SystemVisibilityBit, bit_mask from schemas.messages.chat_message import ( AgentChatMessageMetadata, UserMessageAttachment, @@ -51,7 +54,11 @@ class AgentRepositoryLike(Protocol): async def rollback(self) -> None: ... async def get_history_day( - self, *, session_id: str, before: date | None + self, + *, + session_id: str, + before: date | None, + visibility_mask: int | None = None, ) -> dict[str, object] | None: ... async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ... @@ -62,8 +69,13 @@ class AgentRepositoryLike(Protocol): session_id: str, content: str, metadata: AgentChatMessageMetadata | None, + visibility_mask: int, ) -> None: ... + async def get_system_agent_config( + self, *, agent_type: str + ) -> dict[str, object] | None: ... + class QueueClientLike(Protocol): async def enqueue( @@ -138,6 +150,17 @@ class AgentService: created = False thread_id = run_input.thread_id run_id = run_input.run_id + forwarded_props = getattr(run_input, "forwarded_props", None) + try: + agent_type = parse_forwarded_props_agent_type(forwarded_props) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + if agent_type == "memory": + raise HTTPException( + status_code=422, + detail="memory mode is automation-only", + ) + try: owner = await self._repository.get_session_owner(session_id=thread_id) except HTTPException as exc: @@ -161,25 +184,21 @@ class AgentService: run_input=run_input, current_user=current_user, ) + visibility_mask = await self._resolve_user_message_visibility_mask( + agent_type=agent_type + ) await self._repository.persist_user_message( session_id=thread_id, content=user_message_text, metadata=user_message_metadata, + visibility_mask=visibility_mask, ) await self._repository.commit() - forwarded_props = getattr(run_input, "forwarded_props", None) - system_agent_mode = "worker" - if isinstance(forwarded_props, dict): - raw_mode = forwarded_props.get("system_agent_mode") - if isinstance(raw_mode, str) and raw_mode.strip(): - system_agent_mode = raw_mode.strip().lower() - task_id = await self._queue.enqueue( command={ "command": "run", "owner_id": str(current_user.id), - "system_agent_mode": system_agent_mode, "run_input": run_input.model_dump( mode="json", by_alias=True, exclude_none=True ), @@ -193,6 +212,61 @@ class AgentService: created=created, ) + async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int: + normalized_agent_type = agent_type.strip().lower() + history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) + + if normalized_agent_type == "memory": + return bit_mask(bit=18) + + agent_config = await self._repository.get_system_agent_config( + agent_type=normalized_agent_type + ) + if agent_config is None: + raise HTTPException( + status_code=422, detail="invalid forwarded_props.agent_type" + ) + llm_config = SystemAgentLLMConfig.model_validate( + (agent_config.get("config") if isinstance(agent_config, dict) else {}) or {} + ) + agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit) + + if normalized_agent_type == "worker": + router_config = await self._repository.get_system_agent_config( + agent_type="router" + ) + worker_config = await self._repository.get_system_agent_config( + agent_type="worker" + ) + if router_config is None or worker_config is None: + raise HTTPException( + status_code=500, + detail="system agent visibility config missing", + ) + router_mask = bit_mask( + bit=SystemAgentLLMConfig.model_validate( + ( + router_config.get("config") + if isinstance(router_config, dict) + else {} + ) + or {} + ).visibility_consumer_bit + ) + worker_mask = bit_mask( + bit=SystemAgentLLMConfig.model_validate( + ( + worker_config.get("config") + if isinstance(worker_config, dict) + else {} + ) + or {} + ).visibility_consumer_bit + ) + return history_bit_mask | router_mask | worker_mask + + return history_bit_mask | agent_mask + async def _prepare_user_message( self, *, @@ -408,6 +482,7 @@ class AgentService: day_payload = await self._repository.get_history_day( session_id=thread_id, before=before, + visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)), ) messages: list[HistoryMessage] = [] diff --git a/backend/src/v1/auth/gateway.py b/backend/src/v1/auth/gateway.py index 702bdff..56c2a09 100644 --- a/backend/src/v1/auth/gateway.py +++ b/backend/src/v1/auth/gateway.py @@ -2,28 +2,22 @@ from __future__ import annotations import asyncio import time -from collections.abc import Mapping from typing import Any, cast -from urllib.parse import urlparse + +from pydantic import ValidationError from fastapi import HTTPException from supabase import AuthError -from core.config.settings import config from core.logging import get_logger from services.base.supabase import supabase_service from v1.auth.schemas import ( AuthUser, - PasswordResetConfirmRequest, - PasswordResetRequest, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionRefreshRequest, SessionResponse, - UserByEmailResponse, - VerificationCreateRequest, - VerificationCreateResponse, - VerificationResendRequest, - VerificationVerifyRequest, + UserByPhoneResponse, ) from v1.auth.service import AuthServiceGateway @@ -36,7 +30,7 @@ class SupabaseAuthGateway(AuthServiceGateway): def __init__(self) -> None: self._user_lookup_cache_ttl_seconds: int = 60 self._user_lookup_cache_expires_at: float = 0.0 - self._users_by_email: dict[str, Any] = {} + self._users_by_phone: dict[str, Any] = {} def _get_client(self) -> Any: return supabase_service.get_client() @@ -44,47 +38,31 @@ class SupabaseAuthGateway(AuthServiceGateway): def _get_admin_client(self) -> Any: return supabase_service.get_admin_client() - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: + async def send_otp(self, request: OtpSendRequest) -> None: client = self._get_client() - metadata: dict[str, Any] = {"username": request.username} - if request.invite_code: - metadata["invite_code"] = request.invite_code payload: dict[str, Any] = { - "email": request.email, - "password": request.password, - "data": metadata, + "phone": request.phone, + "options": {"should_create_user": True}, } - if request.redirect_to: - payload["options"] = { - "email_redirect_to": _validate_redirect_url(request.redirect_to) - } try: - sign_up = cast(Any, client.auth.sign_up) - await asyncio.to_thread(sign_up, payload) - return VerificationCreateResponse(email=request.email) + sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp) + await asyncio.to_thread(sign_in_with_otp, payload) except AuthError as exc: - logger.warning("Signup failed", error_type=type(exc).__name__) + logger.warning("Send otp failed", error_type=type(exc).__name__) if _is_auth_upstream_unavailable(exc): raise HTTPException( status_code=503, detail=AUTH_UNAVAILABLE_DETAIL, ) from exc - raise HTTPException( - status_code=422, detail="Invalid signup request" - ) from exc + raise HTTPException(status_code=429, detail="Too many requests") from exc - async def verify_verification( - self, request: VerificationVerifyRequest + async def create_phone_session( + self, request: PhoneSessionCreateRequest ) -> SessionResponse: - if request.type != "signup": - raise HTTPException(status_code=422, detail="Invalid request") - client = self._get_client() payload: dict[str, Any] = { - "type": request.type, - "email": request.email, + "type": "sms", + "phone": request.phone, "token": request.token, } try: @@ -92,7 +70,7 @@ class SupabaseAuthGateway(AuthServiceGateway): response = await asyncio.to_thread(verify_otp, payload) return _map_auth_response(response, "Invalid verification code") except AuthError as exc: - logger.warning("Signup verify failed", error_type=type(exc).__name__) + logger.warning("Create phone session failed", error_type=type(exc).__name__) if _is_auth_upstream_unavailable(exc): raise HTTPException( status_code=503, @@ -102,45 +80,6 @@ class SupabaseAuthGateway(AuthServiceGateway): status_code=401, detail="Invalid verification code" ) from exc - async def resend_verification(self, request: VerificationResendRequest) -> None: - client = self._get_client() - if request.type == "recovery": - await self.request_password_reset( - PasswordResetRequest( - email=request.email, - redirect_to=request.redirect_to, - ) - ) - return - - payload: dict[str, Any] = {"type": request.type, "email": request.email} - try: - resend = cast(Any, client.auth.resend) - await asyncio.to_thread(resend, payload) - except AuthError as exc: - logger.warning("Signup resend failed", error_type=type(exc).__name__) - if _is_auth_upstream_unavailable(exc): - raise HTTPException( - status_code=503, - detail=AUTH_UNAVAILABLE_DETAIL, - ) from exc - - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: - client = self._get_client() - payload: dict[str, Any] = {"email": request.email, "password": request.password} - try: - sign_in = cast(Any, client.auth.sign_in_with_password) - response = await asyncio.to_thread(sign_in, payload) - return _map_auth_response(response, "Invalid credentials") - except AuthError as exc: - logger.warning("Login failed", error_type=type(exc).__name__) - if _is_auth_upstream_unavailable(exc): - raise HTTPException( - status_code=503, - detail=AUTH_UNAVAILABLE_DETAIL, - ) from exc - raise HTTPException(status_code=401, detail="Invalid credentials") from exc - async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: client = self._get_client() try: @@ -189,98 +128,84 @@ class SupabaseAuthGateway(AuthServiceGateway): status_code=401, detail="Invalid refresh token" ) from exc - async def get_user_by_email(self, email: str) -> UserByEmailResponse: - admin_client = self._get_admin_client() - normalized_email = email.lower() + async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse: + normalized_phone = _normalize_phone(phone) + if not normalized_phone: + raise HTTPException(status_code=404, detail="User not found") - now = time.monotonic() - if now >= self._user_lookup_cache_expires_at: - users = await asyncio.to_thread(_list_auth_users, admin_client) - users_by_email: dict[str, Any] = {} - for candidate in users: - candidate_email = str(getattr(candidate, "email", "")).lower() - if candidate_email: - users_by_email[candidate_email] = candidate - self._users_by_email = users_by_email - self._user_lookup_cache_expires_at = ( - now + self._user_lookup_cache_ttl_seconds - ) + await self._refresh_user_lookup_cache_if_needed() - user = self._users_by_email.get(normalized_email) + user = self._users_by_phone.get(normalized_phone) if user is None: raise HTTPException(status_code=404, detail="User not found") - return UserByEmailResponse( + user_phone = _normalize_phone(getattr(user, "phone", "")) + if not user_phone: + raise HTTPException(status_code=404, detail="User not found") + + return UserByPhoneResponse( id=str(getattr(user, "id", "")), - email=str(getattr(user, "email", "")), + phone=user_phone, created_at=str(getattr(user, "created_at", "")), - email_confirmed_at=( - str(getattr(user, "email_confirmed_at", "")) - if getattr(user, "email_confirmed_at", None) + phone_confirmed_at=( + str(getattr(user, "phone_confirmed_at", "")) + if getattr(user, "phone_confirmed_at", None) else None ), ) - async def request_password_reset(self, request: PasswordResetRequest) -> None: - client = self._get_client() - try: - reset_email = cast(Any, client.auth.reset_password_email) - email = _coerce_reset_email(request.email) - if request.redirect_to: - options: dict[str, str] = { - "redirect_to": _validate_redirect_url(request.redirect_to) - } - await asyncio.to_thread(reset_email, email, options=options) - else: - await asyncio.to_thread(reset_email, email) - except AuthError as exc: - logger.warning( - "Password reset request failed", - error_type=type(exc).__name__, - ) - if _is_auth_upstream_unavailable(exc): - raise HTTPException( - status_code=503, - detail=AUTH_UNAVAILABLE_DETAIL, - ) from exc + async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]: + normalized_query = _normalize_phone_search_query(query) + if not normalized_query: + return [] + + await self._refresh_user_lookup_cache_if_needed() + if normalized_query.startswith("+"): + matched_user = self._users_by_phone.get(normalized_query) + if matched_user is None: + return [] + user_id = str(getattr(matched_user, "id", "")) + return [user_id] if user_id else [] + + digits = _digits_only(normalized_query) + if not digits: + return [] + + matched_records: list[tuple[str, str]] = [] + for cached_phone, candidate in self._users_by_phone.items(): + candidate_digits = _digits_only(cached_phone) + if not candidate_digits.endswith(digits): + continue + user_id = str(getattr(candidate, "id", "")) + if user_id: + matched_records.append((cached_phone, user_id)) + + if not matched_records: + return [] + + unique_ids: list[str] = [] + for _, user_id in sorted(matched_records, key=lambda item: item[0]): + if user_id in unique_ids: + continue + unique_ids.append(user_id) + if len(unique_ids) >= max(1, limit): + break + return unique_ids + + async def _refresh_user_lookup_cache_if_needed(self) -> None: + now = time.monotonic() + if now < self._user_lookup_cache_expires_at: + return - async def confirm_password_reset( - self, request: PasswordResetConfirmRequest - ) -> None: - client = self._get_client() admin_client = self._get_admin_client() - verify_payload: dict[str, Any] = { - "type": "recovery", - "email": request.email, - "token": request.token, - } - try: - verify_otp = cast(Any, client.auth.verify_otp) - response = await asyncio.to_thread(verify_otp, verify_payload) - session = getattr(response, "session", None) - user = getattr(response, "user", None) - user_id = str(getattr(user, "id", "")) if user is not None else "" - if session is None or not user_id: - raise HTTPException( - status_code=401, detail="Invalid or expired verification code" - ) - await asyncio.to_thread( - admin_client.auth.admin.update_user_by_id, - user_id, - {"password": request.new_password}, - ) - except AuthError as exc: - logger.warning( - "Password reset confirm failed", error_type=type(exc).__name__ - ) - if _is_auth_upstream_unavailable(exc): - raise HTTPException( - status_code=503, - detail=AUTH_UNAVAILABLE_DETAIL, - ) from exc - raise HTTPException( - status_code=401, detail="Invalid or expired verification code" - ) from exc + users = await asyncio.to_thread(_list_auth_users, admin_client) + users_by_phone: dict[str, Any] = {} + for candidate in users: + candidate_phone = _normalize_phone(getattr(candidate, "phone", "")) + if candidate_phone: + users_by_phone[candidate_phone] = candidate + self._users_by_phone = users_by_phone + self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds def _is_auth_upstream_unavailable(exc: AuthError) -> bool: @@ -312,55 +237,24 @@ def _is_auth_upstream_unavailable(exc: AuthError) -> bool: return any(token in code or token in message for token in indicators) -def _validate_redirect_url(url: str) -> str: - parsed = urlparse(url) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise HTTPException(status_code=422, detail="Invalid redirect URL") - - origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}" - allowed_origins = { - _normalize_origin(candidate) - for candidate in config.cors.allow_origins - if _is_http_origin(candidate) - } - if origin not in allowed_origins: - raise HTTPException(status_code=422, detail="Invalid redirect URL") - return url - - -def _normalize_origin(value: str) -> str: - parsed = urlparse(value) - return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}" - - -def _is_http_origin(value: str) -> bool: - parsed = urlparse(value) - return parsed.scheme in {"http", "https"} and bool(parsed.netloc) - - -def _coerce_reset_email(value: object) -> str: - if isinstance(value, str): - return value - - if isinstance(value, Mapping): - nested = value.get("email") or value.get("value") - if isinstance(nested, str): - return nested - - raise HTTPException(status_code=422, detail="Invalid email") - - def _map_auth_response(response: object, failure_message: str) -> SessionResponse: session = getattr(response, "session", None) user = getattr(response, "user", None) if session is None or user is None: raise HTTPException(status_code=401, detail=failure_message) - email = getattr(user, "email", None) - if not email: + phone = _normalize_phone(getattr(user, "phone", None)) + if not phone: raise HTTPException(status_code=401, detail=failure_message) - auth_user = AuthUser(id=str(user.id), email=str(email)) + try: + auth_user = AuthUser(id=str(user.id), phone=str(phone)) + except ValidationError as exc: + logger.warning( + "Auth response returned invalid phone format", + error_type=type(exc).__name__, + ) + raise HTTPException(status_code=401, detail=failure_message) from exc return SessionResponse( access_token=str(session.access_token), refresh_token=str(session.refresh_token), @@ -389,3 +283,37 @@ def _list_auth_users(client: Any) -> list[Any]: page += 1 return users + + +def _normalize_phone(raw_phone: object) -> str | None: + phone = str(raw_phone).strip() + for separator in (" ", "-", "(", ")"): + phone = phone.replace(separator, "") + if not phone: + return None + if phone.startswith("00") and len(phone) > 2: + return f"+{phone[2:]}" + if phone.startswith("+"): + return phone + if phone.isdigit(): + return f"+{phone}" + return None + + +def _normalize_phone_search_query(raw_query: str) -> str | None: + query = raw_query.strip() + for separator in (" ", "-", "(", ")"): + query = query.replace(separator, "") + if not query: + return None + if query.startswith("00") and len(query) > 2: + return f"+{query[2:]}" + if query.startswith("+"): + return query + if query.isdigit(): + return query + return None + + +def _digits_only(value: str) -> str: + return "".join(ch for ch in value if ch.isdigit()) diff --git a/backend/src/v1/auth/router.py b/backend/src/v1/auth/router.py index 5a04767..2cad943 100644 --- a/backend/src/v1/auth/router.py +++ b/backend/src/v1/auth/router.py @@ -1,20 +1,16 @@ from __future__ import annotations from fastapi import APIRouter, Depends, Request, Response -from fastapi import HTTPException +from core.config.settings import config from v1.auth.rate_limit import enforce_rate_limit from v1.auth.dependencies import get_auth_service from v1.auth.schemas import ( - PasswordResetConfirmRequest, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionDeleteRequest, SessionRefreshRequest, SessionResponse, - VerificationCreateRequest, - VerificationCreateResponse, - VerificationResendRequest, - VerificationVerifyRequest, ) from v1.auth.service import AuthService @@ -22,80 +18,49 @@ from v1.auth.service import AuthService router = APIRouter(prefix="/auth", tags=["auth"]) -@router.post( - "/verifications", response_model=VerificationCreateResponse, status_code=202 -) -async def create_verification( - payload: VerificationCreateRequest, - service: AuthService = Depends(get_auth_service), -) -> VerificationCreateResponse: - await enforce_rate_limit( - scope="signup_start", - identifier=payload.email, - limit=5, - window_seconds=60, - ) - return await service.create_verification(payload) - - -@router.post("/verify", response_model=SessionResponse) -async def verify( - payload: VerificationVerifyRequest, - request: Request, - service: AuthService = Depends(get_auth_service), -) -> SessionResponse | Response: - scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm" - limit = 10 - window_seconds = 600 - await enforce_rate_limit( - scope=scope, - identifier=f"{payload.email.lower()}:{_client_ip(request)}", - limit=limit, - window_seconds=window_seconds, - ) - if payload.type == "signup": - return await service.verify_verification(payload) - if payload.new_password is None: - raise HTTPException(status_code=422, detail="Invalid request") - await service.confirm_password_reset( - PasswordResetConfirmRequest( - email=payload.email, - token=payload.token, - new_password=payload.new_password, - ) - ) - return Response(status_code=204) - - -@router.post("/resend", status_code=204) -async def resend( - payload: VerificationResendRequest, +@router.post("/otp/send", status_code=204) +async def send_otp( + payload: OtpSendRequest, request: Request, service: AuthService = Depends(get_auth_service), ) -> Response: - scope = "signup_resend" if payload.type == "signup" else "password_reset_request" + client_ip = _client_ip(request) await enforce_rate_limit( - scope=scope, - identifier=f"{payload.email.lower()}:{_client_ip(request)}", - limit=5, + scope="otp_send_phone", + identifier=payload.phone, + limit=3, window_seconds=60, ) - await service.resend_verification(payload) + await enforce_rate_limit( + scope="otp_send_ip", + identifier=client_ip, + limit=20, + window_seconds=60, + ) + await service.send_otp(payload) return Response(status_code=204) -@router.post("/sessions", response_model=SessionResponse) -async def create_session( - payload: SessionCreateRequest, +@router.post("/phone-session", response_model=SessionResponse) +async def create_phone_session( + payload: PhoneSessionCreateRequest, + request: Request, service: AuthService = Depends(get_auth_service), ) -> SessionResponse: + client_ip = _client_ip(request) await enforce_rate_limit( - scope="login", - identifier=payload.email, - limit=10, - window_seconds=60, + scope="phone_session_phone", + identifier=payload.phone, + limit=6, + window_seconds=300, ) - return await service.create_session(payload) + await enforce_rate_limit( + scope="phone_session_ip", + identifier=client_ip, + limit=20, + window_seconds=300, + ) + return await service.create_phone_session(payload) @router.post("/sessions/refresh", response_model=SessionResponse) @@ -130,13 +95,23 @@ async def delete_session( def _client_ip(request: Request) -> str: - forwarded_for = request.headers.get("x-forwarded-for", "") - if forwarded_for: - first = forwarded_for.split(",")[0].strip() - if first: - return first - real_ip = request.headers.get("x-real-ip", "").strip() - if real_ip: - return real_ip host = request.client.host if request.client else "" - return host or "unknown" + if not host: + return "unknown" + + if _should_trust_proxy_headers(host): + forwarded_for = request.headers.get("x-forwarded-for", "") + if forwarded_for: + first = forwarded_for.split(",")[0].strip() + if first: + return first + real_ip = request.headers.get("x-real-ip", "").strip() + if real_ip: + return real_ip + + return host + + +def _should_trust_proxy_headers(host: str) -> bool: + trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips} + return host in trusted_proxies diff --git a/backend/src/v1/auth/schemas.py b/backend/src/v1/auth/schemas.py index e0f0524..8542180 100644 --- a/backend/src/v1/auth/schemas.py +++ b/backend/src/v1/auth/schemas.py @@ -1,49 +1,22 @@ from __future__ import annotations -from typing import Literal - -from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field SUPABASE_PASSWORD_MIN_LENGTH = 6 -OtpType = Literal["signup", "recovery"] +SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$" -class VerificationCreateRequest(BaseModel): +class OtpSendRequest(BaseModel): model_config = ConfigDict(extra="forbid") - username: str = Field(min_length=3, max_length=30) - email: EmailStr - password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH) - redirect_to: str | None = None - invite_code: str | None = None + phone: str = Field(pattern=SUPABASE_PHONE_PATTERN) -class VerificationResendRequest(BaseModel): - email: EmailStr - type: OtpType = "signup" - redirect_to: str | None = None +class PhoneSessionCreateRequest(BaseModel): + model_config = ConfigDict(extra="forbid") - -class VerificationVerifyRequest(BaseModel): - type: OtpType = "signup" - email: EmailStr + phone: str = Field(pattern=SUPABASE_PHONE_PATTERN) token: str = Field(pattern=r"^\d{6}$") - new_password: str | None = Field( - default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH - ) - - @model_validator(mode="after") - def validate_type_payload(self) -> "VerificationVerifyRequest": - if self.type == "recovery" and self.new_password is None: - raise ValueError("new_password is required when type is recovery") - if self.type == "signup" and self.new_password is not None: - raise ValueError("new_password is only allowed when type is recovery") - return self - - -class SessionCreateRequest(BaseModel): - email: EmailStr - password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH) class SessionRefreshRequest(BaseModel): @@ -56,7 +29,7 @@ class SessionDeleteRequest(BaseModel): class AuthUser(BaseModel): id: str - email: EmailStr + phone: str = Field(pattern=SUPABASE_PHONE_PATTERN) class SessionResponse(BaseModel): @@ -67,23 +40,12 @@ class SessionResponse(BaseModel): user: AuthUser -class UserByEmailResponse(BaseModel): +class UserByPhoneResponse(BaseModel): id: str - email: EmailStr + phone: str = Field(pattern=SUPABASE_PHONE_PATTERN) created_at: str - email_confirmed_at: str | None = None + phone_confirmed_at: str | None = None -class VerificationCreateResponse(BaseModel): - email: EmailStr - - -class PasswordResetRequest(BaseModel): - email: EmailStr - redirect_to: str | None = None - - -class PasswordResetConfirmRequest(BaseModel): - email: EmailStr - token: str = Field(pattern=r"^\d{6}$") - new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH) +class OtpSendResponse(BaseModel): + phone: str = Field(pattern=SUPABASE_PHONE_PATTERN) diff --git a/backend/src/v1/auth/service.py b/backend/src/v1/auth/service.py index 8082784..8fb0803 100644 --- a/backend/src/v1/auth/service.py +++ b/backend/src/v1/auth/service.py @@ -1,52 +1,30 @@ from __future__ import annotations -import re from typing import Protocol from v1.auth.schemas import ( - PasswordResetConfirmRequest, - PasswordResetRequest, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionRefreshRequest, SessionResponse, - VerificationCreateRequest, - VerificationCreateResponse, - VerificationResendRequest, - VerificationVerifyRequest, ) class AuthServiceGateway(Protocol): - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: + async def send_otp(self, request: OtpSendRequest) -> None: raise NotImplementedError - async def verify_verification( - self, request: VerificationVerifyRequest + async def create_phone_session( + self, request: PhoneSessionCreateRequest ) -> SessionResponse: raise NotImplementedError - async def resend_verification(self, request: VerificationResendRequest) -> None: - raise NotImplementedError - - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: - raise NotImplementedError - async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: raise NotImplementedError async def delete_session(self, refresh_token: str | None) -> None: raise NotImplementedError - async def request_password_reset(self, request: PasswordResetRequest) -> None: - raise NotImplementedError - - async def confirm_password_reset( - self, request: PasswordResetConfirmRequest - ) -> None: - raise NotImplementedError - class AuthService: _gateway: AuthServiceGateway @@ -54,50 +32,16 @@ class AuthService: def __init__(self, gateway: AuthServiceGateway) -> None: self._gateway = gateway - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: - normalized_invite_code = _normalize_invite_code(request.invite_code) - normalized_request = request.model_copy( - update={"invite_code": normalized_invite_code} - ) - return await self._gateway.create_verification(normalized_request) + async def send_otp(self, request: OtpSendRequest) -> None: + await self._gateway.send_otp(request) - async def verify_verification( - self, request: VerificationVerifyRequest + async def create_phone_session( + self, request: PhoneSessionCreateRequest ) -> SessionResponse: - return await self._gateway.verify_verification(request) - - async def resend_verification(self, request: VerificationResendRequest) -> None: - await self._gateway.resend_verification(request) - - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: - return await self._gateway.create_session(request) + return await self._gateway.create_phone_session(request) async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: return await self._gateway.refresh_session(request) async def delete_session(self, refresh_token: str | None) -> None: await self._gateway.delete_session(refresh_token) - - async def request_password_reset(self, request: PasswordResetRequest) -> None: - await self._gateway.request_password_reset(request) - - async def confirm_password_reset( - self, request: PasswordResetConfirmRequest - ) -> None: - await self._gateway.confirm_password_reset(request) - - -_INVITE_CODE_PATTERN = re.compile(r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{4}$") - - -def _normalize_invite_code(value: str | None) -> str | None: - if value is None: - return None - - normalized = value.strip().upper() - if not normalized: - return None - - return normalized if _INVITE_CODE_PATTERN.fullmatch(normalized) else None diff --git a/backend/src/v1/schedule_items/schemas.py b/backend/src/v1/schedule_items/schemas.py index c4df121..769bf44 100644 --- a/backend/src/v1/schedule_items/schemas.py +++ b/backend/src/v1/schedule_items/schemas.py @@ -5,7 +5,7 @@ from typing import ClassVar from uuid import UUID from zoneinfo import ZoneInfo, ZoneInfoNotFoundError -from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from schemas.inbox.messages import ( CalendarContent, @@ -154,7 +154,11 @@ _PERMISSION_EDIT = 4 class ScheduleItemShareRequest(BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - email: EmailStr = Field(..., description="Email of user to share with") + phone: str = Field( + ..., + pattern=r"^\+861[3-9]\d{9}$", + description="Phone of user to share with", + ) permission_view: bool = Field(True, description="Grant view permission") permission_edit: bool = Field(False, description="Grant edit permission") permission_invite: bool = Field(False, description="Grant invite permission") diff --git a/backend/src/v1/schedule_items/service.py b/backend/src/v1/schedule_items/service.py index e2c4b28..8336659 100644 --- a/backend/src/v1/schedule_items/service.py +++ b/backend/src/v1/schedule_items/service.py @@ -31,19 +31,19 @@ from v1.schedule_items.schemas import ( if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession - from v1.auth.schemas import UserByEmailResponse + from v1.auth.schemas import UserByPhoneResponse logger = get_logger("v1.schedule_items.service") -class AuthByEmailGateway(Protocol): - async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ... +class AuthByPhoneGateway(Protocol): + async def get_user_by_phone(self, phone: str) -> "UserByPhoneResponse": ... class ScheduleItemService(BaseService): _repository: ScheduleItemRepository _session: AsyncSession - _auth_gateway: AuthByEmailGateway + _auth_gateway: AuthByPhoneGateway _inbox_repository: InboxMessageRepository def __init__( @@ -51,7 +51,7 @@ class ScheduleItemService(BaseService): repository: ScheduleItemRepository, session: AsyncSession, current_user: CurrentUser | None, - auth_gateway: AuthByEmailGateway | None = None, + auth_gateway: AuthByPhoneGateway | None = None, inbox_repository: InboxMessageRepository | None = None, ) -> None: super().__init__(current_user=current_user) @@ -329,7 +329,7 @@ class ScheduleItemService(BaseService): detail=f"You can only share with permissions up to {inviter_permission}", ) - target_user = await self._auth_gateway.get_user_by_email(request.email) + target_user = await self._auth_gateway.get_user_by_phone(request.phone) recipient_id = UUID(target_user.id) existing = await self._repository.get_subscription(item_id, recipient_id) @@ -404,7 +404,7 @@ class ScheduleItemService(BaseService): except ValueError: await self._session.rollback() logger.exception( - "Auth lookup returned invalid user id", email=request.email + "Auth lookup returned invalid user id", phone=request.phone ) raise HTTPException(status_code=503, detail="Auth lookup unavailable") diff --git a/backend/src/v1/users/dependencies.py b/backend/src/v1/users/dependencies.py index 53f4a4b..aeba2e9 100644 --- a/backend/src/v1/users/dependencies.py +++ b/backend/src/v1/users/dependencies.py @@ -76,11 +76,11 @@ async def _verify_user_with_supabase(token: str) -> CurrentUser | None: parsed_id = UUID(user_id) except ValueError: return None - email = getattr(user, "email", None) + phone = getattr(user, "phone", None) role = getattr(user, "role", None) return CurrentUser( id=parsed_id, - email=email if isinstance(email, str) else None, + phone=phone if isinstance(phone, str) else None, role=role if isinstance(role, str) else None, ) @@ -125,9 +125,9 @@ async def get_current_user( raise HTTPException(status_code=401, detail="Unauthorized") logger.debug("JWT validation successful", user_id=str(user_id)) - email = payload.get("email") if isinstance(payload.get("email"), str) else None + phone = payload.get("phone") if isinstance(payload.get("phone"), str) else None role = payload.get("role") if isinstance(payload.get("role"), str) else None - return CurrentUser(id=user_id, email=email, role=role) + return CurrentUser(id=user_id, phone=phone, role=role) async def get_user_repository( diff --git a/backend/src/v1/users/repository.py b/backend/src/v1/users/repository.py index f6dc983..924857c 100644 --- a/backend/src/v1/users/repository.py +++ b/backend/src/v1/users/repository.py @@ -38,7 +38,7 @@ class UserRepository(Protocol): ... async def search_users(self, query: str, limit: int = 20) -> list[Profile]: - """Search users by username (ilike) or email (exact match).""" + """Search users by username (ilike) or phone (exact match).""" ... diff --git a/backend/src/v1/users/service.py b/backend/src/v1/users/service.py index 406ed0d..70d5840 100644 --- a/backend/src/v1/users/service.py +++ b/backend/src/v1/users/service.py @@ -21,19 +21,22 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession from schemas.user.context import UserContext - from v1.auth.schemas import UserByEmailResponse logger = get_logger("v1.users.service") -_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$") +_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$") class AuthLookupGateway(Protocol): - async def get_user_id_by_email(self, email: str) -> str | None: ... + async def search_user_ids_by_phone( + self, query: str, limit: int = 20 + ) -> list[str]: ... -class AuthByEmailGateway(Protocol): - async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ... +class AuthByPhoneGateway(Protocol): + async def search_user_ids_by_phone( + self, query: str, limit: int = 20 + ) -> list[str]: ... class UserContextInvalidator(Protocol): @@ -41,15 +44,14 @@ class UserContextInvalidator(Protocol): class AuthLookupAdapter: - def __init__(self, gateway: AuthByEmailGateway) -> None: + def __init__(self, gateway: AuthByPhoneGateway) -> None: self._gateway = gateway - async def get_user_id_by_email(self, email: str) -> str | None: + async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]: try: - response = await self._gateway.get_user_by_email(email) - return response.id + return await self._gateway.search_user_ids_by_phone(query, limit=limit) except HTTPException: - return None + return [] class UserService(BaseService): @@ -92,11 +94,11 @@ class UserService(BaseService): if user is None: raise HTTPException(status_code=404, detail="User not found") - email = self._current_user.email if self._current_user else None + phone = self._current_user.phone if self._current_user else None return UserContext( id=str(user.id), username=user.username, - email=email, + phone=phone, avatar_url=user.avatar_url, bio=user.bio, settings=parse_profile_settings(user.settings), @@ -152,11 +154,11 @@ class UserService(BaseService): error=str(exc), ) - email = self._current_user.email if self._current_user else None + phone = self._current_user.phone if self._current_user else None return UserContext( id=str(user.id), username=user.username, - email=email, + phone=phone, avatar_url=user.avatar_url, bio=user.bio, settings=parse_profile_settings(user.settings), @@ -181,36 +183,59 @@ class UserService(BaseService): async def search_users(self, request: UserSearchRequest) -> list[UserContext]: query = request.query.strip() - if _EMAIL_PATTERN.match(query): - return await self._search_by_email(query) + if _looks_like_phone_query(query): + phone_results = await self._search_by_phone(query) + if not query.isdigit(): + return phone_results + username_results = await self._search_by_username(query) + if not phone_results: + return username_results + merged_by_id = {result.id: result for result in phone_results} + for result in username_results: + merged_by_id.setdefault(result.id, result) + return list(merged_by_id.values()) return await self._search_by_username(query) - async def _search_by_email(self, email: str) -> list[UserContext]: + async def _search_by_phone(self, phone: str) -> list[UserContext]: if self._auth_gateway is None: raise HTTPException(status_code=503, detail="Auth lookup unavailable") - user_id_str = await self._auth_gateway.get_user_id_by_email(email) - if user_id_str is None: + user_id_values = await self._auth_gateway.search_user_ids_by_phone( + phone, limit=20 + ) + if not user_id_values: + return [] + + user_ids: list[UUID] = [] + for raw_id in user_id_values: + try: + user_ids.append(UUID(raw_id)) + except ValueError: + continue + if not user_ids: return [] try: - user = await self._repository.get_by_user_id(UUID(user_id_str)) + users_by_id = await self._repository.get_by_user_ids(user_ids) except SQLAlchemyError: raise HTTPException(status_code=503, detail="User store unavailable") - if user is None: - return [] - - return [ - UserContext( - id=str(user.id), - username=user.username, - avatar_url=user.avatar_url, - bio=user.bio, - settings=parse_profile_settings(user.settings), + results: list[UserContext] = [] + for user_id in user_ids: + user = users_by_id.get(user_id) + if user is None: + continue + results.append( + UserContext( + id=str(user.id), + username=user.username, + avatar_url=user.avatar_url, + bio=user.bio, + settings=parse_profile_settings(user.settings), + ) ) - ] + return results async def _search_by_username(self, query: str) -> list[UserContext]: try: @@ -228,3 +253,10 @@ class UserService(BaseService): ) for user in users ] + + +def _looks_like_phone_query(query: str) -> bool: + if not _PHONE_QUERY_PATTERN.fullmatch(query): + return False + digits_count = sum(char.isdigit() for char in query) + return digits_count >= 4 diff --git a/backend/tests/e2e/test_auth_flow.py b/backend/tests/e2e/test_auth_flow.py index fb9a840..5c6a31a 100644 --- a/backend/tests/e2e/test_auth_flow.py +++ b/backend/tests/e2e/test_auth_flow.py @@ -12,41 +12,24 @@ from app import app from v1.auth.dependencies import get_auth_service from v1.auth.schemas import ( AuthUser, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionRefreshRequest, SessionResponse, - VerificationCreateRequest, - VerificationCreateResponse, - VerificationResendRequest, - VerificationVerifyRequest, ) from v1.auth.service import AuthService class FakeE2EAuthService(AuthService): def __init__(self) -> None: - self._user = AuthUser(id="user-1", email="user@example.com") + self._user = AuthUser(id="user-1", phone="+8613812345678") - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: - return VerificationCreateResponse(email=request.email) - - async def verify_verification( - self, request: VerificationVerifyRequest - ) -> SessionResponse: - return SessionResponse( - access_token="access-1", - refresh_token="refresh-1", - expires_in=3600, - token_type="bearer", - user=self._user, - ) - - async def resend_verification(self, request: VerificationResendRequest) -> None: + async def send_otp(self, request: OtpSendRequest) -> None: return None - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: + async def create_phone_session( + self, request: PhoneSessionCreateRequest + ) -> SessionResponse: return SessionResponse( access_token="access-2", refresh_token="refresh-2", @@ -105,41 +88,25 @@ def test_auth_flow_e2e() -> None: base_url=f"http://{host}:{port}" ) try: - verification = request_context.post( - "/api/v1/auth/verifications", - data=json.dumps( - { - "username": "demo", - "email": "user@example.com", - "password": "secret123", - } - ), + send_code = request_context.post( + "/api/v1/auth/otp/send", + data=json.dumps({"phone": "+8613812345678"}), headers={"Content-Type": "application/json"}, ) - assert verification.status == 202 + assert send_code.status == 204 - verify = request_context.post( - "/api/v1/auth/verify", + login_or_register = request_context.post( + "/api/v1/auth/phone-session", data=json.dumps( { - "email": "user@example.com", + "phone": "+8613812345678", "token": "123456", } ), headers={"Content-Type": "application/json"}, ) - assert verify.status == 200 - assert verify.json()["access_token"] == "access-1" - - login = request_context.post( - "/api/v1/auth/sessions", - data=json.dumps( - {"email": "user@example.com", "password": "secret123"} - ), - headers={"Content-Type": "application/json"}, - ) - assert login.status == 200 - assert login.json()["access_token"] == "access-2" + assert login_or_register.status == 200 + assert login_or_register.json()["access_token"] == "access-2" refresh = request_context.post( "/api/v1/auth/sessions/refresh", diff --git a/backend/tests/e2e/test_infra_health_e2e.py b/backend/tests/e2e/test_infra_health_e2e.py index a438446..a7396bc 100644 --- a/backend/tests/e2e/test_infra_health_e2e.py +++ b/backend/tests/e2e/test_infra_health_e2e.py @@ -4,11 +4,16 @@ import socket import threading import time +import pytest from playwright.sync_api import sync_playwright import uvicorn from app import app -from v1.infra.dependencies import get_redis_service + +pytest.skip( + "infra health endpoint removed from v1 API", + allow_module_level=True, +) class _FakeService: @@ -52,8 +57,6 @@ def _start_server(host: str, port: int): def test_infra_health_e2e() -> None: - app.dependency_overrides[get_redis_service] = lambda: _FakeService() - host = "127.0.0.1" port = _find_free_port() server, thread = _start_server(host, port) diff --git a/backend/tests/e2e/test_profile_flow.py b/backend/tests/e2e/test_profile_flow.py index 608f196..0dec6fe 100644 --- a/backend/tests/e2e/test_profile_flow.py +++ b/backend/tests/e2e/test_profile_flow.py @@ -11,21 +11,22 @@ import uvicorn from app import app from core.auth.models import CurrentUser +from schemas.user.context import UserContext from v1.users.dependencies import get_current_user, get_user_service -from v1.users.schemas import UserResponse, UserUpdateRequest +from v1.users.schemas import UserUpdateRequest class FakeUserService: """Fake service for E2E testing.""" - def __init__(self, user: UserResponse) -> None: + def __init__(self, user: UserContext) -> None: self._user = user - async def get_me(self) -> UserResponse: + async def get_me(self) -> UserContext: return self._user - async def update_me(self, update: UserUpdateRequest) -> UserResponse: - return UserResponse( + async def update_me(self, update: UserUpdateRequest) -> UserContext: + return UserContext( id=self._user.id, username=( update.username if update.username is not None else self._user.username @@ -38,6 +39,7 @@ class FakeUserService: bio=update.bio if update.bio is not None else self._user.bio, ) + def _find_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("127.0.0.1", 0)) @@ -65,7 +67,7 @@ def _start_server(host: str, port: int): def test_profile_flow_e2e() -> None: user_id = UUID("00000000-0000-0000-0000-000000000001") - user = UserResponse( + user = UserContext( id=str(user_id), username="demo", avatar_url=None, diff --git a/backend/tests/integration/test_auth_routes.py b/backend/tests/integration/test_auth_routes.py index f2ee314..9069ca7 100644 --- a/backend/tests/integration/test_auth_routes.py +++ b/backend/tests/integration/test_auth_routes.py @@ -11,16 +11,10 @@ from v1.auth.dependencies import get_auth_service from v1.auth.rate_limit import reset_rate_limit_state from v1.auth.schemas import ( AuthUser, - PasswordResetConfirmRequest, - PasswordResetRequest, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionRefreshRequest, SessionResponse, - UserByEmailResponse, - VerificationCreateRequest, - VerificationCreateResponse, - VerificationResendRequest, - VerificationVerifyRequest, ) from v1.auth.service import AuthService @@ -30,58 +24,39 @@ def reset_auth_rate_limit_state() -> None: reset_rate_limit_state() +@pytest.fixture(autouse=True) +def force_in_memory_rate_limit(monkeypatch: pytest.MonkeyPatch) -> None: + async def _raise_redis_unavailable() -> None: + raise RuntimeError("redis unavailable in integration tests") + + monkeypatch.setattr( + "v1.auth.rate_limit.get_or_init_redis_client", + _raise_redis_unavailable, + ) + + class FakeAuthService(AuthService): def __init__(self, token_response: SessionResponse) -> None: self._token_response = token_response - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: - if request.email == "exists@example.com": - raise HTTPException(status_code=422, detail="Invalid signup request") - return VerificationCreateResponse(email=request.email) + async def send_otp(self, request: OtpSendRequest) -> None: + if request.phone == "+8613811111111": + raise HTTPException(status_code=401, detail="Invalid verification code") + return None - async def verify_verification( - self, request: VerificationVerifyRequest + async def create_phone_session( + self, request: PhoneSessionCreateRequest ) -> SessionResponse: if request.token == "000000": raise HTTPException(status_code=401, detail="Invalid verification code") return self._token_response - async def resend_verification(self, request: VerificationResendRequest) -> None: - return None - - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: - raise HTTPException(status_code=401, detail="Invalid credentials") - async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: raise HTTPException(status_code=401, detail="Invalid refresh token") async def delete_session(self, refresh_token: str | None) -> None: return None - async def get_user_by_email(self, email: str) -> UserByEmailResponse: - if email == "missing@example.com": - raise HTTPException(status_code=404, detail="User not found") - return UserByEmailResponse( - id="user-1", - email=email, - created_at="2026-02-24T00:00:00Z", - email_confirmed_at=None, - ) - - async def request_password_reset(self, request: PasswordResetRequest) -> None: - return None - - async def confirm_password_reset( - self, request: PasswordResetConfirmRequest - ) -> None: - if request.token == "000000": - raise HTTPException( - status_code=401, detail="Invalid or expired verification code" - ) - return None - def _override_auth_service(service: AuthService) -> Callable[[], AuthService]: def _get_service() -> AuthService: @@ -90,761 +65,126 @@ def _override_auth_service(service: AuthService) -> Callable[[], AuthService]: return _get_service -def test_signup_start_returns_pending_response() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( +def _token_response() -> SessionResponse: + user = AuthUser(id="user-1", phone="+8613812345678") + return SessionResponse( access_token="access", refresh_token="refresh", expires_in=3600, token_type="bearer", user=user, ) + + +def test_send_otp_returns_204() -> None: app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) + FakeAuthService(_token_response()) ) client = TestClient(app) try: response = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "user@example.com", - "password": "secret123", - }, + "/api/v1/auth/otp/send", + json={"phone": "+8613812345678"}, ) - assert response.status_code == 202 - assert response.json() == {"email": "user@example.com"} + assert response.status_code == 204 finally: app.dependency_overrides = {} -def test_signup_verify_returns_token_response() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) +def test_phone_session_returns_token_response() -> None: app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) + FakeAuthService(_token_response()) ) client = TestClient(app) try: response = client.post( - "/api/v1/auth/verify", - json={"email": "user@example.com", "token": "123456"}, + "/api/v1/auth/phone-session", + json={"phone": "+8613812345678", "token": "123456"}, ) assert response.status_code == 200 body = response.json() assert body["access_token"] == "access" assert body["refresh_token"] == "refresh" - assert body["user"]["email"] == "user@example.com" + assert body["user"]["phone"] == "+8613812345678" finally: app.dependency_overrides = {} -def test_signup_resend_returns_generic_message() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) +def test_phone_session_invalid_token_returns_problem_details() -> None: app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) + FakeAuthService(_token_response()) ) client = TestClient(app) try: response = client.post( - "/api/v1/auth/resend", - json={"type": "recovery", "email": "user@example.com"}, - ) - assert response.status_code == 204 - assert response.content == b"" - finally: - app.dependency_overrides = {} - - -def test_signup_verify_invalid_token_returns_problem_details() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verify", - json={"email": "user@example.com", "token": "000000"}, + "/api/v1/auth/phone-session", + json={"phone": "+8613812345678", "token": "000000"}, ) assert response.status_code == 401 assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unauthorized" - assert body["status"] == 401 - assert body["detail"] == "Invalid verification code" finally: app.dependency_overrides = {} -def test_signup_start_existing_email_returns_problem_details() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) +def test_legacy_routes_are_removed() -> None: app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) + FakeAuthService(_token_response()) ) client = TestClient(app) try: - response = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "exists@example.com", - "password": "secret123", - }, - ) - assert response.status_code == 422 - assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unprocessable Entity" - assert body["status"] == 422 - assert body["detail"] == "Invalid signup request" + assert client.post("/api/v1/auth/verifications", json={}).status_code == 404 + assert client.post("/api/v1/auth/verify", json={}).status_code == 404 + assert client.post("/api/v1/auth/resend", json={}).status_code == 404 + assert client.post("/api/v1/auth/sessions", json={}).status_code == 405 finally: app.dependency_overrides = {} -def test_signup_verify_rate_limited_after_too_many_attempts() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) +def test_send_otp_phone_rate_limited_after_too_many_attempts() -> None: app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) + FakeAuthService(_token_response()) ) client = TestClient(app) try: - for _ in range(10): + for _ in range(3): ok = client.post( - "/api/v1/auth/verify", - json={"email": "user@example.com", "token": "123456"}, - ) - assert ok.status_code == 200 - - blocked = client.post( - "/api/v1/auth/verify", - json={"email": "user@example.com", "token": "123456"}, - ) - assert blocked.status_code == 429 - assert blocked.headers["content-type"].startswith("application/problem+json") - finally: - app.dependency_overrides = {} - - -def test_signup_resend_rate_limited_after_too_many_attempts() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - for _ in range(5): - ok = client.post( - "/api/v1/auth/resend", - json={"email": "user@example.com"}, + "/api/v1/auth/otp/send", + json={"phone": "+8613812345678"}, ) assert ok.status_code == 204 blocked = client.post( - "/api/v1/auth/resend", - json={"email": "user@example.com"}, + "/api/v1/auth/otp/send", + json={"phone": "+8613812345678"}, ) assert blocked.status_code == 429 - assert blocked.headers["content-type"].startswith("application/problem+json") finally: app.dependency_overrides = {} -def test_signup_start_rate_limited_after_too_many_attempts() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) +def test_phone_session_rate_limited_after_too_many_attempts() -> None: app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) + FakeAuthService(_token_response()) ) client = TestClient(app) try: - for _ in range(5): - ok = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "user@example.com", - "password": "secret123", - }, - ) - assert ok.status_code == 202 - - blocked = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "user@example.com", - "password": "secret123", - }, - ) - assert blocked.status_code == 429 - assert blocked.headers["content-type"].startswith("application/problem+json") - finally: - app.dependency_overrides = {} - - -def test_login_invalid_returns_problem_details() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/sessions", - json={"email": "user@example.com", "password": "wrongpw"}, - ) - assert response.status_code == 401 - assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unauthorized" - assert body["status"] == 401 - assert body["detail"] == "Invalid credentials" - finally: - app.dependency_overrides = {} - - -def test_refresh_invalid_returns_problem_details() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/sessions/refresh", - json={"refresh_token": "invalid"}, - ) - assert response.status_code == 401 - assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unauthorized" - assert body["status"] == 401 - assert body["detail"] == "Invalid refresh token" - finally: - app.dependency_overrides = {} - - -def test_logout_returns_no_content() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.request( - "DELETE", - "/api/v1/auth/sessions", - json={"refresh_token": "refresh"}, - ) - assert response.status_code == 204 - assert response.content == b"" - finally: - app.dependency_overrides = {} - - -def test_login_rate_limited_after_too_many_attempts() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - for _ in range(10): + for _ in range(6): blocked = client.post( - "/api/v1/auth/sessions", - json={"email": "user@example.com", "password": "wrongpw"}, + "/api/v1/auth/phone-session", + json={"phone": "+8613812345678", "token": "000000"}, ) assert blocked.status_code == 401 blocked = client.post( - "/api/v1/auth/sessions", - json={"email": "user@example.com", "password": "wrongpw"}, - ) - assert blocked.status_code == 429 - assert blocked.headers["content-type"].startswith("application/problem+json") - body = blocked.json() - assert body["detail"] == "Too many requests" - finally: - app.dependency_overrides = {} - - -def test_refresh_rate_limited_after_too_many_attempts() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - for _ in range(10): - blocked = client.post( - "/api/v1/auth/sessions/refresh", - json={"refresh_token": "invalid"}, - ) - assert blocked.status_code == 401 - - blocked = client.post( - "/api/v1/auth/sessions/refresh", - json={"refresh_token": "invalid"}, - ) - assert blocked.status_code == 429 - assert blocked.headers["content-type"].startswith("application/problem+json") - body = blocked.json() - assert body["detail"] == "Too many requests" - finally: - app.dependency_overrides = {} - - -def test_refresh_rate_limit_not_bypassed_by_changing_refresh_token() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - for index in range(10): - blocked = client.post( - "/api/v1/auth/sessions/refresh", - json={"refresh_token": f"invalid-{index}"}, - ) - assert blocked.status_code == 401 - - blocked = client.post( - "/api/v1/auth/sessions/refresh", - json={"refresh_token": "invalid-extra"}, + "/api/v1/auth/phone-session", + json={"phone": "+8613812345678", "token": "000000"}, ) assert blocked.status_code == 429 finally: app.dependency_overrides = {} - - -def test_logout_rate_limited_after_too_many_attempts() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - for _ in range(10): - ok = client.request( - "DELETE", - "/api/v1/auth/sessions", - json={"refresh_token": "refresh"}, - ) - assert ok.status_code == 204 - - blocked = client.request( - "DELETE", - "/api/v1/auth/sessions", - json={"refresh_token": "refresh"}, - ) - assert blocked.status_code == 429 - assert blocked.headers["content-type"].startswith("application/problem+json") - body = blocked.json() - assert body["detail"] == "Too many requests" - finally: - app.dependency_overrides = {} - - -def test_logout_rate_limit_not_bypassed_by_changing_refresh_token() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - for index in range(10): - ok = client.request( - "DELETE", - "/api/v1/auth/sessions", - json={"refresh_token": f"refresh-{index}"}, - ) - assert ok.status_code == 204 - - blocked = client.request( - "DELETE", - "/api/v1/auth/sessions", - json={"refresh_token": "refresh-extra"}, - ) - assert blocked.status_code == 429 - finally: - app.dependency_overrides = {} - - -def test_signup_start_validation_error_returns_problem_details() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post("/api/v1/auth/verifications", json={}) - assert response.status_code == 422 - assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unprocessable Entity" - assert body["status"] == 422 - assert body["detail"] == "Invalid request" - finally: - app.dependency_overrides = {} - - -def test_signup_start_missing_username_returns_problem_details() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verifications", - json={"email": "user@example.com", "password": "secret123"}, - ) - assert response.status_code == 422 - assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unprocessable Entity" - assert body["status"] == 422 - assert body["detail"] == "Invalid request" - finally: - app.dependency_overrides = {} - - -def test_password_reset_request_returns_204() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/resend", - json={"email": "user@example.com"}, - ) - assert response.status_code == 204 - finally: - app.dependency_overrides = {} - - -def test_password_reset_confirm_returns_204() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verify", - json={ - "type": "recovery", - "email": "user@example.com", - "token": "123456", - "new_password": "newpassword123", - }, - ) - assert response.status_code == 204 - finally: - app.dependency_overrides = {} - - -def test_password_reset_confirm_invalid_token_returns_401() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verify", - json={ - "type": "recovery", - "email": "user@example.com", - "token": "000000", - "new_password": "newpassword123", - }, - ) - assert response.status_code == 401 - assert response.headers["content-type"].startswith("application/problem+json") - body = response.json() - assert body["title"] == "Unauthorized" - assert body["status"] == 401 - finally: - app.dependency_overrides = {} - - -def test_password_reset_confirm_weak_password_returns_422() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verify", - json={ - "type": "recovery", - "email": "user@example.com", - "token": "123456", - "new_password": "123", - }, - ) - assert response.status_code == 422 - assert response.headers["content-type"].startswith("application/problem+json") - finally: - app.dependency_overrides = {} - - -class TestInviteCodeSignup: - def test_signup_with_valid_invite_code_returns_202(self) -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "user@example.com", - "password": "secret123", - "invite_code": "A2B3", - }, - ) - assert response.status_code == 202 - assert response.json() == {"email": "user@example.com"} - finally: - app.dependency_overrides = {} - - def test_signup_with_invalid_invite_code_length_returns_202(self) -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "user@example.com", - "password": "secret123", - "invite_code": "ABC123", - }, - ) - assert response.status_code == 202 - assert response.json() == {"email": "user@example.com"} - finally: - app.dependency_overrides = {} - - def test_signup_with_invalid_invite_code_chars_returns_202(self) -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - app.dependency_overrides[get_auth_service] = _override_auth_service( - FakeAuthService(token_response) - ) - - client = TestClient(app) - try: - response = client.post( - "/api/v1/auth/verifications", - json={ - "username": "demo", - "email": "user@example.com", - "password": "secret123", - "invite_code": "ABCD1234", - }, - ) - assert response.status_code == 202 - assert response.json() == {"email": "user@example.com"} - finally: - app.dependency_overrides = {} diff --git a/backend/tests/integration/test_friendship_routes.py b/backend/tests/integration/test_friendship_routes.py index 64cfe01..8da9241 100644 --- a/backend/tests/integration/test_friendship_routes.py +++ b/backend/tests/integration/test_friendship_routes.py @@ -9,12 +9,12 @@ from fastapi.testclient import TestClient from app import app from core.auth.models import CurrentUser +from schemas.user.context import UserContext from v1.friendships.dependencies import get_friendship_service from v1.friendships.schemas import ( FriendRequestCreate, FriendRequestResponse, FriendResponse, - UserBasicInfo, ) from v1.friendships.service import FriendshipService from v1.users.dependencies import get_current_user @@ -31,9 +31,9 @@ class FakeFriendshipService(FriendshipService): async def send_request(self, request: FriendRequestCreate) -> FriendRequestResponse: return FriendRequestResponse( id=UUID("11111111-1111-1111-1111-111111111111"), - sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None), - recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None), - content=request.content, + sender=UserContext(id="user-1", username="sender", avatar_url=None), + recipient=UserContext(id="user-2", username="recipient", avatar_url=None), + content={"text": request.content} if request.content else None, status="pending", created_at=datetime.now(timezone.utc), ) @@ -41,9 +41,9 @@ class FakeFriendshipService(FriendshipService): async def accept_request(self, friendship_id: UUID) -> FriendRequestResponse: return FriendRequestResponse( id=friendship_id, - sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None), - recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None), - content="Hello!", + sender=UserContext(id="user-1", username="sender", avatar_url=None), + recipient=UserContext(id="user-2", username="recipient", avatar_url=None), + content={"text": "Hello!"}, status="accepted", created_at=datetime.now(timezone.utc), ) @@ -51,9 +51,9 @@ class FakeFriendshipService(FriendshipService): async def decline_request(self, friendship_id: UUID) -> FriendRequestResponse: return FriendRequestResponse( id=friendship_id, - sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None), - recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None), - content="Hello!", + sender=UserContext(id="user-1", username="sender", avatar_url=None), + recipient=UserContext(id="user-2", username="recipient", avatar_url=None), + content={"text": "Hello!"}, status="rejected", created_at=datetime.now(timezone.utc), ) @@ -61,9 +61,9 @@ class FakeFriendshipService(FriendshipService): async def cancel_request(self, friendship_id: UUID) -> FriendRequestResponse: return FriendRequestResponse( id=friendship_id, - sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None), - recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None), - content="Hello!", + sender=UserContext(id="user-1", username="sender", avatar_url=None), + recipient=UserContext(id="user-2", username="recipient", avatar_url=None), + content={"text": "Hello!"}, status="canceled", created_at=datetime.now(timezone.utc), ) @@ -72,11 +72,11 @@ class FakeFriendshipService(FriendshipService): return [ FriendRequestResponse( id=UUID("11111111-1111-1111-1111-111111111111"), - sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None), - recipient=UserBasicInfo( + sender=UserContext(id="user-1", username="sender", avatar_url=None), + recipient=UserContext( id="user-2", username="recipient", avatar_url=None ), - content="Hello!", + content={"text": "Hello!"}, status="pending", created_at=datetime.now(timezone.utc), ) @@ -86,10 +86,8 @@ class FakeFriendshipService(FriendshipService): return [ FriendRequestResponse( id=UUID("22222222-2222-2222-2222-222222222222"), - sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None), - recipient=UserBasicInfo( - id="user-3", username="target", avatar_url=None - ), + sender=UserContext(id="user-1", username="sender", avatar_url=None), + recipient=UserContext(id="user-3", username="target", avatar_url=None), content=None, status="pending", created_at=datetime.now(timezone.utc), @@ -100,7 +98,7 @@ class FakeFriendshipService(FriendshipService): return [ FriendResponse( id=UUID("33333333-3333-3333-3333-333333333333"), - friend=UserBasicInfo(id="user-2", username="friend", avatar_url=None), + friend=UserContext(id="user-2", username="friend", avatar_url=None), status="active", created_at=datetime.now(timezone.utc), accepted_at=datetime.now(timezone.utc), @@ -110,7 +108,7 @@ class FakeFriendshipService(FriendshipService): async def remove_friend(self, friend_id: UUID) -> FriendResponse: return FriendResponse( id=UUID("33333333-3333-3333-3333-333333333333"), - friend=UserBasicInfo(id=str(friend_id), username="friend", avatar_url=None), + friend=UserContext(id=str(friend_id), username="friend", avatar_url=None), status="active", created_at=datetime.now(timezone.utc), accepted_at=datetime.now(timezone.utc), @@ -129,7 +127,7 @@ def _override_friendship_service( def _get_fake_current_user() -> CurrentUser: return CurrentUser( id=UUID("00000000-0000-0000-0000-000000000001"), - email="test@example.com", + phone="+8613812345678", ) diff --git a/backend/tests/integration/test_schedule_share_routes.py b/backend/tests/integration/test_schedule_share_routes.py index 6a42d30..aa7d123 100644 --- a/backend/tests/integration/test_schedule_share_routes.py +++ b/backend/tests/integration/test_schedule_share_routes.py @@ -52,7 +52,7 @@ def test_share_schedule_item_returns_200() -> None: response = client.post( f"/api/v1/schedule-items/{item_id}/share", json={ - "email": "friend@example.com", + "phone": "+8613810000000", "permission_view": True, "permission_edit": False, "permission_invite": True, @@ -62,7 +62,7 @@ def test_share_schedule_item_returns_200() -> None: body = response.json() assert body["message"] == "Calendar invitation sent" assert service.last_share_request is not None - assert service.last_share_request.email == "friend@example.com" + assert service.last_share_request.phone == "+8613810000000" assert service.last_share_request.permission_invite is True finally: app.dependency_overrides = {} diff --git a/backend/tests/integration/test_users_routes.py b/backend/tests/integration/test_users_routes.py index bc6df1c..20aa42b 100644 --- a/backend/tests/integration/test_users_routes.py +++ b/backend/tests/integration/test_users_routes.py @@ -8,30 +8,31 @@ from fastapi.testclient import TestClient from app import app from core.auth.models import CurrentUser +from schemas.user.context import UserContext from v1.users.dependencies import get_current_user, get_user_service -from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest +from v1.users.schemas import UserSearchRequest, UserUpdateRequest from v1.users.service import UserService class FakeUserService: """Fake service for integration testing.""" - def __init__(self, user: UserResponse) -> None: + def __init__(self, user: UserContext) -> None: self._user = user - self._search_results: list[UserResponse] = [] + self._search_results: list[UserContext] = [] - def set_search_results(self, results: list[UserResponse]) -> None: + def set_search_results(self, results: list[UserContext]) -> None: self._search_results = results - async def get_me(self) -> UserResponse: + async def get_me(self) -> UserContext: if self._user.id is None: raise HTTPException(status_code=404, detail="User not found") return self._user - async def update_me(self, update: UserUpdateRequest) -> UserResponse: + async def update_me(self, update: UserUpdateRequest) -> UserContext: if self._user.id is None: raise HTTPException(status_code=404, detail="User not found") - return UserResponse( + return UserContext( id=self._user.id, username=( update.username if update.username is not None else self._user.username @@ -44,7 +45,7 @@ class FakeUserService: bio=update.bio if update.bio is not None else self._user.bio, ) - async def search_users(self, request: UserSearchRequest) -> list[UserResponse]: + async def search_users(self, request: UserSearchRequest) -> list[UserContext]: if request.query: return self._search_results if self._search_results else [self._user] return [] @@ -68,7 +69,7 @@ def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]: def test_get_me_returns_user() -> None: user_id = UUID("00000000-0000-0000-0000-000000000001") - user = UserResponse( + user = UserContext( id=str(user_id), username="demo", avatar_url=None, @@ -91,7 +92,7 @@ def test_get_me_returns_user() -> None: def test_patch_me_updates_user() -> None: user_id = UUID("00000000-0000-0000-0000-000000000001") - user = UserResponse( + user = UserContext( id=str(user_id), username="demo", avatar_url=None, @@ -117,7 +118,7 @@ def test_patch_me_updates_user() -> None: def test_patch_me_validation_error_returns_problem_details() -> None: user_id = UUID("00000000-0000-0000-0000-000000000001") - user = UserResponse( + user = UserContext( id=str(user_id), username="demo", avatar_url=None, @@ -142,7 +143,7 @@ def test_patch_me_validation_error_returns_problem_details() -> None: def test_search_users_returns_list() -> None: user_id = UUID("00000000-0000-0000-0000-000000000001") - user = UserResponse( + user = UserContext( id=str(user_id), username="demo", avatar_url=None, @@ -167,7 +168,7 @@ def test_search_users_returns_list() -> None: def test_search_users_empty_query_returns_422() -> None: user_id = UUID("00000000-0000-0000-0000-000000000001") - user = UserResponse( + user = UserContext( id=str(user_id), username="demo", avatar_url=None, diff --git a/backend/tests/integration/v1/agent/test_routes.py b/backend/tests/integration/v1/agent/test_routes.py index b39aea5..d915cc2 100644 --- a/backend/tests/integration/v1/agent/test_routes.py +++ b/backend/tests/integration/v1/agent/test_routes.py @@ -115,6 +115,34 @@ class _FailingStreamAgentService(_FakeAgentService): raise RuntimeError("redis timeout") +class _TerminalStreamAgentService(_FakeAgentService): + def __init__(self) -> None: + super().__init__() + self.stream_calls = 0 + + async def stream_events( + self, + *, + thread_id: str, + last_event_id: str | None, + current_user: CurrentUser, + ) -> list[dict[str, object]]: + del thread_id, last_event_id, current_user + self.stream_calls += 1 + if self.stream_calls == 1: + return [ + { + "id": "9-0", + "event": { + "type": "RUN_FINISHED", + "threadId": "00000000-0000-0000-0000-000000000001", + "runId": "run-1", + }, + } + ] + return [] + + def test_run_requires_auth_and_returns_202_task_id() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() client = TestClient(app) @@ -129,13 +157,13 @@ def test_run_requires_auth_and_returns_202_task_id() -> None: "messages": [{"id": "u1", "role": "user", "content": "hello"}], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": "worker"}, }, ) assert unauthorized.status_code == 401 app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) authorized = client.post( "/api/v1/agent/runs", @@ -146,7 +174,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None: "messages": [{"id": "u1", "role": "user", "content": "hello"}], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": "worker"}, }, ) assert authorized.status_code == 202 @@ -161,7 +189,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None: def test_stream_reads_from_last_event_id() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) original_acquire = agent_router._acquire_sse_slot @@ -197,7 +225,7 @@ def test_stream_reads_from_last_event_id() -> None: def test_stream_handles_stream_backend_errors_without_connection_crash() -> None: app.dependency_overrides[get_agent_service] = lambda: _FailingStreamAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) original_acquire = agent_router._acquire_sse_slot @@ -226,10 +254,45 @@ def test_stream_handles_stream_backend_errors_without_connection_crash() -> None app.dependency_overrides = {} +def test_stream_stops_after_terminal_run_event() -> None: + service = _TerminalStreamAgentService() + app.dependency_overrides[get_agent_service] = lambda: service + app.dependency_overrides[get_current_user] = lambda: CurrentUser( + id=uuid4(), phone="+8613812345678" + ) + client = TestClient(app) + original_acquire = agent_router._acquire_sse_slot + original_release = agent_router._release_sse_slot + + async def _allow_slot(*, user_id: str) -> bool: + del user_id + return True + + async def _noop_release(*, user_id: str) -> None: + del user_id + return None + + agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment] + agent_router._release_sse_slot = _noop_release # type: ignore[assignment] + + try: + response = client.get( + "/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=3" + ) + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + assert "event: RUN_FINISHED" in response.text + assert service.stream_calls == 1 + finally: + agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment] + agent_router._release_sse_slot = original_release # type: ignore[assignment] + app.dependency_overrides = {} + + def test_stream_rejects_invalid_last_event_id() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) @@ -255,7 +318,7 @@ def test_history_returns_state_snapshot() -> None: assert unauthorized.status_code == 401 app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) authorized = client.get( "/api/v1/agent/history", @@ -276,7 +339,7 @@ def test_history_returns_state_snapshot() -> None: def test_user_history_returns_latest_snapshot() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) try: @@ -292,7 +355,7 @@ def test_user_history_returns_latest_snapshot() -> None: def test_run_rejects_oversized_user_text_payload() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) @@ -312,7 +375,7 @@ def test_run_rejects_oversized_user_text_payload() -> None: ], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": "worker"}, }, ) assert response.status_code == 422 @@ -323,7 +386,7 @@ def test_run_rejects_oversized_user_text_payload() -> None: def test_run_rejects_client_supplied_history_messages() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) @@ -340,7 +403,7 @@ def test_run_rejects_client_supplied_history_messages() -> None: ], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": "worker"}, }, ) assert response.status_code == 422 @@ -351,7 +414,7 @@ def test_run_rejects_client_supplied_history_messages() -> None: def test_upload_attachment_returns_reference() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) @@ -376,7 +439,7 @@ def test_upload_attachment_returns_reference() -> None: def test_create_attachment_signed_url_returns_url() -> None: app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService() app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) @@ -399,7 +462,7 @@ def test_create_attachment_signed_url_returns_url() -> None: def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None: app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) async def mock_transcribe_file(file_path: str, filename: str) -> str: @@ -434,7 +497,7 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None: def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None: app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4) @@ -457,7 +520,7 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None: def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None: app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) @@ -478,7 +541,7 @@ def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None: def test_asr_transcribe_rejects_invalid_wav_payload(monkeypatch) -> None: app.dependency_overrides[get_current_user] = lambda: CurrentUser( - id=uuid4(), email="user@example.com" + id=uuid4(), phone="+8613812345678" ) client = TestClient(app) diff --git a/backend/tests/integration/v1/agent/test_sse_flow_live.py b/backend/tests/integration/v1/agent/test_sse_flow_live.py index 471cf86..ca03785 100644 --- a/backend/tests/integration/v1/agent/test_sse_flow_live.py +++ b/backend/tests/integration/v1/agent/test_sse_flow_live.py @@ -20,16 +20,16 @@ FIXTURE_IMAGE_PATH = ( async def _live_access_token(client: httpx.AsyncClient) -> str: - email = os.getenv("AGENT_LIVE_EMAIL") + phone = os.getenv("AGENT_LIVE_PHONE") password = os.getenv("AGENT_LIVE_PASSWORD") - if not email or not password: + if not phone or not password: pytest.fail( - "AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_EMAIL and AGENT_LIVE_PASSWORD" + "AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_PHONE and AGENT_LIVE_PASSWORD" ) response = await client.post( f"{BASE_URL}/api/v1/auth/sessions", - json={"email": email, "password": password}, + json={"phone": phone, "password": password}, ) response_text = response.text.strip().replace("\n", " ") truncated_text = response_text[:200] @@ -67,7 +67,7 @@ async def test_agent_sse_closed_loop_live() -> None: ], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": "worker"}, }, ) assert run_resp.status_code == 202 @@ -143,7 +143,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None: ], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": "worker"}, }, ) assert run_resp.status_code == 202 diff --git a/backend/tests/unit/v1/agent/test_owner_guard.py b/backend/tests/unit/v1/agent/test_owner_guard.py index 4c5af2b..b1d994c 100644 --- a/backend/tests/unit/v1/agent/test_owner_guard.py +++ b/backend/tests/unit/v1/agent/test_owner_guard.py @@ -10,7 +10,7 @@ from v1.agent.service import ensure_session_owner def test_owner_guard_denies_non_owner() -> None: - user = CurrentUser(id=uuid4(), email="self@example.com") + user = CurrentUser(id=uuid4(), phone="self@example.com") with pytest.raises(HTTPException): ensure_session_owner(owner_id="other-user", current_user=user) diff --git a/backend/tests/unit/v1/agent/test_repository.py b/backend/tests/unit/v1/agent/test_repository.py index c79ba08..9fccf95 100644 --- a/backend/tests/unit/v1/agent/test_repository.py +++ b/backend/tests/unit/v1/agent/test_repository.py @@ -7,6 +7,8 @@ from uuid import uuid4 import pytest from models.agent_chat_message import AgentChatMessageRole +from sqlalchemy import select +from models.agent_chat_message import AgentChatMessage from v1.agent.repository import AgentRepository @@ -79,6 +81,7 @@ async def test_persist_user_message_sets_session_title_when_empty() -> None: session_id=session_id, content=" 请帮我安排明天下午开会 ", metadata=None, + visibility_mask=1, ) assert session_row.title == "请帮我安排明天下午开会" @@ -101,6 +104,7 @@ async def test_persist_user_message_keeps_existing_session_title() -> None: session_id=session_id, content="新的消息内容", metadata=None, + visibility_mask=1, ) assert session_row.title == "已有标题" @@ -164,3 +168,13 @@ async def test_get_history_day_uses_target_day_queries_only() -> None: messages = payload["messages"] assert isinstance(messages, list) assert len(messages) == 1 + + +def test_apply_visibility_filter_adds_bitwise_expression() -> None: + repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type] + stmt = select(AgentChatMessage) + + filtered = repository._apply_visibility_filter(stmt=stmt, visibility_mask=1) + + assert "visibility_mask" in str(filtered) + assert "&" in str(filtered) diff --git a/backend/tests/unit/v1/agent/test_service.py b/backend/tests/unit/v1/agent/test_service.py index 3d6f266..9565322 100644 --- a/backend/tests/unit/v1/agent/test_service.py +++ b/backend/tests/unit/v1/agent/test_service.py @@ -20,6 +20,7 @@ class _FakeRepository: def __init__(self) -> None: self.committed = False self.persisted_user_messages: list[dict[str, object]] = [] + self.created_session_calls = 0 async def get_session_owner(self, *, session_id: str) -> str: if session_id == "00000000-0000-0000-0000-000000000001": @@ -30,6 +31,7 @@ class _FakeRepository: self, *, user_id: str, session_id: str | None = None ) -> str: del user_id + self.created_session_calls += 1 return session_id or "00000000-0000-0000-0000-000000000999" async def commit(self) -> None: @@ -39,9 +41,13 @@ class _FakeRepository: return None async def get_history_day( - self, *, session_id: str, before: date | None + self, + *, + session_id: str, + before: date | None, + visibility_mask: int | None = None, ) -> dict[str, object] | None: - del session_id, before + del session_id, before, visibility_mask return None async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: @@ -54,15 +60,42 @@ class _FakeRepository: session_id: str, content: str, metadata: AgentChatMessageMetadata | None, + visibility_mask: int, ) -> None: self.persisted_user_messages.append( { "session_id": session_id, "content": content, "metadata": metadata, + "visibility_mask": visibility_mask, } ) + async def get_system_agent_config( + self, *, agent_type: str + ) -> dict[str, object] | None: + normalized = agent_type.strip().lower() + mapping = { + "router": 16, + "worker": 17, + "memory": 18, + } + bit = mapping.get(normalized) + if bit is None: + return None + return { + "agent_type": normalized, + "status": "active", + "config": { + "temperature": 0.7, + "max_tokens": None, + "timeout_seconds": 30, + "visibility_consumer_bit": bit, + "context_messages": {"mode": "number", "count": 20}, + "enabled_tools": [], + }, + } + class _FakeQueue: def __init__(self) -> None: @@ -122,11 +155,11 @@ class _FakeAttachmentStorage: def _user() -> CurrentUser: return CurrentUser( id=UUID("00000000-0000-0000-0000-000000000001"), - email="user@example.com", + phone="+8613812345678", ) -def _build_run_input(*, urls: list[str]) -> RunAgentInput: +def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgentInput: content: list[dict[str, str]] = [{"type": "text", "text": "hello"}] for url in urls: content.append({"type": "binary", "mimeType": "image/png", "url": url}) @@ -144,7 +177,7 @@ def _build_run_input(*, urls: list[str]) -> RunAgentInput: ], "tools": [], "context": [], - "forwardedProps": {}, + "forwardedProps": {"agent_type": agent_type}, } ) @@ -222,6 +255,68 @@ async def test_enqueue_run_persists_attachment_and_queue_without_user_token( assert run_input["runId"] == "run-1" +@pytest.mark.asyncio +async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None: + monkeypatch.setattr( + agent_service_module.config.storage, "bucket", "agent-test-bucket" + ) + service = AgentService( + repository=_FakeRepository(), + queue=_FakeQueue(), + stream=_FakeStream(), + attachment_storage=_FakeAttachmentStorage(), + ) + base_url = str(config.supabase.url).rstrip("/") + safe_path = quote( + "agent-inputs/00000000-0000-0000-0000-000000000001/" + "00000000-0000-0000-0000-000000000001/uploads/a.png" + ) + run_input = _build_run_input( + urls=[ + f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1" + ], + agent_type="planner", + ) + + with pytest.raises(HTTPException) as exc_info: + await service.enqueue_run(run_input=run_input, current_user=_user()) + + assert exc_info.value.status_code == 422 + + +@pytest.mark.asyncio +async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None: + monkeypatch.setattr( + agent_service_module.config.storage, "bucket", "agent-test-bucket" + ) + repository = _FakeRepository() + service = AgentService( + repository=repository, + queue=_FakeQueue(), + stream=_FakeStream(), + attachment_storage=_FakeAttachmentStorage(), + ) + base_url = str(config.supabase.url).rstrip("/") + safe_path = quote( + "agent-inputs/00000000-0000-0000-0000-000000000001/" + "00000000-0000-0000-0000-000000000001/uploads/a.png" + ) + run_input = _build_run_input( + urls=[ + f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1" + ], + agent_type="memory", + ) + + with pytest.raises(HTTPException) as exc_info: + await service.enqueue_run(run_input=run_input, current_user=_user()) + + assert exc_info.value.status_code == 422 + assert exc_info.value.detail == "memory mode is automation-only" + assert repository.created_session_calls == 0 + assert repository.persisted_user_messages == [] + + @pytest.mark.asyncio async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None: monkeypatch.setattr( @@ -317,9 +412,13 @@ async def test_enqueue_run_rejects_too_many_attachments(monkeypatch) -> None: async def test_get_history_snapshot_filters_out_tool_messages() -> None: class _HistoryRepository(_FakeRepository): async def get_history_day( - self, *, session_id: str, before: date | None + self, + *, + session_id: str, + before: date | None, + visibility_mask: int | None = None, ) -> dict[str, object] | None: - del session_id, before + del session_id, before, visibility_mask return { "day": "2026-03-17", "hasMore": False, diff --git a/backend/tests/unit/v1/auth/test_auth_gateway.py b/backend/tests/unit/v1/auth/test_auth_gateway.py index 8c1e356..8721baf 100644 --- a/backend/tests/unit/v1/auth/test_auth_gateway.py +++ b/backend/tests/unit/v1/auth/test_auth_gateway.py @@ -8,13 +8,9 @@ from fastapi import HTTPException from v1.auth.gateway import SupabaseAuthGateway from v1.auth.schemas import ( - PasswordResetConfirmRequest, - PasswordResetRequest, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionRefreshRequest, - VerificationCreateRequest, - VerificationVerifyRequest, - VerificationResendRequest, ) @@ -35,314 +31,83 @@ class TestSupabaseAuthGateway: return SupabaseAuthGateway(), mock_client, mock_admin_client @pytest.mark.asyncio - async def test_request_password_reset_calls_email_with_string( + async def test_send_otp_sets_should_create_user( self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: sut, mock_client, _ = gateway - mock_reset_email = MagicMock() - mock_client.auth.reset_password_email = mock_reset_email + mock_sign_in_with_otp = MagicMock() + mock_client.auth.sign_in_with_otp = mock_sign_in_with_otp - request = PasswordResetRequest(email="test@example.com") - await sut.request_password_reset(request) + await sut.send_otp(OtpSendRequest(phone="+8613812345678")) - mock_reset_email.assert_called_once_with("test@example.com") - - @pytest.mark.asyncio - async def test_create_verification_maps_timeout_error_to_503( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - from supabase import AuthError - - mock_client.auth.sign_up = MagicMock( - side_effect=AuthError("request_timeout", None) - ) - - with pytest.raises(HTTPException) as exc_info: - await sut.create_verification( - VerificationCreateRequest( - username="tester", - email="test@example.com", - password="secret123", - ) - ) - - assert exc_info.value.status_code == 503 - assert exc_info.value.detail == "Auth service temporarily unavailable" - - @pytest.mark.asyncio - async def test_request_password_reset_with_redirect( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - mock_reset_email = MagicMock() - mock_client.auth.reset_password_email = mock_reset_email - - request = PasswordResetRequest( - email="test@example.com", - redirect_to="http://localhost:3000/reset-password", - ) - await sut.request_password_reset(request) - - mock_reset_email.assert_called_once_with( - "test@example.com", - options={"redirect_to": "http://localhost:3000/reset-password"}, - ) - - @pytest.mark.asyncio - async def test_create_verification_rejects_untrusted_redirect_url( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, _, _ = gateway - - with pytest.raises(HTTPException) as exc_info: - await sut.create_verification( - VerificationCreateRequest( - username="tester", - email="test@example.com", - password="secret123", - redirect_to="https://evil.example.com/callback", - ) - ) - - assert exc_info.value.status_code == 422 - assert exc_info.value.detail == "Invalid redirect URL" - - @pytest.mark.asyncio - async def test_request_password_reset_rejects_untrusted_redirect_url( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, _, _ = gateway - - with pytest.raises(HTTPException) as exc_info: - await sut.request_password_reset( - PasswordResetRequest( - email="test@example.com", - redirect_to="https://evil.example.com/reset", - ) - ) - - assert exc_info.value.status_code == 422 - assert exc_info.value.detail == "Invalid redirect URL" - - @pytest.mark.asyncio - async def test_request_password_reset_swallows_auth_error( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - from supabase import AuthError - - mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None)) - mock_client.auth.reset_password_email = mock_reset_email - - request = PasswordResetRequest(email="test@example.com") - - result = await sut.request_password_reset(request) - - mock_reset_email.assert_called_once() - assert result is None - - @pytest.mark.asyncio - async def test_request_password_reset_extracts_email_from_mapping( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - mock_reset_email = MagicMock() - mock_client.auth.reset_password_email = mock_reset_email - - request = PasswordResetRequest.model_construct( - email={"email": "test@example.com"}, - redirect_to=None, - ) - - await sut.request_password_reset(request) - - mock_reset_email.assert_called_once_with("test@example.com") - - @pytest.mark.asyncio - async def test_request_password_reset_rejects_invalid_email_shape( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, _, _ = gateway - request = PasswordResetRequest.model_construct( - email={"unexpected": "value"}, - redirect_to=None, - ) - - with pytest.raises(HTTPException) as exc_info: - await sut.request_password_reset(request) - - assert exc_info.value.status_code == 422 - assert exc_info.value.detail == "Invalid email" - - @pytest.mark.asyncio - async def test_confirm_password_reset_updates_password_by_user_id( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, mock_admin_client = gateway - verify_response = SimpleNamespace( - session=SimpleNamespace(access_token="access"), - user=SimpleNamespace(id="user-1"), - ) - mock_verify_otp = MagicMock(return_value=verify_response) - mock_client.auth.verify_otp = mock_verify_otp - - mock_update_user_by_id = MagicMock() - mock_admin_client.auth.admin = SimpleNamespace( - update_user_by_id=mock_update_user_by_id - ) - - request = PasswordResetConfirmRequest( - email="test@example.com", - token="123456", - new_password="newpassword123", - ) - - await sut.confirm_password_reset(request) - - mock_verify_otp.assert_called_once_with( + mock_sign_in_with_otp.assert_called_once_with( { - "type": "recovery", - "email": "test@example.com", - "token": "123456", + "phone": "+8613812345678", + "options": {"should_create_user": True}, } ) - mock_update_user_by_id.assert_called_once_with( - "user-1", - {"password": "newpassword123"}, - ) @pytest.mark.asyncio - async def test_confirm_password_reset_raises_when_user_id_missing( + async def test_create_phone_session_uses_verify_otp( self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: sut, mock_client, _ = gateway verify_response = SimpleNamespace( - session=SimpleNamespace(access_token="access"), - user=SimpleNamespace(id=""), + session=SimpleNamespace( + access_token="access", + refresh_token="refresh", + expires_in=3600, + token_type="bearer", + ), + user=SimpleNamespace(id="user-1", phone="+8613812345678"), ) mock_client.auth.verify_otp = MagicMock(return_value=verify_response) - request = PasswordResetConfirmRequest( - email="test@example.com", - token="123456", - new_password="newpassword123", + response = await sut.create_phone_session( + PhoneSessionCreateRequest(phone="+8613812345678", token="123456") + ) + + assert response.user.id == "user-1" + assert response.access_token == "access" + + @pytest.mark.asyncio + async def test_create_phone_session_normalizes_phone_without_plus_prefix( + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] + ) -> None: + sut, mock_client, _ = gateway + verify_response = SimpleNamespace( + session=SimpleNamespace( + access_token="access", + refresh_token="refresh", + expires_in=3600, + token_type="bearer", + ), + user=SimpleNamespace(id="user-1", phone="14155552671"), + ) + mock_client.auth.verify_otp = MagicMock(return_value=verify_response) + + response = await sut.create_phone_session( + PhoneSessionCreateRequest(phone="+14155552671", token="123456") + ) + + assert response.user.phone == "+14155552671" + + @pytest.mark.asyncio + async def test_refresh_session_maps_invalid_token( + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] + ) -> None: + sut, mock_client, _ = gateway + mock_client.auth.refresh_session = MagicMock( + return_value=SimpleNamespace(session=None, user=None) ) with pytest.raises(HTTPException) as exc_info: - await sut.confirm_password_reset(request) + await sut.refresh_session(SessionRefreshRequest(refresh_token="bad")) assert exc_info.value.status_code == 401 - assert exc_info.value.detail == "Invalid or expired verification code" @pytest.mark.asyncio - async def test_recovery_resend_calls_reset_password_email( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - mock_reset_email = MagicMock() - mock_client.auth.reset_password_email = mock_reset_email - - await sut.resend_verification( - VerificationResendRequest( - type="recovery", - email="test@example.com", - redirect_to="http://localhost:3000/reset-password", - ) - ) - - mock_reset_email.assert_called_once_with( - "test@example.com", - options={"redirect_to": "http://localhost:3000/reset-password"}, - ) - - @pytest.mark.asyncio - async def test_verify_verification_maps_internal_error_to_503( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - from supabase import AuthError - - mock_client.auth.verify_otp = MagicMock( - side_effect=AuthError("internal_server_error", None) - ) - - with pytest.raises(HTTPException) as exc_info: - await sut.verify_verification( - VerificationVerifyRequest( - type="signup", - email="test@example.com", - token="123456", - ) - ) - - assert exc_info.value.status_code == 503 - assert exc_info.value.detail == "Auth service temporarily unavailable" - - @pytest.mark.asyncio - async def test_create_session_maps_internal_error_to_503( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - from supabase import AuthError - - mock_client.auth.sign_in_with_password = MagicMock( - side_effect=AuthError("internal_server_error", None) - ) - - with pytest.raises(HTTPException) as exc_info: - await sut.create_session( - SessionCreateRequest( - email="test@example.com", - password="secret123", - ) - ) - - assert exc_info.value.status_code == 503 - assert exc_info.value.detail == "Auth service temporarily unavailable" - - @pytest.mark.asyncio - async def test_refresh_session_maps_bad_gateway_to_503( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - from supabase import AuthError - - mock_client.auth.refresh_session = MagicMock( - side_effect=AuthError("bad_gateway", None) - ) - - with pytest.raises(HTTPException) as exc_info: - await sut.refresh_session(SessionRefreshRequest(refresh_token="rt")) - - assert exc_info.value.status_code == 503 - assert exc_info.value.detail == "Auth service temporarily unavailable" - - @pytest.mark.asyncio - async def test_confirm_password_reset_maps_service_unavailable_to_503( - self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] - ) -> None: - sut, mock_client, _ = gateway - from supabase import AuthError - - mock_client.auth.verify_otp = MagicMock( - side_effect=AuthError("service_unavailable", None) - ) - - with pytest.raises(HTTPException) as exc_info: - await sut.confirm_password_reset( - PasswordResetConfirmRequest( - email="test@example.com", - token="123456", - new_password="newpassword123", - ) - ) - - assert exc_info.value.status_code == 503 - assert exc_info.value.detail == "Auth service temporarily unavailable" - - @pytest.mark.asyncio - async def test_get_user_by_email_uses_in_memory_cache( + async def test_get_user_by_phone_uses_in_memory_cache( self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock], monkeypatch: pytest.MonkeyPatch, @@ -350,9 +115,9 @@ class TestSupabaseAuthGateway: sut, _, _ = gateway user = SimpleNamespace( id="user-1", - email="cached@example.com", + phone="+8613811112222", created_at="2026-03-16T00:00:00Z", - email_confirmed_at=None, + phone_confirmed_at=None, ) list_calls = {"count": 0} @@ -362,9 +127,39 @@ class TestSupabaseAuthGateway: monkeypatch.setattr("v1.auth.gateway._list_auth_users", _fake_list_auth_users) - first = await sut.get_user_by_email("cached@example.com") - second = await sut.get_user_by_email("CACHED@example.com") + first = await sut.get_user_by_phone("+8613811112222") + second = await sut.get_user_by_phone("+8613811112222") assert first.id == "user-1" - assert second.email == "cached@example.com" + assert second.phone == "+8613811112222" assert list_calls["count"] == 1 + + @pytest.mark.asyncio + async def test_search_user_ids_by_phone_supports_suffix_query( + self, + gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock], + monkeypatch: pytest.MonkeyPatch, + ) -> None: + sut, _, _ = gateway + users = [ + SimpleNamespace( + id="user-cn", + phone="+8613811112222", + created_at="2026-03-16T00:00:00Z", + phone_confirmed_at=None, + ), + SimpleNamespace( + id="user-us", + phone="+14155552671", + created_at="2026-03-16T00:00:00Z", + phone_confirmed_at=None, + ), + ] + + monkeypatch.setattr("v1.auth.gateway._list_auth_users", lambda _client: users) + + matched_cn = await sut.search_user_ids_by_phone("13811112222") + matched_us = await sut.search_user_ids_by_phone("4155552671") + + assert matched_cn == ["user-cn"] + assert matched_us == ["user-us"] diff --git a/backend/tests/unit/v1/auth/test_auth_models.py b/backend/tests/unit/v1/auth/test_auth_models.py index a290329..81cab02 100644 --- a/backend/tests/unit/v1/auth/test_auth_models.py +++ b/backend/tests/unit/v1/auth/test_auth_models.py @@ -5,72 +5,28 @@ from pydantic import ValidationError from v1.auth.schemas import ( AuthUser, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, + SessionDeleteRequest, SessionRefreshRequest, SessionResponse, - VerificationCreateRequest, - VerificationVerifyRequest, - VerificationResendRequest, ) -def test_signup_requires_valid_email() -> None: +def test_send_otp_requires_valid_phone() -> None: with pytest.raises(ValidationError): - VerificationCreateRequest( - username="demo", email="not-an-email", password="secret123" - ) + OtpSendRequest(phone="13812345678") -def test_signup_requires_username() -> None: +def test_send_otp_accepts_e164_phone() -> None: + request = OtpSendRequest(phone="+14155552671") + + assert request.phone == "+14155552671" + + +def test_phone_session_requires_six_digit_token() -> None: with pytest.raises(ValidationError): - VerificationCreateRequest.model_validate( - {"email": "user@example.com", "password": "secret123"} - ) - - -def test_signup_allows_any_invite_code_input() -> None: - request = VerificationCreateRequest( - username="demo", - email="user@example.com", - password="secret123", - invite_code="abc123", - ) - - assert request.invite_code == "abc123" - - -def test_signup_verify_requires_six_digit_token() -> None: - with pytest.raises(ValidationError): - VerificationVerifyRequest(email="user@example.com", token="abc123") - - -def test_signup_verify_disallows_new_password() -> None: - with pytest.raises(ValidationError): - VerificationVerifyRequest( - type="signup", - email="user@example.com", - token="123456", - new_password="secret123", - ) - - -def test_recovery_verify_requires_new_password() -> None: - with pytest.raises(ValidationError): - VerificationVerifyRequest( - type="recovery", - email="user@example.com", - token="123456", - ) - - -def test_signup_resend_requires_valid_email() -> None: - with pytest.raises(ValidationError): - VerificationResendRequest(email="invalid") - - -def test_login_requires_valid_email() -> None: - with pytest.raises(ValidationError): - SessionCreateRequest(email="invalid", password="secret123") + PhoneSessionCreateRequest(phone="+8613812345678", token="abc123") def test_refresh_requires_token() -> None: @@ -78,8 +34,13 @@ def test_refresh_requires_token() -> None: SessionRefreshRequest(refresh_token="") +def test_logout_requires_token() -> None: + with pytest.raises(ValidationError): + SessionDeleteRequest(refresh_token="") + + def test_session_response_maps_user() -> None: - user = AuthUser(id="user-1", email="user@example.com") + user = AuthUser(id="user-1", phone="+14155552671") response = SessionResponse( access_token="access", refresh_token="refresh", @@ -89,4 +50,4 @@ def test_session_response_maps_user() -> None: ) assert response.user.id == "user-1" - assert response.user.email == "user@example.com" + assert response.user.phone == "+14155552671" diff --git a/backend/tests/unit/v1/auth/test_auth_service.py b/backend/tests/unit/v1/auth/test_auth_service.py index 85976da..7daeae2 100644 --- a/backend/tests/unit/v1/auth/test_auth_service.py +++ b/backend/tests/unit/v1/auth/test_auth_service.py @@ -2,19 +2,12 @@ from __future__ import annotations import pytest -import v1.auth.gateway as auth_gateway_module from v1.auth.schemas import ( AuthUser, - PasswordResetConfirmRequest, - PasswordResetRequest, - SessionCreateRequest, + OtpSendRequest, + PhoneSessionCreateRequest, SessionRefreshRequest, SessionResponse, - UserByEmailResponse, - VerificationCreateRequest, - VerificationCreateResponse, - VerificationResendRequest, - VerificationVerifyRequest, ) from v1.auth.service import AuthService, AuthServiceGateway @@ -22,23 +15,16 @@ from v1.auth.service import AuthService, AuthServiceGateway class FakeGateway(AuthServiceGateway): def __init__(self, response: SessionResponse) -> None: self._response = response - self.last_create_verification_request: VerificationCreateRequest | None = None + self.last_send_otp_request: OtpSendRequest | None = None + self.last_phone_session_request: PhoneSessionCreateRequest | None = None - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: - self.last_create_verification_request = request - return VerificationCreateResponse(email=request.email) + async def send_otp(self, request: OtpSendRequest) -> None: + self.last_send_otp_request = request - async def verify_verification( - self, request: VerificationVerifyRequest + async def create_phone_session( + self, request: PhoneSessionCreateRequest ) -> SessionResponse: - return self._response - - async def resend_verification(self, request: VerificationResendRequest) -> None: - return None - - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: + self.last_phone_session_request = request return self._response async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: @@ -47,85 +33,10 @@ class FakeGateway(AuthServiceGateway): async def delete_session(self, refresh_token: str | None) -> None: return None - async def get_user_by_email(self, email: str) -> UserByEmailResponse: - raise NotImplementedError - - async def request_password_reset(self, request: PasswordResetRequest) -> None: - raise NotImplementedError - - async def confirm_password_reset( - self, request: PasswordResetConfirmRequest - ) -> None: - raise NotImplementedError - - -class LogoutAssertingGateway(AuthServiceGateway): - def __init__(self, expected_refresh_token: str) -> None: - self._expected_refresh_token = expected_refresh_token - - async def create_verification( - self, request: VerificationCreateRequest - ) -> VerificationCreateResponse: - raise NotImplementedError - - async def verify_verification( - self, request: VerificationVerifyRequest - ) -> SessionResponse: - raise NotImplementedError - - async def resend_verification(self, request: VerificationResendRequest) -> None: - raise NotImplementedError - - async def create_session(self, request: SessionCreateRequest) -> SessionResponse: - raise NotImplementedError - - async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: - raise NotImplementedError - - async def delete_session(self, refresh_token: str | None) -> None: - assert refresh_token == self._expected_refresh_token - - async def get_user_by_email(self, email: str) -> UserByEmailResponse: - raise NotImplementedError - - async def request_password_reset(self, request: PasswordResetRequest) -> None: - raise NotImplementedError - - async def confirm_password_reset( - self, request: PasswordResetConfirmRequest - ) -> None: - raise NotImplementedError - @pytest.mark.asyncio -async def test_logout_forwards_refresh_token() -> None: - service = AuthService(gateway=LogoutAssertingGateway("refresh-token")) - - await service.delete_session("refresh-token") - - -@pytest.mark.asyncio -async def test_signup_resend_returns_none() -> None: - user = AuthUser(id="user-1", email="user@example.com") - token_response = SessionResponse( - access_token="access", - refresh_token="refresh", - expires_in=3600, - token_type="bearer", - user=user, - ) - service = AuthService(gateway=FakeGateway(token_response)) - - result = await service.resend_verification( - VerificationResendRequest(email="user@example.com") - ) - - assert result is None - - -@pytest.mark.asyncio -async def test_create_verification_ignores_invalid_invite_code() -> None: - user = AuthUser(id="user-1", email="user@example.com") +async def test_send_otp_forwards_payload() -> None: + user = AuthUser(id="user-1", phone="+8613812345678") token_response = SessionResponse( access_token="access", refresh_token="refresh", @@ -136,22 +47,15 @@ async def test_create_verification_ignores_invalid_invite_code() -> None: gateway = FakeGateway(token_response) service = AuthService(gateway=gateway) - await service.create_verification( - VerificationCreateRequest( - username="demo", - email="user@example.com", - password="secret123", - invite_code="bad-code", - ) - ) + await service.send_otp(OtpSendRequest(phone="+8613812345678")) - assert gateway.last_create_verification_request is not None - assert gateway.last_create_verification_request.invite_code is None + assert gateway.last_send_otp_request is not None + assert gateway.last_send_otp_request.phone == "+8613812345678" @pytest.mark.asyncio -async def test_create_verification_normalizes_valid_invite_code() -> None: - user = AuthUser(id="user-1", email="user@example.com") +async def test_create_phone_session_forwards_payload() -> None: + user = AuthUser(id="user-1", phone="+8613812345678") token_response = SessionResponse( access_token="access", refresh_token="refresh", @@ -162,59 +66,28 @@ async def test_create_verification_normalizes_valid_invite_code() -> None: gateway = FakeGateway(token_response) service = AuthService(gateway=gateway) - await service.create_verification( - VerificationCreateRequest( - username="demo", - email="user@example.com", - password="secret123", - invite_code="a2b3", - ) + response = await service.create_phone_session( + PhoneSessionCreateRequest(phone="+8613812345678", token="123456") ) - assert gateway.last_create_verification_request is not None - assert gateway.last_create_verification_request.invite_code == "A2B3" + assert gateway.last_phone_session_request is not None + assert gateway.last_phone_session_request.token == "123456" + assert response.user.phone == "+8613812345678" @pytest.mark.asyncio -async def test_supabase_signup_passes_username_in_metadata( - monkeypatch: pytest.MonkeyPatch, -) -> None: - captured_payload: dict[str, object] = {} - - class FakeSupabaseAuth: - def sign_up(self, payload: dict[str, object]) -> object: - captured_payload.update(payload) - - class _User: - id = "user-1" - email = "user@example.com" - - class _Session: - access_token = "access" - refresh_token = "refresh" - expires_in = 3600 - token_type = "bearer" - - class _Response: - user = _User() - session = None - - return _Response() - - class FakeClient: - auth = FakeSupabaseAuth() - - monkeypatch.setattr( - auth_gateway_module.supabase_service, "get_client", lambda: FakeClient() +async def test_refresh_session_forwards_payload() -> None: + user = AuthUser(id="user-1", phone="+8613812345678") + token_response = SessionResponse( + access_token="access", + refresh_token="refresh", + expires_in=3600, + token_type="bearer", + user=user, ) + gateway = FakeGateway(token_response) + service = AuthService(gateway=gateway) - gateway = auth_gateway_module.SupabaseAuthGateway() - await gateway.create_verification( - VerificationCreateRequest( - username="demo", - email="user@example.com", - password="secret123", - ) - ) + response = await service.refresh_session(SessionRefreshRequest(refresh_token="rt")) - assert captured_payload["data"] == {"username": "demo"} + assert response.access_token == "access" diff --git a/backend/tests/unit/v1/friendships/test_schemas.py b/backend/tests/unit/v1/friendships/test_schemas.py index 940f987..df1b0a2 100644 --- a/backend/tests/unit/v1/friendships/test_schemas.py +++ b/backend/tests/unit/v1/friendships/test_schemas.py @@ -1,11 +1,13 @@ from __future__ import annotations -import pytest from datetime import datetime from uuid import uuid4 +import pytest +from pydantic import ValidationError +from schemas.user.context import UserContext + from v1.friendships.schemas import ( - UserBasicInfo, FriendRequestCreate, FriendRequestResponse, FriendResponse, @@ -13,16 +15,16 @@ from v1.friendships.schemas import ( ) -def test_user_basic_info_maps_fields() -> None: - user = UserBasicInfo(id="user-1", username="alice", avatar_url=None) +def test_user_context_maps_fields() -> None: + user = UserContext(id="user-1", username="alice", avatar_url=None) assert user.id == "user-1" assert user.username == "alice" assert user.avatar_url is None -def test_user_basic_info_with_avatar() -> None: - user = UserBasicInfo( +def test_user_context_with_avatar() -> None: + user = UserContext( id="user-2", username="bob", avatar_url="https://example.com/avatar.png" ) @@ -49,13 +51,13 @@ def test_friend_request_create_without_content() -> None: def test_friend_request_create_content_max_length() -> None: target_id = uuid4() - with pytest.raises(Exception): + with pytest.raises(ValidationError): FriendRequestCreate(target_user_id=target_id, content="x" * 201) def test_friend_request_response_maps_fields() -> None: - sender = UserBasicInfo(id="user-1", username="alice", avatar_url=None) - recipient = UserBasicInfo(id="user-2", username="bob", avatar_url=None) + sender = UserContext(id="user-1", username="alice", avatar_url=None) + recipient = UserContext(id="user-2", username="bob", avatar_url=None) request_id = uuid4() created = datetime(2026, 1, 15, 10, 30, 0) @@ -63,7 +65,7 @@ def test_friend_request_response_maps_fields() -> None: id=request_id, sender=sender, recipient=recipient, - content="Hello!", + content={"text": "Hello!"}, status="pending", created_at=created, ) @@ -76,7 +78,7 @@ def test_friend_request_response_maps_fields() -> None: def test_friend_response_maps_fields() -> None: - friend_user = UserBasicInfo(id="user-2", username="bob", avatar_url=None) + friend_user = UserContext(id="user-2", username="bob", avatar_url=None) request_id = uuid4() created = datetime(2026, 1, 15, 10, 30, 0) accepted = datetime(2026, 1, 16, 12, 0, 0) @@ -96,7 +98,7 @@ def test_friend_response_maps_fields() -> None: def test_friend_response_accepted_at_optional() -> None: - friend_user = UserBasicInfo(id="user-2", username="bob", avatar_url=None) + friend_user = UserContext(id="user-2", username="bob", avatar_url=None) request_id = uuid4() created = datetime(2026, 1, 15, 10, 30, 0) diff --git a/backend/tests/unit/v1/schedule_items/test_share.py b/backend/tests/unit/v1/schedule_items/test_share.py index d6065a4..c2dd026 100644 --- a/backend/tests/unit/v1/schedule_items/test_share.py +++ b/backend/tests/unit/v1/schedule_items/test_share.py @@ -12,7 +12,7 @@ from sqlalchemy.exc import SQLAlchemyError from core.auth.models import CurrentUser from models.inbox_messages import InboxMessage, InboxMessageType from models.schedule_items import ScheduleItem -from v1.auth.schemas import UserByEmailResponse +from v1.auth.schemas import UserByPhoneResponse from v1.schedule_items.repository import ScheduleItemRepository from v1.schedule_items.schemas import ScheduleItemShareRequest from v1.schedule_items.service import ScheduleItemService @@ -20,18 +20,18 @@ from v1.schedule_items.service import ScheduleItemService def test_share_request_schema() -> None: request = ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=True, permission_invite=False, ) - assert request.email == "friend@example.com" + assert request.phone == "+8613810000000" assert request.permission_view is True def test_permission_bits_calculation() -> None: request = ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=True, permission_invite=False, @@ -71,12 +71,12 @@ class ShareRepo: class AuthGatewayStub: - async def get_user_by_email(self, email: str) -> UserByEmailResponse: - return UserByEmailResponse( + async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse: + return UserByPhoneResponse( id="00000000-0000-0000-0000-000000000222", - email=email, + phone=phone, created_at="2026-02-28T10:00:00Z", - email_confirmed_at=None, + phone_confirmed_at=None, ) @@ -119,12 +119,12 @@ class InboxRepoStub: class AuthGatewayInvalidIdStub: - async def get_user_by_email(self, email: str) -> UserByEmailResponse: - return UserByEmailResponse( + async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse: + return UserByPhoneResponse( id="not-a-uuid", - email=email, + phone=phone, created_at="2026-02-28T10:00:00Z", - email_confirmed_at=None, + phone_confirmed_at=None, ) @@ -148,7 +148,7 @@ async def test_share_forbidden_when_not_owner() -> None: await service.share( item_id, ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=False, permission_invite=False, @@ -178,7 +178,7 @@ async def test_share_success_creates_calendar_invitation_message() -> None: result = await service.share( item_id, ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=True, permission_invite=False, @@ -211,7 +211,7 @@ async def test_share_returns_not_found_when_item_missing() -> None: await service.share( uuid4(), ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=False, permission_invite=False, @@ -241,7 +241,7 @@ async def test_share_invalid_auth_user_id_returns_503() -> None: await service.share( item_id, ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=False, permission_invite=False, @@ -274,7 +274,7 @@ async def test_share_sqlalchemy_error_rolls_back() -> None: await service.share( item_id, ScheduleItemShareRequest( - email="friend@example.com", + phone="+8613810000000", permission_view=True, permission_edit=False, permission_invite=False, diff --git a/backend/tests/unit/v1/users/test_dependencies.py b/backend/tests/unit/v1/users/test_dependencies.py index d2328e7..9a35012 100644 --- a/backend/tests/unit/v1/users/test_dependencies.py +++ b/backend/tests/unit/v1/users/test_dependencies.py @@ -22,7 +22,7 @@ async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) - del token return deps.CurrentUser( id=UUID("e8845a17-282b-4a63-8025-194a06235958"), - email="dagronl@126.com", + phone="dagronl@126.com", role="authenticated", ) @@ -31,7 +31,7 @@ async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) - user = await deps.get_current_user(authorization="Bearer valid-token") assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958" - assert user.email == "dagronl@126.com" + assert user.phone == "dagronl@126.com" @pytest.mark.asyncio diff --git a/backend/tests/unit/v1/users/test_user_service.py b/backend/tests/unit/v1/users/test_user_service.py index 989e1fd..c616c79 100644 --- a/backend/tests/unit/v1/users/test_user_service.py +++ b/backend/tests/unit/v1/users/test_user_service.py @@ -6,7 +6,7 @@ from uuid import uuid4 import pytest from core.auth.models import CurrentUser -from v1.users.schemas import UserUpdateRequest +from v1.users.schemas import UserSearchRequest, UserUpdateRequest from v1.users.service import UserService @@ -16,6 +16,7 @@ class _FakeProfile: username: str avatar_url: str | None bio: str | None + settings: dict | None = None class _FakeRepository: @@ -51,6 +52,37 @@ class _FakeSession: self.rollback_called += 1 +class _FakeSearchRepository: + def __init__(self, profiles: list[_FakeProfile]) -> None: + self._profiles_by_id = {profile.id: profile for profile in profiles} + + async def get_by_user_ids( + self, user_ids: list[object] + ) -> dict[object, _FakeProfile]: + return { + user_id: self._profiles_by_id[user_id] + for user_id in user_ids + if user_id in self._profiles_by_id + } + + async def search_users(self, query: str, limit: int = 20) -> list[_FakeProfile]: + _ = limit + return [ + profile + for profile in self._profiles_by_id.values() + if query.lower() in profile.username.lower() + ] + + +class _FakeAuthLookup: + def __init__(self, mapping: dict[str, list[str]]) -> None: + self.mapping = mapping + + async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]: + _ = limit + return self.mapping.get(query, []) + + class _FakeUserContextCache: def __init__(self, *, should_fail: bool = False) -> None: self.should_fail = should_fail @@ -72,7 +104,7 @@ async def test_update_me_invalidates_user_context_cache() -> None: session = _FakeSession() cache = _FakeUserContextCache() service = UserService( - repository=repo, + repository=repo, # type: ignore[arg-type] session=session, # type: ignore[arg-type] current_user=CurrentUser(id=user_id), user_context_cache=cache, # type: ignore[arg-type] @@ -94,7 +126,7 @@ async def test_update_me_succeeds_when_cache_invalidation_fails() -> None: session = _FakeSession() cache = _FakeUserContextCache(should_fail=True) service = UserService( - repository=repo, + repository=repo, # type: ignore[arg-type] session=session, # type: ignore[arg-type] current_user=CurrentUser(id=user_id), user_context_cache=cache, # type: ignore[arg-type] @@ -105,3 +137,59 @@ async def test_update_me_succeeds_when_cache_invalidation_fails() -> None: assert result.username == "new-name" assert session.commit_called == 1 assert cache.invalidated_user_ids == [user_id] + + +@pytest.mark.asyncio +async def test_search_users_supports_phone_without_country_code() -> None: + user_id = uuid4() + repo = _FakeSearchRepository( + [ + _FakeProfile( + id=user_id, + username="alice", + avatar_url=None, + bio=None, + ) + ] + ) + session = _FakeSession() + auth_lookup = _FakeAuthLookup({"13812345678": [str(user_id)]}) + service = UserService( + repository=repo, # type: ignore[arg-type] + session=session, # type: ignore[arg-type] + current_user=CurrentUser(id=user_id), + auth_gateway=auth_lookup, # type: ignore[arg-type] + ) + + results = await service.search_users(UserSearchRequest(query="13812345678")) + + assert len(results) == 1 + assert results[0].id == str(user_id) + + +@pytest.mark.asyncio +async def test_search_users_preserves_numeric_username_lookup() -> None: + user_id = uuid4() + repo = _FakeSearchRepository( + [ + _FakeProfile( + id=user_id, + username="20260319", + avatar_url=None, + bio=None, + ) + ] + ) + session = _FakeSession() + auth_lookup = _FakeAuthLookup({}) + service = UserService( + repository=repo, # type: ignore[arg-type] + session=session, # type: ignore[arg-type] + current_user=CurrentUser(id=user_id), + auth_gateway=auth_lookup, # type: ignore[arg-type] + ) + + results = await service.search_users(UserSearchRequest(query="20260319")) + + assert len(results) == 1 + assert results[0].username == "20260319"