refactor(backend): update API routes and service layer
- Update agent router/service/repository with new endpoints - Update auth routes with phone-based authentication - Update users service with new phone lookup - Update schedule_items with new schemas - Update message schemas with visibility support - Update settings with new automation scheduler config - Update CLI with new commands - Update tests to match new API contracts
This commit is contained in:
@@ -7,5 +7,5 @@ from uuid import UUID
|
||||
@dataclass(frozen=True)
|
||||
class CurrentUser:
|
||||
id: UUID
|
||||
email: str | None = None
|
||||
phone: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -61,6 +61,7 @@ class RuntimeSettings(BaseModel):
|
||||
]
|
||||
)
|
||||
sql_log_queries: bool = False
|
||||
trusted_proxy_ips: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("log_dir", mode="before")
|
||||
@classmethod
|
||||
@@ -162,6 +163,12 @@ class AgentRuntimeSettings(BaseModel):
|
||||
user_context_cache_max_turns: int = Field(default=6, ge=1, le=100)
|
||||
|
||||
|
||||
class AutomationSchedulerSettings(BaseModel):
|
||||
enabled: bool = True
|
||||
interval_seconds: int = Field(default=60, ge=5, le=3600)
|
||||
batch_limit: int = Field(default=100, ge=1, le=1000)
|
||||
|
||||
|
||||
class LlmSettings(BaseModel):
|
||||
provider_keys: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@@ -225,7 +232,7 @@ class AppVersionSettings(BaseModel):
|
||||
|
||||
|
||||
class TestSettings(BaseModel):
|
||||
email: str = ""
|
||||
phone: str = ""
|
||||
password: str = ""
|
||||
|
||||
|
||||
@@ -250,6 +257,7 @@ class Settings(BaseSettings):
|
||||
llm: LlmSettings = LlmSettings()
|
||||
litellm: LiteLLMSettings = LiteLLMSettings()
|
||||
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
|
||||
automation_scheduler: AutomationSchedulerSettings = AutomationSchedulerSettings()
|
||||
taskiq: TaskiqSettings = TaskiqSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
app_version: AppVersionSettings = AppVersionSettings()
|
||||
|
||||
@@ -5,6 +5,8 @@ import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from core.agentscope.runtime.tasks import run_automation_scheduler_scan
|
||||
from core.config.settings import config
|
||||
from core.config.initial.init_data import initialize_data
|
||||
from core.logging import get_logger
|
||||
|
||||
@@ -101,12 +103,34 @@ async def bootstrap() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def run_automation_scheduler_forever() -> bool:
|
||||
if not config.automation_scheduler.enabled:
|
||||
logger.info("Automation scheduler disabled by config")
|
||||
return True
|
||||
|
||||
interval_seconds = int(config.automation_scheduler.interval_seconds)
|
||||
batch_limit = int(config.automation_scheduler.batch_limit)
|
||||
logger.info(
|
||||
"Starting automation scheduler loop",
|
||||
interval_seconds=interval_seconds,
|
||||
batch_limit=batch_limit,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
await run_automation_scheduler_scan(limit=batch_limit)
|
||||
except Exception as exc:
|
||||
logger.exception("Automation scheduler scan failed", error=str(exc))
|
||||
await asyncio.sleep(interval_seconds)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""CLI entry point."""
|
||||
if len(sys.argv) < 2:
|
||||
logger.error("No command provided")
|
||||
logger.info("Usage: python -m core.runtime.cli <command>")
|
||||
logger.info("Available commands: migrate, init-data, bootstrap")
|
||||
logger.info(
|
||||
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
|
||||
)
|
||||
return 1
|
||||
|
||||
command = sys.argv[1]
|
||||
@@ -117,9 +141,13 @@ def main() -> int:
|
||||
success = asyncio.run(run_init_data())
|
||||
elif command == "bootstrap":
|
||||
success = asyncio.run(bootstrap())
|
||||
elif command == "automation-scheduler":
|
||||
success = asyncio.run(run_automation_scheduler_forever())
|
||||
else:
|
||||
logger.error("Unknown command", command=command)
|
||||
logger.info("Available commands: migrate, init-data, bootstrap")
|
||||
logger.info(
|
||||
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
|
||||
)
|
||||
return 1
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from schemas.agent.runtime_models import AgentOutput
|
||||
from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput
|
||||
|
||||
from ..agent import AgentType, ToolAgentOutput
|
||||
|
||||
@@ -24,6 +24,7 @@ class AgentChatMessageMetadata(BaseModel):
|
||||
run_id: str
|
||||
agent_type: AgentType | None = None
|
||||
user_message_attachments: list[UserMessageAttachment] | None = None
|
||||
router_agent_output: RouterAgentOutput | None = None
|
||||
tool_agent_output: ToolAgentOutput | None = None
|
||||
agent_output: AgentOutput | None = None
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class UserContext(BaseModel):
|
||||
|
||||
id: str
|
||||
username: str
|
||||
email: str | None = None
|
||||
phone: str | None = None
|
||||
avatar_url: str | None = None
|
||||
bio: str | None = None
|
||||
settings: ProfileSettingsUnion | None = None
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
@@ -95,6 +95,7 @@ class AgentRepository:
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None:
|
||||
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
|
||||
|
||||
@@ -124,6 +125,7 @@ class AgentRepository:
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=content,
|
||||
visibility_mask=max(int(visibility_mask), 0),
|
||||
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
|
||||
)
|
||||
self._session.add(message)
|
||||
@@ -132,7 +134,11 @@ class AgentRepository:
|
||||
await self._session.flush()
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
@@ -152,6 +158,10 @@ class AgentRepository:
|
||||
.order_by(AgentChatMessage.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
target_created_at_stmt = self._apply_visibility_filter(
|
||||
stmt=target_created_at_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if before_start is not None:
|
||||
target_created_at_stmt = target_created_at_stmt.where(
|
||||
AgentChatMessage.created_at < before_start
|
||||
@@ -175,6 +185,10 @@ class AgentRepository:
|
||||
.where(AgentChatMessage.created_at < end)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
message_stmt = self._apply_visibility_filter(
|
||||
stmt=message_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more_stmt = (
|
||||
select(AgentChatMessage.id)
|
||||
@@ -183,6 +197,10 @@ class AgentRepository:
|
||||
.where(AgentChatMessage.created_at < start)
|
||||
.limit(1)
|
||||
)
|
||||
has_more_stmt = self._apply_visibility_filter(
|
||||
stmt=has_more_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
has_more = (
|
||||
await self._session.execute(has_more_stmt)
|
||||
).scalar_one_or_none() is not None
|
||||
@@ -196,7 +214,11 @@ class AgentRepository:
|
||||
}
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
self, *, session_id: str, user_message_limit: int
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
@@ -210,6 +232,10 @@ class AgentRepository:
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.seq.desc())
|
||||
)
|
||||
message_stmt = self._apply_visibility_filter(
|
||||
stmt=message_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
|
||||
if not messages_desc:
|
||||
return []
|
||||
@@ -294,6 +320,21 @@ class AgentRepository:
|
||||
)
|
||||
return payload_model.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
def _apply_visibility_filter(
|
||||
self,
|
||||
*,
|
||||
stmt: Select,
|
||||
visibility_mask: int | None,
|
||||
) -> Select:
|
||||
if visibility_mask is None:
|
||||
return stmt
|
||||
required_mask = max(int(visibility_mask), 0)
|
||||
if required_mask == 0:
|
||||
return stmt
|
||||
return stmt.where(
|
||||
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
|
||||
)
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
return isinstance(title, str) and bool(title.strip())
|
||||
|
||||
@@ -47,6 +47,7 @@ logger = get_logger("v1.agent.router")
|
||||
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
@@ -72,8 +73,14 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if int(ttl) < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
await redis.decr(key)
|
||||
after_decr = await redis.decr(key)
|
||||
if int(after_decr) <= 0:
|
||||
await redis.delete(key)
|
||||
return False
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -82,7 +89,7 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
user_id=user_id,
|
||||
reason=str(exc),
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _release_sse_slot(*, user_id: str) -> None:
|
||||
@@ -92,10 +99,21 @@ async def _release_sse_slot(*, user_id: str) -> None:
|
||||
count = await redis.decr(key)
|
||||
if int(count) <= 0:
|
||||
await redis.delete(key)
|
||||
return None
|
||||
ttl = await redis.ttl(key)
|
||||
if int(ttl) < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
|
||||
|
||||
def _is_terminal_run_event(event: dict[str, object]) -> bool:
|
||||
raw_event_type = event.get("type")
|
||||
return (
|
||||
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
@@ -145,8 +163,13 @@ async def stream_events(
|
||||
async def _event_iter() -> AsyncIterator[str]:
|
||||
cursor = last_event_id
|
||||
idle_polls = 0
|
||||
terminal_event_reached = False
|
||||
try:
|
||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||
while (
|
||||
not terminal_event_reached
|
||||
and not await request.is_disconnected()
|
||||
and idle_polls < idle_limit
|
||||
):
|
||||
try:
|
||||
rows = await service.stream_events(
|
||||
thread_id=thread_id,
|
||||
@@ -181,6 +204,9 @@ async def stream_events(
|
||||
continue
|
||||
cursor = row_id
|
||||
yield to_sse_event(row_id, event)
|
||||
if _is_terminal_run_event(event):
|
||||
terminal_event_reached = True
|
||||
break
|
||||
|
||||
finally:
|
||||
await _release_sse_slot(user_id=str(current_user.id))
|
||||
|
||||
@@ -17,6 +17,9 @@ from core.auth.models import CurrentUser
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachment,
|
||||
@@ -51,7 +54,11 @@ class AgentRepositoryLike(Protocol):
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
@@ -62,8 +69,13 @@ class AgentRepositoryLike(Protocol):
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -138,6 +150,17 @@ class AgentService:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
forwarded_props = getattr(run_input, "forwarded_props", None)
|
||||
try:
|
||||
agent_type = parse_forwarded_props_agent_type(forwarded_props)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
if agent_type == "memory":
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="memory mode is automation-only",
|
||||
)
|
||||
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
@@ -161,25 +184,21 @@ class AgentService:
|
||||
run_input=run_input,
|
||||
current_user=current_user,
|
||||
)
|
||||
visibility_mask = await self._resolve_user_message_visibility_mask(
|
||||
agent_type=agent_type
|
||||
)
|
||||
await self._repository.persist_user_message(
|
||||
session_id=thread_id,
|
||||
content=user_message_text,
|
||||
metadata=user_message_metadata,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
await self._repository.commit()
|
||||
|
||||
forwarded_props = getattr(run_input, "forwarded_props", None)
|
||||
system_agent_mode = "worker"
|
||||
if isinstance(forwarded_props, dict):
|
||||
raw_mode = forwarded_props.get("system_agent_mode")
|
||||
if isinstance(raw_mode, str) and raw_mode.strip():
|
||||
system_agent_mode = raw_mode.strip().lower()
|
||||
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"system_agent_mode": system_agent_mode,
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
@@ -193,6 +212,61 @@ class AgentService:
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
|
||||
normalized_agent_type = agent_type.strip().lower()
|
||||
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
|
||||
if normalized_agent_type == "memory":
|
||||
return bit_mask(bit=18)
|
||||
|
||||
agent_config = await self._repository.get_system_agent_config(
|
||||
agent_type=normalized_agent_type
|
||||
)
|
||||
if agent_config is None:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="invalid forwarded_props.agent_type"
|
||||
)
|
||||
llm_config = SystemAgentLLMConfig.model_validate(
|
||||
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
|
||||
)
|
||||
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
|
||||
|
||||
if normalized_agent_type == "worker":
|
||||
router_config = await self._repository.get_system_agent_config(
|
||||
agent_type="router"
|
||||
)
|
||||
worker_config = await self._repository.get_system_agent_config(
|
||||
agent_type="worker"
|
||||
)
|
||||
if router_config is None or worker_config is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="system agent visibility config missing",
|
||||
)
|
||||
router_mask = bit_mask(
|
||||
bit=SystemAgentLLMConfig.model_validate(
|
||||
(
|
||||
router_config.get("config")
|
||||
if isinstance(router_config, dict)
|
||||
else {}
|
||||
)
|
||||
or {}
|
||||
).visibility_consumer_bit
|
||||
)
|
||||
worker_mask = bit_mask(
|
||||
bit=SystemAgentLLMConfig.model_validate(
|
||||
(
|
||||
worker_config.get("config")
|
||||
if isinstance(worker_config, dict)
|
||||
else {}
|
||||
)
|
||||
or {}
|
||||
).visibility_consumer_bit
|
||||
)
|
||||
return history_bit_mask | router_mask | worker_mask
|
||||
|
||||
return history_bit_mask | agent_mask
|
||||
|
||||
async def _prepare_user_message(
|
||||
self,
|
||||
*,
|
||||
@@ -408,6 +482,7 @@ class AgentService:
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
|
||||
)
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
|
||||
+127
-199
@@ -2,28 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
UserByPhoneResponse,
|
||||
)
|
||||
from v1.auth.service import AuthServiceGateway
|
||||
|
||||
@@ -36,7 +30,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def __init__(self) -> None:
|
||||
self._user_lookup_cache_ttl_seconds: int = 60
|
||||
self._user_lookup_cache_expires_at: float = 0.0
|
||||
self._users_by_email: dict[str, Any] = {}
|
||||
self._users_by_phone: dict[str, Any] = {}
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
@@ -44,47 +38,31 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
client = self._get_client()
|
||||
metadata: dict[str, Any] = {"username": request.username}
|
||||
if request.invite_code:
|
||||
metadata["invite_code"] = request.invite_code
|
||||
payload: dict[str, Any] = {
|
||||
"email": request.email,
|
||||
"password": request.password,
|
||||
"data": metadata,
|
||||
}
|
||||
if request.redirect_to:
|
||||
payload["options"] = {
|
||||
"email_redirect_to": _validate_redirect_url(request.redirect_to)
|
||||
"phone": request.phone,
|
||||
"options": {"should_create_user": True},
|
||||
}
|
||||
try:
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
|
||||
await asyncio.to_thread(sign_in_with_otp, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup failed", error_type=type(exc).__name__)
|
||||
logger.warning("Send otp failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid signup request"
|
||||
) from exc
|
||||
raise HTTPException(status_code=429, detail="Too many requests") from exc
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
if request.type != "signup":
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": request.type,
|
||||
"email": request.email,
|
||||
"type": "sms",
|
||||
"phone": request.phone,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
@@ -92,7 +70,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(response, "Invalid verification code")
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup verify failed", error_type=type(exc).__name__)
|
||||
logger.warning("Create phone session failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
@@ -102,45 +80,6 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid verification code"
|
||||
) from exc
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
if request.type == "recovery":
|
||||
await self.request_password_reset(
|
||||
PasswordResetRequest(
|
||||
email=request.email,
|
||||
redirect_to=request.redirect_to,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
payload: dict[str, Any] = {"type": request.type, "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
logger.warning("Login failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
@@ -189,98 +128,84 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
admin_client = self._get_admin_client()
|
||||
normalized_email = email.lower()
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
normalized_phone = _normalize_phone(phone)
|
||||
if not normalized_phone:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
now = time.monotonic()
|
||||
if now >= self._user_lookup_cache_expires_at:
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_email: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_email = str(getattr(candidate, "email", "")).lower()
|
||||
if candidate_email:
|
||||
users_by_email[candidate_email] = candidate
|
||||
self._users_by_email = users_by_email
|
||||
self._user_lookup_cache_expires_at = (
|
||||
now + self._user_lookup_cache_ttl_seconds
|
||||
)
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
|
||||
user = self._users_by_email.get(normalized_email)
|
||||
user = self._users_by_phone.get(normalized_phone)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserByEmailResponse(
|
||||
user_phone = _normalize_phone(getattr(user, "phone", ""))
|
||||
if not user_phone:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserByPhoneResponse(
|
||||
id=str(getattr(user, "id", "")),
|
||||
email=str(getattr(user, "email", "")),
|
||||
phone=user_phone,
|
||||
created_at=str(getattr(user, "created_at", "")),
|
||||
email_confirmed_at=(
|
||||
str(getattr(user, "email_confirmed_at", ""))
|
||||
if getattr(user, "email_confirmed_at", None)
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user, "phone_confirmed_at", ""))
|
||||
if getattr(user, "phone_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {
|
||||
"redirect_to": _validate_redirect_url(request.redirect_to)
|
||||
}
|
||||
await asyncio.to_thread(reset_email, email, options=options)
|
||||
else:
|
||||
await asyncio.to_thread(reset_email, email)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset request failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
normalized_query = _normalize_phone_search_query(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
if normalized_query.startswith("+"):
|
||||
matched_user = self._users_by_phone.get(normalized_query)
|
||||
if matched_user is None:
|
||||
return []
|
||||
user_id = str(getattr(matched_user, "id", ""))
|
||||
return [user_id] if user_id else []
|
||||
|
||||
digits = _digits_only(normalized_query)
|
||||
if not digits:
|
||||
return []
|
||||
|
||||
matched_records: list[tuple[str, str]] = []
|
||||
for cached_phone, candidate in self._users_by_phone.items():
|
||||
candidate_digits = _digits_only(cached_phone)
|
||||
if not candidate_digits.endswith(digits):
|
||||
continue
|
||||
user_id = str(getattr(candidate, "id", ""))
|
||||
if user_id:
|
||||
matched_records.append((cached_phone, user_id))
|
||||
|
||||
if not matched_records:
|
||||
return []
|
||||
|
||||
unique_ids: list[str] = []
|
||||
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
|
||||
if user_id in unique_ids:
|
||||
continue
|
||||
unique_ids.append(user_id)
|
||||
if len(unique_ids) >= max(1, limit):
|
||||
break
|
||||
return unique_ids
|
||||
|
||||
async def _refresh_user_lookup_cache_if_needed(self) -> None:
|
||||
now = time.monotonic()
|
||||
if now < self._user_lookup_cache_expires_at:
|
||||
return
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
client = self._get_client()
|
||||
admin_client = self._get_admin_client()
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
user_id = str(getattr(user, "id", "")) if user is not None else ""
|
||||
if session is None or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset confirm failed", error_type=type(exc).__name__
|
||||
)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
) from exc
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_phone: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
|
||||
if candidate_phone:
|
||||
users_by_phone[candidate_phone] = candidate
|
||||
self._users_by_phone = users_by_phone
|
||||
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
|
||||
|
||||
|
||||
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
@@ -312,55 +237,24 @@ def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
return any(token in code or token in message for token in indicators)
|
||||
|
||||
|
||||
def _validate_redirect_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
|
||||
origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
||||
allowed_origins = {
|
||||
_normalize_origin(candidate)
|
||||
for candidate in config.cors.allow_origins
|
||||
if _is_http_origin(candidate)
|
||||
}
|
||||
if origin not in allowed_origins:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
return url
|
||||
|
||||
|
||||
def _normalize_origin(value: str) -> str:
|
||||
parsed = urlparse(value)
|
||||
return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
||||
|
||||
|
||||
def _is_http_origin(value: str) -> bool:
|
||||
parsed = urlparse(value)
|
||||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||||
|
||||
|
||||
def _coerce_reset_email(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
nested = value.get("email") or value.get("value")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
|
||||
raise HTTPException(status_code=422, detail="Invalid email")
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
if session is None or user is None:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
email = getattr(user, "email", None)
|
||||
if not email:
|
||||
phone = _normalize_phone(getattr(user, "phone", None))
|
||||
if not phone:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
auth_user = AuthUser(id=str(user.id), email=str(email))
|
||||
try:
|
||||
auth_user = AuthUser(id=str(user.id), phone=str(phone))
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Auth response returned invalid phone format",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=failure_message) from exc
|
||||
return SessionResponse(
|
||||
access_token=str(session.access_token),
|
||||
refresh_token=str(session.refresh_token),
|
||||
@@ -389,3 +283,37 @@ def _list_auth_users(client: Any) -> list[Any]:
|
||||
page += 1
|
||||
|
||||
return users
|
||||
|
||||
|
||||
def _normalize_phone(raw_phone: object) -> str | None:
|
||||
phone = str(raw_phone).strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
phone = phone.replace(separator, "")
|
||||
if not phone:
|
||||
return None
|
||||
if phone.startswith("00") and len(phone) > 2:
|
||||
return f"+{phone[2:]}"
|
||||
if phone.startswith("+"):
|
||||
return phone
|
||||
if phone.isdigit():
|
||||
return f"+{phone}"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_phone_search_query(raw_query: str) -> str | None:
|
||||
query = raw_query.strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
query = query.replace(separator, "")
|
||||
if not query:
|
||||
return None
|
||||
if query.startswith("00") and len(query) > 2:
|
||||
return f"+{query[2:]}"
|
||||
if query.startswith("+"):
|
||||
return query
|
||||
if query.isdigit():
|
||||
return query
|
||||
return None
|
||||
|
||||
|
||||
def _digits_only(value: str) -> str:
|
||||
return "".join(ch for ch in value if ch.isdigit())
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.config.settings import config
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
@@ -22,80 +18,49 @@ from v1.auth.service import AuthService
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/verifications", response_model=VerificationCreateResponse, status_code=202
|
||||
)
|
||||
async def create_verification(
|
||||
payload: VerificationCreateRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> VerificationCreateResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="signup_start",
|
||||
identifier=payload.email,
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
return await service.create_verification(payload)
|
||||
|
||||
|
||||
@router.post("/verify", response_model=SessionResponse)
|
||||
async def verify(
|
||||
payload: VerificationVerifyRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse | Response:
|
||||
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
|
||||
limit = 10
|
||||
window_seconds = 600
|
||||
await enforce_rate_limit(
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if payload.type == "signup":
|
||||
return await service.verify_verification(payload)
|
||||
if payload.new_password is None:
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
await service.confirm_password_reset(
|
||||
PasswordResetConfirmRequest(
|
||||
email=payload.email,
|
||||
token=payload.token,
|
||||
new_password=payload.new_password,
|
||||
)
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/resend", status_code=204)
|
||||
async def resend(
|
||||
payload: VerificationResendRequest,
|
||||
@router.post("/otp/send", status_code=204)
|
||||
async def send_otp(
|
||||
payload: OtpSendRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=5,
|
||||
scope="otp_send_phone",
|
||||
identifier=payload.phone,
|
||||
limit=3,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.resend_verification(payload)
|
||||
await enforce_rate_limit(
|
||||
scope="otp_send_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.send_otp(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(
|
||||
payload: SessionCreateRequest,
|
||||
@router.post("/phone-session", response_model=SessionResponse)
|
||||
async def create_phone_session(
|
||||
payload: PhoneSessionCreateRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="login",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
scope="phone_session_phone",
|
||||
identifier=payload.phone,
|
||||
limit=6,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_session(payload)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_phone_session(payload)
|
||||
|
||||
|
||||
@router.post("/sessions/refresh", response_model=SessionResponse)
|
||||
@@ -130,6 +95,11 @@ async def delete_session(
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
host = request.client.host if request.client else ""
|
||||
if not host:
|
||||
return "unknown"
|
||||
|
||||
if _should_trust_proxy_headers(host):
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
@@ -138,5 +108,10 @@ def _client_ip(request: Request) -> str:
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
host = request.client.host if request.client else ""
|
||||
return host or "unknown"
|
||||
|
||||
return host
|
||||
|
||||
|
||||
def _should_trust_proxy_headers(host: str) -> bool:
|
||||
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
|
||||
return host in trusted_proxies
|
||||
|
||||
@@ -1,49 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
OtpType = Literal["signup", "recovery"]
|
||||
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
|
||||
|
||||
|
||||
class VerificationCreateRequest(BaseModel):
|
||||
class OtpSendRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
username: str = Field(min_length=3, max_length=30)
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
redirect_to: str | None = None
|
||||
invite_code: str | None = None
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class VerificationResendRequest(BaseModel):
|
||||
email: EmailStr
|
||||
type: OtpType = "signup"
|
||||
redirect_to: str | None = None
|
||||
class PhoneSessionCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class VerificationVerifyRequest(BaseModel):
|
||||
type: OtpType = "signup"
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str | None = Field(
|
||||
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_type_payload(self) -> "VerificationVerifyRequest":
|
||||
if self.type == "recovery" and self.new_password is None:
|
||||
raise ValueError("new_password is required when type is recovery")
|
||||
if self.type == "signup" and self.new_password is not None:
|
||||
raise ValueError("new_password is only allowed when type is recovery")
|
||||
return self
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
|
||||
|
||||
class SessionRefreshRequest(BaseModel):
|
||||
@@ -56,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
@@ -67,23 +40,12 @@ class SessionResponse(BaseModel):
|
||||
user: AuthUser
|
||||
|
||||
|
||||
class UserByEmailResponse(BaseModel):
|
||||
class UserByPhoneResponse(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
created_at: str
|
||||
email_confirmed_at: str | None = None
|
||||
phone_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class VerificationCreateResponse(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
redirect_to: str | None = None
|
||||
|
||||
|
||||
class PasswordResetConfirmRequest(BaseModel):
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
@@ -1,52 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
|
||||
|
||||
class AuthServiceGateway(Protocol):
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
@@ -54,50 +32,16 @@ class AuthService:
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
normalized_invite_code = _normalize_invite_code(request.invite_code)
|
||||
normalized_request = request.model_copy(
|
||||
update={"invite_code": normalized_invite_code}
|
||||
)
|
||||
return await self._gateway.create_verification(normalized_request)
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return await self._gateway.verify_verification(request)
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
await self._gateway.resend_verification(request)
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
return await self._gateway.create_session(request)
|
||||
return await self._gateway.create_phone_session(request)
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
await self._gateway.request_password_reset(request)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
await self._gateway.confirm_password_reset(request)
|
||||
|
||||
|
||||
_INVITE_CODE_PATTERN = re.compile(r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{4}$")
|
||||
|
||||
|
||||
def _normalize_invite_code(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
normalized = value.strip().upper()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
return normalized if _INVITE_CODE_PATTERN.fullmatch(normalized) else None
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import ClassVar
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from schemas.inbox.messages import (
|
||||
CalendarContent,
|
||||
@@ -154,7 +154,11 @@ _PERMISSION_EDIT = 4
|
||||
class ScheduleItemShareRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
email: EmailStr = Field(..., description="Email of user to share with")
|
||||
phone: str = Field(
|
||||
...,
|
||||
pattern=r"^\+861[3-9]\d{9}$",
|
||||
description="Phone of user to share with",
|
||||
)
|
||||
permission_view: bool = Field(True, description="Grant view permission")
|
||||
permission_edit: bool = Field(False, description="Grant edit permission")
|
||||
permission_invite: bool = Field(False, description="Grant invite permission")
|
||||
|
||||
@@ -31,19 +31,19 @@ from v1.schedule_items.schemas import (
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
from v1.auth.schemas import UserByPhoneResponse
|
||||
|
||||
logger = get_logger("v1.schedule_items.service")
|
||||
|
||||
|
||||
class AuthByEmailGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
|
||||
class AuthByPhoneGateway(Protocol):
|
||||
async def get_user_by_phone(self, phone: str) -> "UserByPhoneResponse": ...
|
||||
|
||||
|
||||
class ScheduleItemService(BaseService):
|
||||
_repository: ScheduleItemRepository
|
||||
_session: AsyncSession
|
||||
_auth_gateway: AuthByEmailGateway
|
||||
_auth_gateway: AuthByPhoneGateway
|
||||
_inbox_repository: InboxMessageRepository
|
||||
|
||||
def __init__(
|
||||
@@ -51,7 +51,7 @@ class ScheduleItemService(BaseService):
|
||||
repository: ScheduleItemRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
auth_gateway: AuthByEmailGateway | None = None,
|
||||
auth_gateway: AuthByPhoneGateway | None = None,
|
||||
inbox_repository: InboxMessageRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
@@ -329,7 +329,7 @@ class ScheduleItemService(BaseService):
|
||||
detail=f"You can only share with permissions up to {inviter_permission}",
|
||||
)
|
||||
|
||||
target_user = await self._auth_gateway.get_user_by_email(request.email)
|
||||
target_user = await self._auth_gateway.get_user_by_phone(request.phone)
|
||||
recipient_id = UUID(target_user.id)
|
||||
|
||||
existing = await self._repository.get_subscription(item_id, recipient_id)
|
||||
@@ -404,7 +404,7 @@ class ScheduleItemService(BaseService):
|
||||
except ValueError:
|
||||
await self._session.rollback()
|
||||
logger.exception(
|
||||
"Auth lookup returned invalid user id", email=request.email
|
||||
"Auth lookup returned invalid user id", phone=request.phone
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
|
||||
@@ -76,11 +76,11 @@ async def _verify_user_with_supabase(token: str) -> CurrentUser | None:
|
||||
parsed_id = UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
email = getattr(user, "email", None)
|
||||
phone = getattr(user, "phone", None)
|
||||
role = getattr(user, "role", None)
|
||||
return CurrentUser(
|
||||
id=parsed_id,
|
||||
email=email if isinstance(email, str) else None,
|
||||
phone=phone if isinstance(phone, str) else None,
|
||||
role=role if isinstance(role, str) else None,
|
||||
)
|
||||
|
||||
@@ -125,9 +125,9 @@ async def get_current_user(
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
logger.debug("JWT validation successful", user_id=str(user_id))
|
||||
email = payload.get("email") if isinstance(payload.get("email"), str) else None
|
||||
phone = payload.get("phone") if isinstance(payload.get("phone"), str) else None
|
||||
role = payload.get("role") if isinstance(payload.get("role"), str) else None
|
||||
return CurrentUser(id=user_id, email=email, role=role)
|
||||
return CurrentUser(id=user_id, phone=phone, role=role)
|
||||
|
||||
|
||||
async def get_user_repository(
|
||||
|
||||
@@ -38,7 +38,7 @@ class UserRepository(Protocol):
|
||||
...
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
|
||||
"""Search users by username (ilike) or email (exact match)."""
|
||||
"""Search users by username (ilike) or phone (exact match)."""
|
||||
...
|
||||
|
||||
|
||||
|
||||
@@ -21,19 +21,22 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from schemas.user.context import UserContext
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
|
||||
logger = get_logger("v1.users.service")
|
||||
|
||||
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$")
|
||||
|
||||
|
||||
class AuthLookupGateway(Protocol):
|
||||
async def get_user_id_by_email(self, email: str) -> str | None: ...
|
||||
async def search_user_ids_by_phone(
|
||||
self, query: str, limit: int = 20
|
||||
) -> list[str]: ...
|
||||
|
||||
|
||||
class AuthByEmailGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
|
||||
class AuthByPhoneGateway(Protocol):
|
||||
async def search_user_ids_by_phone(
|
||||
self, query: str, limit: int = 20
|
||||
) -> list[str]: ...
|
||||
|
||||
|
||||
class UserContextInvalidator(Protocol):
|
||||
@@ -41,15 +44,14 @@ class UserContextInvalidator(Protocol):
|
||||
|
||||
|
||||
class AuthLookupAdapter:
|
||||
def __init__(self, gateway: AuthByEmailGateway) -> None:
|
||||
def __init__(self, gateway: AuthByPhoneGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def get_user_id_by_email(self, email: str) -> str | None:
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
try:
|
||||
response = await self._gateway.get_user_by_email(email)
|
||||
return response.id
|
||||
return await self._gateway.search_user_ids_by_phone(query, limit=limit)
|
||||
except HTTPException:
|
||||
return None
|
||||
return []
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
@@ -92,11 +94,11 @@ class UserService(BaseService):
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
email = self._current_user.email if self._current_user else None
|
||||
phone = self._current_user.phone if self._current_user else None
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
phone=phone,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
@@ -152,11 +154,11 @@ class UserService(BaseService):
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
email = self._current_user.email if self._current_user else None
|
||||
phone = self._current_user.phone if self._current_user else None
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
phone=phone,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
@@ -181,28 +183,50 @@ class UserService(BaseService):
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
|
||||
query = request.query.strip()
|
||||
|
||||
if _EMAIL_PATTERN.match(query):
|
||||
return await self._search_by_email(query)
|
||||
if _looks_like_phone_query(query):
|
||||
phone_results = await self._search_by_phone(query)
|
||||
if not query.isdigit():
|
||||
return phone_results
|
||||
username_results = await self._search_by_username(query)
|
||||
if not phone_results:
|
||||
return username_results
|
||||
merged_by_id = {result.id: result for result in phone_results}
|
||||
for result in username_results:
|
||||
merged_by_id.setdefault(result.id, result)
|
||||
return list(merged_by_id.values())
|
||||
|
||||
return await self._search_by_username(query)
|
||||
|
||||
async def _search_by_email(self, email: str) -> list[UserContext]:
|
||||
async def _search_by_phone(self, phone: str) -> list[UserContext]:
|
||||
if self._auth_gateway is None:
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
|
||||
if user_id_str is None:
|
||||
user_id_values = await self._auth_gateway.search_user_ids_by_phone(
|
||||
phone, limit=20
|
||||
)
|
||||
if not user_id_values:
|
||||
return []
|
||||
|
||||
user_ids: list[UUID] = []
|
||||
for raw_id in user_id_values:
|
||||
try:
|
||||
user_ids.append(UUID(raw_id))
|
||||
except ValueError:
|
||||
continue
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
user = await self._repository.get_by_user_id(UUID(user_id_str))
|
||||
users_by_id = await self._repository.get_by_user_ids(user_ids)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
results: list[UserContext] = []
|
||||
for user_id in user_ids:
|
||||
user = users_by_id.get(user_id)
|
||||
if user is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
continue
|
||||
results.append(
|
||||
UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
@@ -210,7 +234,8 @@ class UserService(BaseService):
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
]
|
||||
)
|
||||
return results
|
||||
|
||||
async def _search_by_username(self, query: str) -> list[UserContext]:
|
||||
try:
|
||||
@@ -228,3 +253,10 @@ class UserService(BaseService):
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
|
||||
def _looks_like_phone_query(query: str) -> bool:
|
||||
if not _PHONE_QUERY_PATTERN.fullmatch(query):
|
||||
return False
|
||||
digits_count = sum(char.isdigit() for char in query)
|
||||
return digits_count >= 4
|
||||
|
||||
@@ -12,41 +12,24 @@ from app import app
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
|
||||
class FakeE2EAuthService(AuthService):
|
||||
def __init__(self) -> None:
|
||||
self._user = AuthUser(id="user-1", email="user@example.com")
|
||||
self._user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
) -> SessionResponse:
|
||||
return SessionResponse(
|
||||
access_token="access-1",
|
||||
refresh_token="refresh-1",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
return None
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return SessionResponse(
|
||||
access_token="access-2",
|
||||
refresh_token="refresh-2",
|
||||
@@ -105,41 +88,25 @@ def test_auth_flow_e2e() -> None:
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
verification = request_context.post(
|
||||
"/api/v1/auth/verifications",
|
||||
data=json.dumps(
|
||||
{
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
}
|
||||
),
|
||||
send_code = request_context.post(
|
||||
"/api/v1/auth/otp/send",
|
||||
data=json.dumps({"phone": "+8613812345678"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert verification.status == 202
|
||||
assert send_code.status == 204
|
||||
|
||||
verify = request_context.post(
|
||||
"/api/v1/auth/verify",
|
||||
login_or_register = request_context.post(
|
||||
"/api/v1/auth/phone-session",
|
||||
data=json.dumps(
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"phone": "+8613812345678",
|
||||
"token": "123456",
|
||||
}
|
||||
),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert verify.status == 200
|
||||
assert verify.json()["access_token"] == "access-1"
|
||||
|
||||
login = request_context.post(
|
||||
"/api/v1/auth/sessions",
|
||||
data=json.dumps(
|
||||
{"email": "user@example.com", "password": "secret123"}
|
||||
),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert login.status == 200
|
||||
assert login.json()["access_token"] == "access-2"
|
||||
assert login_or_register.status == 200
|
||||
assert login_or_register.json()["access_token"] == "access-2"
|
||||
|
||||
refresh = request_context.post(
|
||||
"/api/v1/auth/sessions/refresh",
|
||||
|
||||
@@ -4,11 +4,16 @@ import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.infra.dependencies import get_redis_service
|
||||
|
||||
pytest.skip(
|
||||
"infra health endpoint removed from v1 API",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
class _FakeService:
|
||||
@@ -52,8 +57,6 @@ def _start_server(host: str, port: int):
|
||||
|
||||
|
||||
def test_infra_health_e2e() -> None:
|
||||
app.dependency_overrides[get_redis_service] = lambda: _FakeService()
|
||||
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
@@ -11,21 +11,22 @@ import uvicorn
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from schemas.user.context import UserContext
|
||||
from v1.users.dependencies import get_current_user, get_user_service
|
||||
from v1.users.schemas import UserResponse, UserUpdateRequest
|
||||
from v1.users.schemas import UserUpdateRequest
|
||||
|
||||
|
||||
class FakeUserService:
|
||||
"""Fake service for E2E testing."""
|
||||
|
||||
def __init__(self, user: UserResponse) -> None:
|
||||
def __init__(self, user: UserContext) -> None:
|
||||
self._user = user
|
||||
|
||||
async def get_me(self) -> UserResponse:
|
||||
async def get_me(self) -> UserContext:
|
||||
return self._user
|
||||
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserResponse:
|
||||
return UserResponse(
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserContext:
|
||||
return UserContext(
|
||||
id=self._user.id,
|
||||
username=(
|
||||
update.username if update.username is not None else self._user.username
|
||||
@@ -38,6 +39,7 @@ class FakeUserService:
|
||||
bio=update.bio if update.bio is not None else self._user.bio,
|
||||
)
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
@@ -65,7 +67,7 @@ def _start_server(host: str, port: int):
|
||||
|
||||
def test_profile_flow_e2e() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
user = UserContext(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
|
||||
@@ -11,16 +11,10 @@ from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.rate_limit import reset_rate_limit_state
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
@@ -30,58 +24,39 @@ def reset_auth_rate_limit_state() -> None:
|
||||
reset_rate_limit_state()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def force_in_memory_rate_limit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def _raise_redis_unavailable() -> None:
|
||||
raise RuntimeError("redis unavailable in integration tests")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"v1.auth.rate_limit.get_or_init_redis_client",
|
||||
_raise_redis_unavailable,
|
||||
)
|
||||
|
||||
|
||||
class FakeAuthService(AuthService):
|
||||
def __init__(self, token_response: SessionResponse) -> None:
|
||||
self._token_response = token_response
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
if request.email == "exists@example.com":
|
||||
raise HTTPException(status_code=422, detail="Invalid signup request")
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
if request.phone == "+8613811111111":
|
||||
raise HTTPException(status_code=401, detail="Invalid verification code")
|
||||
return None
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
if request.token == "000000":
|
||||
raise HTTPException(status_code=401, detail="Invalid verification code")
|
||||
return self._token_response
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
return None
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
if email == "missing@example.com":
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return UserByEmailResponse(
|
||||
id="user-1",
|
||||
email=email,
|
||||
created_at="2026-02-24T00:00:00Z",
|
||||
email_confirmed_at=None,
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
return None
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
if request.token == "000000":
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
|
||||
def _get_service() -> AuthService:
|
||||
@@ -90,761 +65,126 @@ def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_signup_start_returns_pending_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
def _token_response() -> SessionResponse:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
return SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
def test_send_otp_returns_204() -> None:
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
FakeAuthService(_token_response())
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
"/api/v1/auth/otp/send",
|
||||
json={"phone": "+8613812345678"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"email": "user@example.com"}
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_verify_returns_token_response() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
def test_phone_session_returns_token_response() -> None:
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
FakeAuthService(_token_response())
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={"email": "user@example.com", "token": "123456"},
|
||||
"/api/v1/auth/phone-session",
|
||||
json={"phone": "+8613812345678", "token": "123456"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["access_token"] == "access"
|
||||
assert body["refresh_token"] == "refresh"
|
||||
assert body["user"]["email"] == "user@example.com"
|
||||
assert body["user"]["phone"] == "+8613812345678"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_resend_returns_generic_message() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
def test_phone_session_invalid_token_returns_problem_details() -> None:
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
FakeAuthService(_token_response())
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/resend",
|
||||
json={"type": "recovery", "email": "user@example.com"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
assert response.content == b""
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_verify_invalid_token_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={"email": "user@example.com", "token": "000000"},
|
||||
"/api/v1/auth/phone-session",
|
||||
json={"phone": "+8613812345678", "token": "000000"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid verification code"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_start_existing_email_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
def test_legacy_routes_are_removed() -> None:
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
FakeAuthService(_token_response())
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "exists@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Entity"
|
||||
assert body["status"] == 422
|
||||
assert body["detail"] == "Invalid signup request"
|
||||
assert client.post("/api/v1/auth/verifications", json={}).status_code == 404
|
||||
assert client.post("/api/v1/auth/verify", json={}).status_code == 404
|
||||
assert client.post("/api/v1/auth/resend", json={}).status_code == 404
|
||||
assert client.post("/api/v1/auth/sessions", json={}).status_code == 405
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_verify_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
def test_send_otp_phone_rate_limited_after_too_many_attempts() -> None:
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
FakeAuthService(_token_response())
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
for _ in range(3):
|
||||
ok = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={"email": "user@example.com", "token": "123456"},
|
||||
)
|
||||
assert ok.status_code == 200
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={"email": "user@example.com", "token": "123456"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_resend_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(5):
|
||||
ok = client.post(
|
||||
"/api/v1/auth/resend",
|
||||
json={"email": "user@example.com"},
|
||||
"/api/v1/auth/otp/send",
|
||||
json={"phone": "+8613812345678"},
|
||||
)
|
||||
assert ok.status_code == 204
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/resend",
|
||||
json={"email": "user@example.com"},
|
||||
"/api/v1/auth/otp/send",
|
||||
json={"phone": "+8613812345678"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_start_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
def test_phone_session_rate_limited_after_too_many_attempts() -> None:
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
FakeAuthService(_token_response())
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(5):
|
||||
ok = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
assert ok.status_code == 202
|
||||
|
||||
for _ in range(6):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_login_invalid_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/sessions",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid credentials"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_invalid_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/sessions/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
assert body["detail"] == "Invalid refresh token"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_returns_no_content() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.request(
|
||||
"DELETE",
|
||||
"/api/v1/auth/sessions",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
assert response.content == b""
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_login_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/sessions",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
"/api/v1/auth/phone-session",
|
||||
json={"phone": "+8613812345678", "token": "000000"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/sessions",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/sessions/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/sessions/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_rate_limit_not_bypassed_by_changing_refresh_token() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for index in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/sessions/refresh",
|
||||
json={"refresh_token": f"invalid-{index}"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/sessions/refresh",
|
||||
json={"refresh_token": "invalid-extra"},
|
||||
"/api/v1/auth/phone-session",
|
||||
json={"phone": "+8613812345678", "token": "000000"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
ok = client.request(
|
||||
"DELETE",
|
||||
"/api/v1/auth/sessions",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert ok.status_code == 204
|
||||
|
||||
blocked = client.request(
|
||||
"DELETE",
|
||||
"/api/v1/auth/sessions",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_rate_limit_not_bypassed_by_changing_refresh_token() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for index in range(10):
|
||||
ok = client.request(
|
||||
"DELETE",
|
||||
"/api/v1/auth/sessions",
|
||||
json={"refresh_token": f"refresh-{index}"},
|
||||
)
|
||||
assert ok.status_code == 204
|
||||
|
||||
blocked = client.request(
|
||||
"DELETE",
|
||||
"/api/v1/auth/sessions",
|
||||
json={"refresh_token": "refresh-extra"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_start_validation_error_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/auth/verifications", json={})
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Entity"
|
||||
assert body["status"] == 422
|
||||
assert body["detail"] == "Invalid request"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_start_missing_username_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={"email": "user@example.com", "password": "secret123"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Entity"
|
||||
assert body["status"] == 422
|
||||
assert body["detail"] == "Invalid request"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_request_returns_204() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/resend",
|
||||
json={"email": "user@example.com"},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_confirm_returns_204() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={
|
||||
"type": "recovery",
|
||||
"email": "user@example.com",
|
||||
"token": "123456",
|
||||
"new_password": "newpassword123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_confirm_invalid_token_returns_401() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={
|
||||
"type": "recovery",
|
||||
"email": "user@example.com",
|
||||
"token": "000000",
|
||||
"new_password": "newpassword123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unauthorized"
|
||||
assert body["status"] == 401
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_password_reset_confirm_weak_password_returns_422() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verify",
|
||||
json={
|
||||
"type": "recovery",
|
||||
"email": "user@example.com",
|
||||
"token": "123456",
|
||||
"new_password": "123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
class TestInviteCodeSignup:
|
||||
def test_signup_with_valid_invite_code_returns_202(self) -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
"invite_code": "A2B3",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"email": "user@example.com"}
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
def test_signup_with_invalid_invite_code_length_returns_202(self) -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
"invite_code": "ABC123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"email": "user@example.com"}
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
def test_signup_with_invalid_invite_code_chars_returns_202(self) -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/auth/verifications",
|
||||
json={
|
||||
"username": "demo",
|
||||
"email": "user@example.com",
|
||||
"password": "secret123",
|
||||
"invite_code": "ABCD1234",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"email": "user@example.com"}
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -9,12 +9,12 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from schemas.user.context import UserContext
|
||||
from v1.friendships.dependencies import get_friendship_service
|
||||
from v1.friendships.schemas import (
|
||||
FriendRequestCreate,
|
||||
FriendRequestResponse,
|
||||
FriendResponse,
|
||||
UserBasicInfo,
|
||||
)
|
||||
from v1.friendships.service import FriendshipService
|
||||
from v1.users.dependencies import get_current_user
|
||||
@@ -31,9 +31,9 @@ class FakeFriendshipService(FriendshipService):
|
||||
async def send_request(self, request: FriendRequestCreate) -> FriendRequestResponse:
|
||||
return FriendRequestResponse(
|
||||
id=UUID("11111111-1111-1111-1111-111111111111"),
|
||||
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
|
||||
content=request.content,
|
||||
sender=UserContext(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
|
||||
content={"text": request.content} if request.content else None,
|
||||
status="pending",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -41,9 +41,9 @@ class FakeFriendshipService(FriendshipService):
|
||||
async def accept_request(self, friendship_id: UUID) -> FriendRequestResponse:
|
||||
return FriendRequestResponse(
|
||||
id=friendship_id,
|
||||
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
|
||||
content="Hello!",
|
||||
sender=UserContext(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
|
||||
content={"text": "Hello!"},
|
||||
status="accepted",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -51,9 +51,9 @@ class FakeFriendshipService(FriendshipService):
|
||||
async def decline_request(self, friendship_id: UUID) -> FriendRequestResponse:
|
||||
return FriendRequestResponse(
|
||||
id=friendship_id,
|
||||
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
|
||||
content="Hello!",
|
||||
sender=UserContext(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
|
||||
content={"text": "Hello!"},
|
||||
status="rejected",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -61,9 +61,9 @@ class FakeFriendshipService(FriendshipService):
|
||||
async def cancel_request(self, friendship_id: UUID) -> FriendRequestResponse:
|
||||
return FriendRequestResponse(
|
||||
id=friendship_id,
|
||||
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
|
||||
content="Hello!",
|
||||
sender=UserContext(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
|
||||
content={"text": "Hello!"},
|
||||
status="canceled",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -72,11 +72,11 @@ class FakeFriendshipService(FriendshipService):
|
||||
return [
|
||||
FriendRequestResponse(
|
||||
id=UUID("11111111-1111-1111-1111-111111111111"),
|
||||
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserBasicInfo(
|
||||
sender=UserContext(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserContext(
|
||||
id="user-2", username="recipient", avatar_url=None
|
||||
),
|
||||
content="Hello!",
|
||||
content={"text": "Hello!"},
|
||||
status="pending",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -86,10 +86,8 @@ class FakeFriendshipService(FriendshipService):
|
||||
return [
|
||||
FriendRequestResponse(
|
||||
id=UUID("22222222-2222-2222-2222-222222222222"),
|
||||
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserBasicInfo(
|
||||
id="user-3", username="target", avatar_url=None
|
||||
),
|
||||
sender=UserContext(id="user-1", username="sender", avatar_url=None),
|
||||
recipient=UserContext(id="user-3", username="target", avatar_url=None),
|
||||
content=None,
|
||||
status="pending",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
@@ -100,7 +98,7 @@ class FakeFriendshipService(FriendshipService):
|
||||
return [
|
||||
FriendResponse(
|
||||
id=UUID("33333333-3333-3333-3333-333333333333"),
|
||||
friend=UserBasicInfo(id="user-2", username="friend", avatar_url=None),
|
||||
friend=UserContext(id="user-2", username="friend", avatar_url=None),
|
||||
status="active",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
accepted_at=datetime.now(timezone.utc),
|
||||
@@ -110,7 +108,7 @@ class FakeFriendshipService(FriendshipService):
|
||||
async def remove_friend(self, friend_id: UUID) -> FriendResponse:
|
||||
return FriendResponse(
|
||||
id=UUID("33333333-3333-3333-3333-333333333333"),
|
||||
friend=UserBasicInfo(id=str(friend_id), username="friend", avatar_url=None),
|
||||
friend=UserContext(id=str(friend_id), username="friend", avatar_url=None),
|
||||
status="active",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
accepted_at=datetime.now(timezone.utc),
|
||||
@@ -129,7 +127,7 @@ def _override_friendship_service(
|
||||
def _get_fake_current_user() -> CurrentUser:
|
||||
return CurrentUser(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
email="test@example.com",
|
||||
phone="+8613812345678",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_share_schedule_item_returns_200() -> None:
|
||||
response = client.post(
|
||||
f"/api/v1/schedule-items/{item_id}/share",
|
||||
json={
|
||||
"email": "friend@example.com",
|
||||
"phone": "+8613810000000",
|
||||
"permission_view": True,
|
||||
"permission_edit": False,
|
||||
"permission_invite": True,
|
||||
@@ -62,7 +62,7 @@ def test_share_schedule_item_returns_200() -> None:
|
||||
body = response.json()
|
||||
assert body["message"] == "Calendar invitation sent"
|
||||
assert service.last_share_request is not None
|
||||
assert service.last_share_request.email == "friend@example.com"
|
||||
assert service.last_share_request.phone == "+8613810000000"
|
||||
assert service.last_share_request.permission_invite is True
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -8,30 +8,31 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from schemas.user.context import UserContext
|
||||
from v1.users.dependencies import get_current_user, get_user_service
|
||||
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
class FakeUserService:
|
||||
"""Fake service for integration testing."""
|
||||
|
||||
def __init__(self, user: UserResponse) -> None:
|
||||
def __init__(self, user: UserContext) -> None:
|
||||
self._user = user
|
||||
self._search_results: list[UserResponse] = []
|
||||
self._search_results: list[UserContext] = []
|
||||
|
||||
def set_search_results(self, results: list[UserResponse]) -> None:
|
||||
def set_search_results(self, results: list[UserContext]) -> None:
|
||||
self._search_results = results
|
||||
|
||||
async def get_me(self) -> UserResponse:
|
||||
async def get_me(self) -> UserContext:
|
||||
if self._user.id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return self._user
|
||||
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserResponse:
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserContext:
|
||||
if self._user.id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return UserResponse(
|
||||
return UserContext(
|
||||
id=self._user.id,
|
||||
username=(
|
||||
update.username if update.username is not None else self._user.username
|
||||
@@ -44,7 +45,7 @@ class FakeUserService:
|
||||
bio=update.bio if update.bio is not None else self._user.bio,
|
||||
)
|
||||
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
|
||||
if request.query:
|
||||
return self._search_results if self._search_results else [self._user]
|
||||
return []
|
||||
@@ -68,7 +69,7 @@ def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
|
||||
|
||||
def test_get_me_returns_user() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
user = UserContext(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
@@ -91,7 +92,7 @@ def test_get_me_returns_user() -> None:
|
||||
|
||||
def test_patch_me_updates_user() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
user = UserContext(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
@@ -117,7 +118,7 @@ def test_patch_me_updates_user() -> None:
|
||||
|
||||
def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
user = UserContext(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
@@ -142,7 +143,7 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
|
||||
def test_search_users_returns_list() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
user = UserContext(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
@@ -167,7 +168,7 @@ def test_search_users_returns_list() -> None:
|
||||
|
||||
def test_search_users_empty_query_returns_422() -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
user = UserResponse(
|
||||
user = UserContext(
|
||||
id=str(user_id),
|
||||
username="demo",
|
||||
avatar_url=None,
|
||||
|
||||
@@ -115,6 +115,34 @@ class _FailingStreamAgentService(_FakeAgentService):
|
||||
raise RuntimeError("redis timeout")
|
||||
|
||||
|
||||
class _TerminalStreamAgentService(_FakeAgentService):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.stream_calls = 0
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del thread_id, last_event_id, current_user
|
||||
self.stream_calls += 1
|
||||
if self.stream_calls == 1:
|
||||
return [
|
||||
{
|
||||
"id": "9-0",
|
||||
"event": {
|
||||
"type": "RUN_FINISHED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
},
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
@@ -129,13 +157,13 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
authorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
@@ -146,7 +174,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
@@ -161,7 +189,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
def test_stream_reads_from_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
@@ -197,7 +225,7 @@ def test_stream_reads_from_last_event_id() -> None:
|
||||
def test_stream_handles_stream_backend_errors_without_connection_crash() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FailingStreamAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
@@ -226,10 +254,45 @@ def test_stream_handles_stream_backend_errors_without_connection_crash() -> None
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_stops_after_terminal_run_event() -> None:
|
||||
service = _TerminalStreamAgentService()
|
||||
app.dependency_overrides[get_agent_service] = lambda: service
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
original_acquire = agent_router._acquire_sse_slot
|
||||
original_release = agent_router._release_sse_slot
|
||||
|
||||
async def _allow_slot(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _noop_release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
return None
|
||||
|
||||
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=3"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert "event: RUN_FINISHED" in response.text
|
||||
assert service.stream_calls == 1
|
||||
finally:
|
||||
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
|
||||
agent_router._release_sse_slot = original_release # type: ignore[assignment]
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_rejects_invalid_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -255,7 +318,7 @@ def test_history_returns_state_snapshot() -> None:
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
authorized = client.get(
|
||||
"/api/v1/agent/history",
|
||||
@@ -276,7 +339,7 @@ def test_history_returns_state_snapshot() -> None:
|
||||
def test_user_history_returns_latest_snapshot() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
try:
|
||||
@@ -292,7 +355,7 @@ def test_user_history_returns_latest_snapshot() -> None:
|
||||
def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -312,7 +375,7 @@ def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -323,7 +386,7 @@ def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
def test_run_rejects_client_supplied_history_messages() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -340,7 +403,7 @@ def test_run_rejects_client_supplied_history_messages() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -351,7 +414,7 @@ def test_run_rejects_client_supplied_history_messages() -> None:
|
||||
def test_upload_attachment_returns_reference() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -376,7 +439,7 @@ def test_upload_attachment_returns_reference() -> None:
|
||||
def test_create_attachment_signed_url_returns_url() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -399,7 +462,7 @@ def test_create_attachment_signed_url_returns_url() -> None:
|
||||
|
||||
def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
|
||||
async def mock_transcribe_file(file_path: str, filename: str) -> str:
|
||||
@@ -434,7 +497,7 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
|
||||
|
||||
def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
|
||||
@@ -457,7 +520,7 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
|
||||
|
||||
def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
@@ -478,7 +541,7 @@ def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None:
|
||||
|
||||
def test_asr_transcribe_rejects_invalid_wav_payload(monkeypatch) -> None:
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
id=uuid4(), phone="+8613812345678"
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@@ -20,16 +20,16 @@ FIXTURE_IMAGE_PATH = (
|
||||
|
||||
|
||||
async def _live_access_token(client: httpx.AsyncClient) -> str:
|
||||
email = os.getenv("AGENT_LIVE_EMAIL")
|
||||
phone = os.getenv("AGENT_LIVE_PHONE")
|
||||
password = os.getenv("AGENT_LIVE_PASSWORD")
|
||||
if not email or not password:
|
||||
if not phone or not password:
|
||||
pytest.fail(
|
||||
"AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_EMAIL and AGENT_LIVE_PASSWORD"
|
||||
"AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_PHONE and AGENT_LIVE_PASSWORD"
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/v1/auth/sessions",
|
||||
json={"email": email, "password": password},
|
||||
json={"phone": phone, "password": password},
|
||||
)
|
||||
response_text = response.text.strip().replace("\n", " ")
|
||||
truncated_text = response_text[:200]
|
||||
@@ -67,7 +67,7 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
@@ -143,7 +143,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
@@ -10,7 +10,7 @@ from v1.agent.service import ensure_session_owner
|
||||
|
||||
|
||||
def test_owner_guard_denies_non_owner() -> None:
|
||||
user = CurrentUser(id=uuid4(), email="self@example.com")
|
||||
user = CurrentUser(id=uuid4(), phone="self@example.com")
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
ensure_session_owner(owner_id="other-user", current_user=user)
|
||||
|
||||
@@ -7,6 +7,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from sqlalchemy import select
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
|
||||
@@ -79,6 +81,7 @@ async def test_persist_user_message_sets_session_title_when_empty() -> None:
|
||||
session_id=session_id,
|
||||
content=" 请帮我安排明天下午开会 ",
|
||||
metadata=None,
|
||||
visibility_mask=1,
|
||||
)
|
||||
|
||||
assert session_row.title == "请帮我安排明天下午开会"
|
||||
@@ -101,6 +104,7 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||
session_id=session_id,
|
||||
content="新的消息内容",
|
||||
metadata=None,
|
||||
visibility_mask=1,
|
||||
)
|
||||
|
||||
assert session_row.title == "已有标题"
|
||||
@@ -164,3 +168,13 @@ async def test_get_history_day_uses_target_day_queries_only() -> None:
|
||||
messages = payload["messages"]
|
||||
assert isinstance(messages, list)
|
||||
assert len(messages) == 1
|
||||
|
||||
|
||||
def test_apply_visibility_filter_adds_bitwise_expression() -> None:
|
||||
repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type]
|
||||
stmt = select(AgentChatMessage)
|
||||
|
||||
filtered = repository._apply_visibility_filter(stmt=stmt, visibility_mask=1)
|
||||
|
||||
assert "visibility_mask" in str(filtered)
|
||||
assert "&" in str(filtered)
|
||||
|
||||
@@ -20,6 +20,7 @@ class _FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.persisted_user_messages: list[dict[str, object]] = []
|
||||
self.created_session_calls = 0
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
if session_id == "00000000-0000-0000-0000-000000000001":
|
||||
@@ -30,6 +31,7 @@ class _FakeRepository:
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id
|
||||
self.created_session_calls += 1
|
||||
return session_id or "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
@@ -39,9 +41,13 @@ class _FakeRepository:
|
||||
return None
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
del session_id, before
|
||||
del session_id, before, visibility_mask
|
||||
return None
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
@@ -54,15 +60,42 @@ class _FakeRepository:
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None:
|
||||
self.persisted_user_messages.append(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"visibility_mask": visibility_mask,
|
||||
}
|
||||
)
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None:
|
||||
normalized = agent_type.strip().lower()
|
||||
mapping = {
|
||||
"router": 16,
|
||||
"worker": 17,
|
||||
"memory": 18,
|
||||
}
|
||||
bit = mapping.get(normalized)
|
||||
if bit is None:
|
||||
return None
|
||||
return {
|
||||
"agent_type": normalized,
|
||||
"status": "active",
|
||||
"config": {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": None,
|
||||
"timeout_seconds": 30,
|
||||
"visibility_consumer_bit": bit,
|
||||
"context_messages": {"mode": "number", "count": 20},
|
||||
"enabled_tools": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
@@ -122,11 +155,11 @@ class _FakeAttachmentStorage:
|
||||
def _user() -> CurrentUser:
|
||||
return CurrentUser(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
email="user@example.com",
|
||||
phone="+8613812345678",
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, urls: list[str]) -> RunAgentInput:
|
||||
def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgentInput:
|
||||
content: list[dict[str, str]] = [{"type": "text", "text": "hello"}]
|
||||
for url in urls:
|
||||
content.append({"type": "binary", "mimeType": "image/png", "url": url})
|
||||
@@ -144,7 +177,7 @@ def _build_run_input(*, urls: list[str]) -> RunAgentInput:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"forwardedProps": {"agent_type": agent_type},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -222,6 +255,68 @@ async def test_enqueue_run_persists_attachment_and_queue_without_user_token(
|
||||
assert run_input["runId"] == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
base_url = str(config.supabase.url).rstrip("/")
|
||||
safe_path = quote(
|
||||
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
||||
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
urls=[
|
||||
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
||||
],
|
||||
agent_type="planner",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
base_url = str(config.supabase.url).rstrip("/")
|
||||
safe_path = quote(
|
||||
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
||||
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
urls=[
|
||||
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
||||
],
|
||||
agent_type="memory",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "memory mode is automation-only"
|
||||
assert repository.created_session_calls == 0
|
||||
assert repository.persisted_user_messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
@@ -317,9 +412,13 @@ async def test_enqueue_run_rejects_too_many_attachments(monkeypatch) -> None:
|
||||
async def test_get_history_snapshot_filters_out_tool_messages() -> None:
|
||||
class _HistoryRepository(_FakeRepository):
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
del session_id, before
|
||||
del session_id, before, visibility_mask
|
||||
return {
|
||||
"day": "2026-03-17",
|
||||
"hasMore": False,
|
||||
|
||||
@@ -8,13 +8,9 @@ from fastapi import HTTPException
|
||||
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
VerificationCreateRequest,
|
||||
VerificationVerifyRequest,
|
||||
VerificationResendRequest,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,314 +31,83 @@ class TestSupabaseAuthGateway:
|
||||
return SupabaseAuthGateway(), mock_client, mock_admin_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_calls_email_with_string(
|
||||
async def test_send_otp_sets_should_create_user(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
mock_sign_in_with_otp = MagicMock()
|
||||
mock_client.auth.sign_in_with_otp = mock_sign_in_with_otp
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
await sut.request_password_reset(request)
|
||||
await sut.send_otp(OtpSendRequest(phone="+8613812345678"))
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_verification_maps_timeout_error_to_503(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_client.auth.sign_up = MagicMock(
|
||||
side_effect=AuthError("request_timeout", None)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.create_verification(
|
||||
VerificationCreateRequest(
|
||||
username="tester",
|
||||
email="test@example.com",
|
||||
password="secret123",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Auth service temporarily unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_with_redirect(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
options={"redirect_to": "http://localhost:3000/reset-password"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_verification_rejects_untrusted_redirect_url(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.create_verification(
|
||||
VerificationCreateRequest(
|
||||
username="tester",
|
||||
email="test@example.com",
|
||||
password="secret123",
|
||||
redirect_to="https://evil.example.com/callback",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid redirect URL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_untrusted_redirect_url(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.request_password_reset(
|
||||
PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="https://evil.example.com/reset",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid redirect URL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_swallows_auth_error(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
|
||||
result = await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_extracts_email_from_mapping(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"email": "test@example.com"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_invalid_email_shape(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"unexpected": "value"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid email"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_updates_password_by_user_id(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, mock_admin_client = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id="user-1"),
|
||||
)
|
||||
mock_verify_otp = MagicMock(return_value=verify_response)
|
||||
mock_client.auth.verify_otp = mock_verify_otp
|
||||
|
||||
mock_update_user_by_id = MagicMock()
|
||||
mock_admin_client.auth.admin = SimpleNamespace(
|
||||
update_user_by_id=mock_update_user_by_id
|
||||
)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
token="123456",
|
||||
new_password="newpassword123",
|
||||
)
|
||||
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
mock_verify_otp.assert_called_once_with(
|
||||
mock_sign_in_with_otp.assert_called_once_with(
|
||||
{
|
||||
"type": "recovery",
|
||||
"email": "test@example.com",
|
||||
"token": "123456",
|
||||
"phone": "+8613812345678",
|
||||
"options": {"should_create_user": True},
|
||||
}
|
||||
)
|
||||
mock_update_user_by_id.assert_called_once_with(
|
||||
"user-1",
|
||||
{"password": "newpassword123"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_raises_when_user_id_missing(
|
||||
async def test_create_phone_session_uses_verify_otp(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id=""),
|
||||
session=SimpleNamespace(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
),
|
||||
user=SimpleNamespace(id="user-1", phone="+8613812345678"),
|
||||
)
|
||||
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
token="123456",
|
||||
new_password="newpassword123",
|
||||
response = await sut.create_phone_session(
|
||||
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
|
||||
)
|
||||
|
||||
assert response.user.id == "user-1"
|
||||
assert response.access_token == "access"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_phone_session_normalizes_phone_without_plus_prefix(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
),
|
||||
user=SimpleNamespace(id="user-1", phone="14155552671"),
|
||||
)
|
||||
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
|
||||
response = await sut.create_phone_session(
|
||||
PhoneSessionCreateRequest(phone="+14155552671", token="123456")
|
||||
)
|
||||
|
||||
assert response.user.phone == "+14155552671"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_session_maps_invalid_token(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_client.auth.refresh_session = MagicMock(
|
||||
return_value=SimpleNamespace(session=None, user=None)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.confirm_password_reset(request)
|
||||
await sut.refresh_session(SessionRefreshRequest(refresh_token="bad"))
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid or expired verification code"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_resend_calls_reset_password_email(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
await sut.resend_verification(
|
||||
VerificationResendRequest(
|
||||
type="recovery",
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
options={"redirect_to": "http://localhost:3000/reset-password"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_verification_maps_internal_error_to_503(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_client.auth.verify_otp = MagicMock(
|
||||
side_effect=AuthError("internal_server_error", None)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.verify_verification(
|
||||
VerificationVerifyRequest(
|
||||
type="signup",
|
||||
email="test@example.com",
|
||||
token="123456",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Auth service temporarily unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_maps_internal_error_to_503(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_client.auth.sign_in_with_password = MagicMock(
|
||||
side_effect=AuthError("internal_server_error", None)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.create_session(
|
||||
SessionCreateRequest(
|
||||
email="test@example.com",
|
||||
password="secret123",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Auth service temporarily unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_session_maps_bad_gateway_to_503(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_client.auth.refresh_session = MagicMock(
|
||||
side_effect=AuthError("bad_gateway", None)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.refresh_session(SessionRefreshRequest(refresh_token="rt"))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Auth service temporarily unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_maps_service_unavailable_to_503(
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_client.auth.verify_otp = MagicMock(
|
||||
side_effect=AuthError("service_unavailable", None)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await sut.confirm_password_reset(
|
||||
PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
token="123456",
|
||||
new_password="newpassword123",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Auth service temporarily unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_uses_in_memory_cache(
|
||||
async def test_get_user_by_phone_uses_in_memory_cache(
|
||||
self,
|
||||
gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -350,9 +115,9 @@ class TestSupabaseAuthGateway:
|
||||
sut, _, _ = gateway
|
||||
user = SimpleNamespace(
|
||||
id="user-1",
|
||||
email="cached@example.com",
|
||||
phone="+8613811112222",
|
||||
created_at="2026-03-16T00:00:00Z",
|
||||
email_confirmed_at=None,
|
||||
phone_confirmed_at=None,
|
||||
)
|
||||
list_calls = {"count": 0}
|
||||
|
||||
@@ -362,9 +127,39 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
monkeypatch.setattr("v1.auth.gateway._list_auth_users", _fake_list_auth_users)
|
||||
|
||||
first = await sut.get_user_by_email("cached@example.com")
|
||||
second = await sut.get_user_by_email("CACHED@example.com")
|
||||
first = await sut.get_user_by_phone("+8613811112222")
|
||||
second = await sut.get_user_by_phone("+8613811112222")
|
||||
|
||||
assert first.id == "user-1"
|
||||
assert second.email == "cached@example.com"
|
||||
assert second.phone == "+8613811112222"
|
||||
assert list_calls["count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_user_ids_by_phone_supports_suffix_query(
|
||||
self,
|
||||
gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
users = [
|
||||
SimpleNamespace(
|
||||
id="user-cn",
|
||||
phone="+8613811112222",
|
||||
created_at="2026-03-16T00:00:00Z",
|
||||
phone_confirmed_at=None,
|
||||
),
|
||||
SimpleNamespace(
|
||||
id="user-us",
|
||||
phone="+14155552671",
|
||||
created_at="2026-03-16T00:00:00Z",
|
||||
phone_confirmed_at=None,
|
||||
),
|
||||
]
|
||||
|
||||
monkeypatch.setattr("v1.auth.gateway._list_auth_users", lambda _client: users)
|
||||
|
||||
matched_cn = await sut.search_user_ids_by_phone("13811112222")
|
||||
matched_us = await sut.search_user_ids_by_phone("4155552671")
|
||||
|
||||
assert matched_cn == ["user-cn"]
|
||||
assert matched_us == ["user-us"]
|
||||
|
||||
@@ -5,72 +5,28 @@ from pydantic import ValidationError
|
||||
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationVerifyRequest,
|
||||
VerificationResendRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_signup_requires_valid_email() -> None:
|
||||
def test_send_otp_requires_valid_phone() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationCreateRequest(
|
||||
username="demo", email="not-an-email", password="secret123"
|
||||
)
|
||||
OtpSendRequest(phone="13812345678")
|
||||
|
||||
|
||||
def test_signup_requires_username() -> None:
|
||||
def test_send_otp_accepts_e164_phone() -> None:
|
||||
request = OtpSendRequest(phone="+14155552671")
|
||||
|
||||
assert request.phone == "+14155552671"
|
||||
|
||||
|
||||
def test_phone_session_requires_six_digit_token() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationCreateRequest.model_validate(
|
||||
{"email": "user@example.com", "password": "secret123"}
|
||||
)
|
||||
|
||||
|
||||
def test_signup_allows_any_invite_code_input() -> None:
|
||||
request = VerificationCreateRequest(
|
||||
username="demo",
|
||||
email="user@example.com",
|
||||
password="secret123",
|
||||
invite_code="abc123",
|
||||
)
|
||||
|
||||
assert request.invite_code == "abc123"
|
||||
|
||||
|
||||
def test_signup_verify_requires_six_digit_token() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationVerifyRequest(email="user@example.com", token="abc123")
|
||||
|
||||
|
||||
def test_signup_verify_disallows_new_password() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationVerifyRequest(
|
||||
type="signup",
|
||||
email="user@example.com",
|
||||
token="123456",
|
||||
new_password="secret123",
|
||||
)
|
||||
|
||||
|
||||
def test_recovery_verify_requires_new_password() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationVerifyRequest(
|
||||
type="recovery",
|
||||
email="user@example.com",
|
||||
token="123456",
|
||||
)
|
||||
|
||||
|
||||
def test_signup_resend_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VerificationResendRequest(email="invalid")
|
||||
|
||||
|
||||
def test_login_requires_valid_email() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SessionCreateRequest(email="invalid", password="secret123")
|
||||
PhoneSessionCreateRequest(phone="+8613812345678", token="abc123")
|
||||
|
||||
|
||||
def test_refresh_requires_token() -> None:
|
||||
@@ -78,8 +34,13 @@ def test_refresh_requires_token() -> None:
|
||||
SessionRefreshRequest(refresh_token="")
|
||||
|
||||
|
||||
def test_logout_requires_token() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SessionDeleteRequest(refresh_token="")
|
||||
|
||||
|
||||
def test_session_response_maps_user() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
user = AuthUser(id="user-1", phone="+14155552671")
|
||||
response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
@@ -89,4 +50,4 @@ def test_session_response_maps_user() -> None:
|
||||
)
|
||||
|
||||
assert response.user.id == "user-1"
|
||||
assert response.user.email == "user@example.com"
|
||||
assert response.user.phone == "+14155552671"
|
||||
|
||||
@@ -2,19 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import v1.auth.gateway as auth_gateway_module
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService, AuthServiceGateway
|
||||
|
||||
@@ -22,23 +15,16 @@ from v1.auth.service import AuthService, AuthServiceGateway
|
||||
class FakeGateway(AuthServiceGateway):
|
||||
def __init__(self, response: SessionResponse) -> None:
|
||||
self._response = response
|
||||
self.last_create_verification_request: VerificationCreateRequest | None = None
|
||||
self.last_send_otp_request: OtpSendRequest | None = None
|
||||
self.last_phone_session_request: PhoneSessionCreateRequest | None = None
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
self.last_create_verification_request = request
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
self.last_send_otp_request = request
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return self._response
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
return None
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
self.last_phone_session_request = request
|
||||
return self._response
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
@@ -47,85 +33,10 @@ class FakeGateway(AuthServiceGateway):
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LogoutAssertingGateway(AuthServiceGateway):
|
||||
def __init__(self, expected_refresh_token: str) -> None:
|
||||
self._expected_refresh_token = expected_refresh_token
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
assert refresh_token == self._expected_refresh_token
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_forwards_refresh_token() -> None:
|
||||
service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
|
||||
|
||||
await service.delete_session("refresh-token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signup_resend_returns_none() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
service = AuthService(gateway=FakeGateway(token_response))
|
||||
|
||||
result = await service.resend_verification(
|
||||
VerificationResendRequest(email="user@example.com")
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_verification_ignores_invalid_invite_code() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
async def test_send_otp_forwards_payload() -> None:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
@@ -136,22 +47,15 @@ async def test_create_verification_ignores_invalid_invite_code() -> None:
|
||||
gateway = FakeGateway(token_response)
|
||||
service = AuthService(gateway=gateway)
|
||||
|
||||
await service.create_verification(
|
||||
VerificationCreateRequest(
|
||||
username="demo",
|
||||
email="user@example.com",
|
||||
password="secret123",
|
||||
invite_code="bad-code",
|
||||
)
|
||||
)
|
||||
await service.send_otp(OtpSendRequest(phone="+8613812345678"))
|
||||
|
||||
assert gateway.last_create_verification_request is not None
|
||||
assert gateway.last_create_verification_request.invite_code is None
|
||||
assert gateway.last_send_otp_request is not None
|
||||
assert gateway.last_send_otp_request.phone == "+8613812345678"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_verification_normalizes_valid_invite_code() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
async def test_create_phone_session_forwards_payload() -> None:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
@@ -162,59 +66,28 @@ async def test_create_verification_normalizes_valid_invite_code() -> None:
|
||||
gateway = FakeGateway(token_response)
|
||||
service = AuthService(gateway=gateway)
|
||||
|
||||
await service.create_verification(
|
||||
VerificationCreateRequest(
|
||||
username="demo",
|
||||
email="user@example.com",
|
||||
password="secret123",
|
||||
invite_code="a2b3",
|
||||
)
|
||||
response = await service.create_phone_session(
|
||||
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
|
||||
)
|
||||
|
||||
assert gateway.last_create_verification_request is not None
|
||||
assert gateway.last_create_verification_request.invite_code == "A2B3"
|
||||
assert gateway.last_phone_session_request is not None
|
||||
assert gateway.last_phone_session_request.token == "123456"
|
||||
assert response.user.phone == "+8613812345678"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supabase_signup_passes_username_in_metadata(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_payload: dict[str, object] = {}
|
||||
|
||||
class FakeSupabaseAuth:
|
||||
def sign_up(self, payload: dict[str, object]) -> object:
|
||||
captured_payload.update(payload)
|
||||
|
||||
class _User:
|
||||
id = "user-1"
|
||||
email = "user@example.com"
|
||||
|
||||
class _Session:
|
||||
access_token = "access"
|
||||
refresh_token = "refresh"
|
||||
expires_in = 3600
|
||||
token_type = "bearer"
|
||||
|
||||
class _Response:
|
||||
user = _User()
|
||||
session = None
|
||||
|
||||
return _Response()
|
||||
|
||||
class FakeClient:
|
||||
auth = FakeSupabaseAuth()
|
||||
|
||||
monkeypatch.setattr(
|
||||
auth_gateway_module.supabase_service, "get_client", lambda: FakeClient()
|
||||
async def test_refresh_session_forwards_payload() -> None:
|
||||
user = AuthUser(id="user-1", phone="+8613812345678")
|
||||
token_response = SessionResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
gateway = FakeGateway(token_response)
|
||||
service = AuthService(gateway=gateway)
|
||||
|
||||
gateway = auth_gateway_module.SupabaseAuthGateway()
|
||||
await gateway.create_verification(
|
||||
VerificationCreateRequest(
|
||||
username="demo",
|
||||
email="user@example.com",
|
||||
password="secret123",
|
||||
)
|
||||
)
|
||||
response = await service.refresh_session(SessionRefreshRequest(refresh_token="rt"))
|
||||
|
||||
assert captured_payload["data"] == {"username": "demo"}
|
||||
assert response.access_token == "access"
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from schemas.user.context import UserContext
|
||||
|
||||
from v1.friendships.schemas import (
|
||||
UserBasicInfo,
|
||||
FriendRequestCreate,
|
||||
FriendRequestResponse,
|
||||
FriendResponse,
|
||||
@@ -13,16 +15,16 @@ from v1.friendships.schemas import (
|
||||
)
|
||||
|
||||
|
||||
def test_user_basic_info_maps_fields() -> None:
|
||||
user = UserBasicInfo(id="user-1", username="alice", avatar_url=None)
|
||||
def test_user_context_maps_fields() -> None:
|
||||
user = UserContext(id="user-1", username="alice", avatar_url=None)
|
||||
|
||||
assert user.id == "user-1"
|
||||
assert user.username == "alice"
|
||||
assert user.avatar_url is None
|
||||
|
||||
|
||||
def test_user_basic_info_with_avatar() -> None:
|
||||
user = UserBasicInfo(
|
||||
def test_user_context_with_avatar() -> None:
|
||||
user = UserContext(
|
||||
id="user-2", username="bob", avatar_url="https://example.com/avatar.png"
|
||||
)
|
||||
|
||||
@@ -49,13 +51,13 @@ def test_friend_request_create_without_content() -> None:
|
||||
|
||||
def test_friend_request_create_content_max_length() -> None:
|
||||
target_id = uuid4()
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(ValidationError):
|
||||
FriendRequestCreate(target_user_id=target_id, content="x" * 201)
|
||||
|
||||
|
||||
def test_friend_request_response_maps_fields() -> None:
|
||||
sender = UserBasicInfo(id="user-1", username="alice", avatar_url=None)
|
||||
recipient = UserBasicInfo(id="user-2", username="bob", avatar_url=None)
|
||||
sender = UserContext(id="user-1", username="alice", avatar_url=None)
|
||||
recipient = UserContext(id="user-2", username="bob", avatar_url=None)
|
||||
request_id = uuid4()
|
||||
created = datetime(2026, 1, 15, 10, 30, 0)
|
||||
|
||||
@@ -63,7 +65,7 @@ def test_friend_request_response_maps_fields() -> None:
|
||||
id=request_id,
|
||||
sender=sender,
|
||||
recipient=recipient,
|
||||
content="Hello!",
|
||||
content={"text": "Hello!"},
|
||||
status="pending",
|
||||
created_at=created,
|
||||
)
|
||||
@@ -76,7 +78,7 @@ def test_friend_request_response_maps_fields() -> None:
|
||||
|
||||
|
||||
def test_friend_response_maps_fields() -> None:
|
||||
friend_user = UserBasicInfo(id="user-2", username="bob", avatar_url=None)
|
||||
friend_user = UserContext(id="user-2", username="bob", avatar_url=None)
|
||||
request_id = uuid4()
|
||||
created = datetime(2026, 1, 15, 10, 30, 0)
|
||||
accepted = datetime(2026, 1, 16, 12, 0, 0)
|
||||
@@ -96,7 +98,7 @@ def test_friend_response_maps_fields() -> None:
|
||||
|
||||
|
||||
def test_friend_response_accepted_at_optional() -> None:
|
||||
friend_user = UserBasicInfo(id="user-2", username="bob", avatar_url=None)
|
||||
friend_user = UserContext(id="user-2", username="bob", avatar_url=None)
|
||||
request_id = uuid4()
|
||||
created = datetime(2026, 1, 15, 10, 30, 0)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from core.auth.models import CurrentUser
|
||||
from models.inbox_messages import InboxMessage, InboxMessageType
|
||||
from models.schedule_items import ScheduleItem
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
from v1.auth.schemas import UserByPhoneResponse
|
||||
from v1.schedule_items.repository import ScheduleItemRepository
|
||||
from v1.schedule_items.schemas import ScheduleItemShareRequest
|
||||
from v1.schedule_items.service import ScheduleItemService
|
||||
@@ -20,18 +20,18 @@ from v1.schedule_items.service import ScheduleItemService
|
||||
|
||||
def test_share_request_schema() -> None:
|
||||
request = ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=True,
|
||||
permission_invite=False,
|
||||
)
|
||||
assert request.email == "friend@example.com"
|
||||
assert request.phone == "+8613810000000"
|
||||
assert request.permission_view is True
|
||||
|
||||
|
||||
def test_permission_bits_calculation() -> None:
|
||||
request = ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=True,
|
||||
permission_invite=False,
|
||||
@@ -71,12 +71,12 @@ class ShareRepo:
|
||||
|
||||
|
||||
class AuthGatewayStub:
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
return UserByEmailResponse(
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
return UserByPhoneResponse(
|
||||
id="00000000-0000-0000-0000-000000000222",
|
||||
email=email,
|
||||
phone=phone,
|
||||
created_at="2026-02-28T10:00:00Z",
|
||||
email_confirmed_at=None,
|
||||
phone_confirmed_at=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -119,12 +119,12 @@ class InboxRepoStub:
|
||||
|
||||
|
||||
class AuthGatewayInvalidIdStub:
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
return UserByEmailResponse(
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
return UserByPhoneResponse(
|
||||
id="not-a-uuid",
|
||||
email=email,
|
||||
phone=phone,
|
||||
created_at="2026-02-28T10:00:00Z",
|
||||
email_confirmed_at=None,
|
||||
phone_confirmed_at=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ async def test_share_forbidden_when_not_owner() -> None:
|
||||
await service.share(
|
||||
item_id,
|
||||
ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=False,
|
||||
permission_invite=False,
|
||||
@@ -178,7 +178,7 @@ async def test_share_success_creates_calendar_invitation_message() -> None:
|
||||
result = await service.share(
|
||||
item_id,
|
||||
ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=True,
|
||||
permission_invite=False,
|
||||
@@ -211,7 +211,7 @@ async def test_share_returns_not_found_when_item_missing() -> None:
|
||||
await service.share(
|
||||
uuid4(),
|
||||
ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=False,
|
||||
permission_invite=False,
|
||||
@@ -241,7 +241,7 @@ async def test_share_invalid_auth_user_id_returns_503() -> None:
|
||||
await service.share(
|
||||
item_id,
|
||||
ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=False,
|
||||
permission_invite=False,
|
||||
@@ -274,7 +274,7 @@ async def test_share_sqlalchemy_error_rolls_back() -> None:
|
||||
await service.share(
|
||||
item_id,
|
||||
ScheduleItemShareRequest(
|
||||
email="friend@example.com",
|
||||
phone="+8613810000000",
|
||||
permission_view=True,
|
||||
permission_edit=False,
|
||||
permission_invite=False,
|
||||
|
||||
@@ -22,7 +22,7 @@ async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -
|
||||
del token
|
||||
return deps.CurrentUser(
|
||||
id=UUID("e8845a17-282b-4a63-8025-194a06235958"),
|
||||
email="dagronl@126.com",
|
||||
phone="dagronl@126.com",
|
||||
role="authenticated",
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -
|
||||
user = await deps.get_current_user(authorization="Bearer valid-token")
|
||||
|
||||
assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958"
|
||||
assert user.email == "dagronl@126.com"
|
||||
assert user.phone == "dagronl@126.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.users.schemas import UserUpdateRequest
|
||||
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class _FakeProfile:
|
||||
username: str
|
||||
avatar_url: str | None
|
||||
bio: str | None
|
||||
settings: dict | None = None
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
@@ -51,6 +52,37 @@ class _FakeSession:
|
||||
self.rollback_called += 1
|
||||
|
||||
|
||||
class _FakeSearchRepository:
|
||||
def __init__(self, profiles: list[_FakeProfile]) -> None:
|
||||
self._profiles_by_id = {profile.id: profile for profile in profiles}
|
||||
|
||||
async def get_by_user_ids(
|
||||
self, user_ids: list[object]
|
||||
) -> dict[object, _FakeProfile]:
|
||||
return {
|
||||
user_id: self._profiles_by_id[user_id]
|
||||
for user_id in user_ids
|
||||
if user_id in self._profiles_by_id
|
||||
}
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[_FakeProfile]:
|
||||
_ = limit
|
||||
return [
|
||||
profile
|
||||
for profile in self._profiles_by_id.values()
|
||||
if query.lower() in profile.username.lower()
|
||||
]
|
||||
|
||||
|
||||
class _FakeAuthLookup:
|
||||
def __init__(self, mapping: dict[str, list[str]]) -> None:
|
||||
self.mapping = mapping
|
||||
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
_ = limit
|
||||
return self.mapping.get(query, [])
|
||||
|
||||
|
||||
class _FakeUserContextCache:
|
||||
def __init__(self, *, should_fail: bool = False) -> None:
|
||||
self.should_fail = should_fail
|
||||
@@ -72,7 +104,7 @@ async def test_update_me_invalidates_user_context_cache() -> None:
|
||||
session = _FakeSession()
|
||||
cache = _FakeUserContextCache()
|
||||
service = UserService(
|
||||
repository=repo,
|
||||
repository=repo, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=user_id),
|
||||
user_context_cache=cache, # type: ignore[arg-type]
|
||||
@@ -94,7 +126,7 @@ async def test_update_me_succeeds_when_cache_invalidation_fails() -> None:
|
||||
session = _FakeSession()
|
||||
cache = _FakeUserContextCache(should_fail=True)
|
||||
service = UserService(
|
||||
repository=repo,
|
||||
repository=repo, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=user_id),
|
||||
user_context_cache=cache, # type: ignore[arg-type]
|
||||
@@ -105,3 +137,59 @@ async def test_update_me_succeeds_when_cache_invalidation_fails() -> None:
|
||||
assert result.username == "new-name"
|
||||
assert session.commit_called == 1
|
||||
assert cache.invalidated_user_ids == [user_id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_users_supports_phone_without_country_code() -> None:
|
||||
user_id = uuid4()
|
||||
repo = _FakeSearchRepository(
|
||||
[
|
||||
_FakeProfile(
|
||||
id=user_id,
|
||||
username="alice",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
session = _FakeSession()
|
||||
auth_lookup = _FakeAuthLookup({"13812345678": [str(user_id)]})
|
||||
service = UserService(
|
||||
repository=repo, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=user_id),
|
||||
auth_gateway=auth_lookup, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
results = await service.search_users(UserSearchRequest(query="13812345678"))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == str(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_users_preserves_numeric_username_lookup() -> None:
|
||||
user_id = uuid4()
|
||||
repo = _FakeSearchRepository(
|
||||
[
|
||||
_FakeProfile(
|
||||
id=user_id,
|
||||
username="20260319",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
session = _FakeSession()
|
||||
auth_lookup = _FakeAuthLookup({})
|
||||
service = UserService(
|
||||
repository=repo, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=user_id),
|
||||
auth_gateway=auth_lookup, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
results = await service.search_users(UserSearchRequest(query="20260319"))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].username == "20260319"
|
||||
|
||||
Reference in New Issue
Block a user