refactor(backend): update API routes and service layer

- Update agent router/service/repository with new endpoints
- Update auth routes with phone-based authentication
- Update users service with new phone lookup
- Update schedule_items with new schemas
- Update message schemas with visibility support
- Update settings with new automation scheduler config
- Update CLI with new commands
- Update tests to match new API contracts
This commit is contained in:
qzl
2026-03-19 18:42:59 +08:00
parent 641d847008
commit f0af44d840
36 changed files with 1083 additions and 1853 deletions
+1 -1
View File
@@ -7,5 +7,5 @@ from uuid import UUID
@dataclass(frozen=True)
class CurrentUser:
id: UUID
email: str | None = None
phone: str | None = None
role: str | None = None
+9 -1
View File
@@ -61,6 +61,7 @@ class RuntimeSettings(BaseModel):
]
)
sql_log_queries: bool = False
trusted_proxy_ips: list[str] = Field(default_factory=list)
@field_validator("log_dir", mode="before")
@classmethod
@@ -162,6 +163,12 @@ class AgentRuntimeSettings(BaseModel):
user_context_cache_max_turns: int = Field(default=6, ge=1, le=100)
class AutomationSchedulerSettings(BaseModel):
enabled: bool = True
interval_seconds: int = Field(default=60, ge=5, le=3600)
batch_limit: int = Field(default=100, ge=1, le=1000)
class LlmSettings(BaseModel):
provider_keys: dict[str, str] = Field(default_factory=dict)
@@ -225,7 +232,7 @@ class AppVersionSettings(BaseModel):
class TestSettings(BaseModel):
email: str = ""
phone: str = ""
password: str = ""
@@ -250,6 +257,7 @@ class Settings(BaseSettings):
llm: LlmSettings = LlmSettings()
litellm: LiteLLMSettings = LiteLLMSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
automation_scheduler: AutomationSchedulerSettings = AutomationSchedulerSettings()
taskiq: TaskiqSettings = TaskiqSettings()
database: DatabaseSettings = DatabaseSettings()
app_version: AppVersionSettings = AppVersionSettings()
+30 -2
View File
@@ -5,6 +5,8 @@ import subprocess
import sys
from pathlib import Path
from core.agentscope.runtime.tasks import run_automation_scheduler_scan
from core.config.settings import config
from core.config.initial.init_data import initialize_data
from core.logging import get_logger
@@ -101,12 +103,34 @@ async def bootstrap() -> bool:
return True
async def run_automation_scheduler_forever() -> bool:
if not config.automation_scheduler.enabled:
logger.info("Automation scheduler disabled by config")
return True
interval_seconds = int(config.automation_scheduler.interval_seconds)
batch_limit = int(config.automation_scheduler.batch_limit)
logger.info(
"Starting automation scheduler loop",
interval_seconds=interval_seconds,
batch_limit=batch_limit,
)
while True:
try:
await run_automation_scheduler_scan(limit=batch_limit)
except Exception as exc:
logger.exception("Automation scheduler scan failed", error=str(exc))
await asyncio.sleep(interval_seconds)
def main() -> int:
"""CLI entry point."""
if len(sys.argv) < 2:
logger.error("No command provided")
logger.info("Usage: python -m core.runtime.cli <command>")
logger.info("Available commands: migrate, init-data, bootstrap")
logger.info(
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
)
return 1
command = sys.argv[1]
@@ -117,9 +141,13 @@ def main() -> int:
success = asyncio.run(run_init_data())
elif command == "bootstrap":
success = asyncio.run(bootstrap())
elif command == "automation-scheduler":
success = asyncio.run(run_automation_scheduler_forever())
else:
logger.error("Unknown command", command=command)
logger.info("Available commands: migrate, init-data, bootstrap")
logger.info(
"Available commands: migrate, init-data, bootstrap, automation-scheduler"
)
return 1
return 0 if success else 1
+2 -1
View File
@@ -6,7 +6,7 @@ from typing import Any, ClassVar
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from schemas.agent.runtime_models import AgentOutput
from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput
from ..agent import AgentType, ToolAgentOutput
@@ -24,6 +24,7 @@ class AgentChatMessageMetadata(BaseModel):
run_id: str
agent_type: AgentType | None = None
user_message_attachments: list[UserMessageAttachment] | None = None
router_agent_output: RouterAgentOutput | None = None
tool_agent_output: ToolAgentOutput | None = None
agent_output: AgentOutput | None = None
+1 -1
View File
@@ -66,7 +66,7 @@ class UserContext(BaseModel):
id: str
username: str
email: str | None = None
phone: str | None = None
avatar_url: str | None = None
bio: str | None = None
settings: ProfileSettingsUnion | None = None
+44 -3
View File
@@ -6,7 +6,7 @@ from typing import Protocol
from uuid import UUID, uuid4
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
@@ -95,6 +95,7 @@ class AgentRepository:
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
@@ -124,6 +125,7 @@ class AgentRepository:
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content,
visibility_mask=max(int(visibility_mask), 0),
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
@@ -132,7 +134,11 @@ class AgentRepository:
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
@@ -152,6 +158,10 @@ class AgentRepository:
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
target_created_at_stmt = self._apply_visibility_filter(
stmt=target_created_at_stmt,
visibility_mask=visibility_mask,
)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
@@ -175,6 +185,10 @@ class AgentRepository:
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more_stmt = (
select(AgentChatMessage.id)
@@ -183,6 +197,10 @@ class AgentRepository:
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more_stmt = self._apply_visibility_filter(
stmt=has_more_stmt,
visibility_mask=visibility_mask,
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
@@ -196,7 +214,11 @@ class AgentRepository:
}
async def get_recent_messages_by_user_window(
self, *, session_id: str, user_message_limit: int
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
@@ -210,6 +232,10 @@ class AgentRepository:
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.desc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
if not messages_desc:
return []
@@ -294,6 +320,21 @@ class AgentRepository:
)
return payload_model.model_dump(mode="json", exclude_none=True)
def _apply_visibility_filter(
self,
*,
stmt: Select,
visibility_mask: int | None,
) -> Select:
if visibility_mask is None:
return stmt
required_mask = max(int(visibility_mask), 0)
if required_mask == 0:
return stmt
return stmt.where(
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
)
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
+29 -3
View File
@@ -47,6 +47,7 @@ logger = get_logger("v1.agent.router")
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
@@ -72,8 +73,14 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
else:
ttl = await redis.ttl(key)
if int(ttl) < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
after_decr = await redis.decr(key)
if int(after_decr) <= 0:
await redis.delete(key)
return False
return True
except Exception as exc: # noqa: BLE001
@@ -82,7 +89,7 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
user_id=user_id,
reason=str(exc),
)
return False
return True
async def _release_sse_slot(*, user_id: str) -> None:
@@ -92,10 +99,21 @@ async def _release_sse_slot(*, user_id: str) -> None:
count = await redis.decr(key)
if int(count) <= 0:
await redis.delete(key)
return None
ttl = await redis.ttl(key)
if int(ttl) < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception: # noqa: BLE001
return None
def _is_terminal_run_event(event: dict[str, object]) -> bool:
raw_event_type = event.get("type")
return (
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
)
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
@@ -145,8 +163,13 @@ async def stream_events(
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
terminal_event_reached = False
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
while (
not terminal_event_reached
and not await request.is_disconnected()
and idle_polls < idle_limit
):
try:
rows = await service.stream_events(
thread_id=thread_id,
@@ -181,6 +204,9 @@ async def stream_events(
continue
cursor = row_id
yield to_sse_event(row_id, event)
if _is_terminal_run_event(event):
terminal_event_reached = True
break
finally:
await _release_sse_slot(user_id=str(current_user.id))
+84 -9
View File
@@ -17,6 +17,9 @@ from core.auth.models import CurrentUser
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachment,
@@ -51,7 +54,11 @@ class AgentRepositoryLike(Protocol):
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
@@ -62,8 +69,13 @@ class AgentRepositoryLike(Protocol):
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -138,6 +150,17 @@ class AgentService:
created = False
thread_id = run_input.thread_id
run_id = run_input.run_id
forwarded_props = getattr(run_input, "forwarded_props", None)
try:
agent_type = parse_forwarded_props_agent_type(forwarded_props)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if agent_type == "memory":
raise HTTPException(
status_code=422,
detail="memory mode is automation-only",
)
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
@@ -161,25 +184,21 @@ class AgentService:
run_input=run_input,
current_user=current_user,
)
visibility_mask = await self._resolve_user_message_visibility_mask(
agent_type=agent_type
)
await self._repository.persist_user_message(
session_id=thread_id,
content=user_message_text,
metadata=user_message_metadata,
visibility_mask=visibility_mask,
)
await self._repository.commit()
forwarded_props = getattr(run_input, "forwarded_props", None)
system_agent_mode = "worker"
if isinstance(forwarded_props, dict):
raw_mode = forwarded_props.get("system_agent_mode")
if isinstance(raw_mode, str) and raw_mode.strip():
system_agent_mode = raw_mode.strip().lower()
task_id = await self._queue.enqueue(
command={
"command": "run",
"owner_id": str(current_user.id),
"system_agent_mode": system_agent_mode,
"run_input": run_input.model_dump(
mode="json", by_alias=True, exclude_none=True
),
@@ -193,6 +212,61 @@ class AgentService:
created=created,
)
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
normalized_agent_type = agent_type.strip().lower()
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
if normalized_agent_type == "memory":
return bit_mask(bit=18)
agent_config = await self._repository.get_system_agent_config(
agent_type=normalized_agent_type
)
if agent_config is None:
raise HTTPException(
status_code=422, detail="invalid forwarded_props.agent_type"
)
llm_config = SystemAgentLLMConfig.model_validate(
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
)
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
if normalized_agent_type == "worker":
router_config = await self._repository.get_system_agent_config(
agent_type="router"
)
worker_config = await self._repository.get_system_agent_config(
agent_type="worker"
)
if router_config is None or worker_config is None:
raise HTTPException(
status_code=500,
detail="system agent visibility config missing",
)
router_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
router_config.get("config")
if isinstance(router_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
worker_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
worker_config.get("config")
if isinstance(worker_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
return history_bit_mask | router_mask | worker_mask
return history_bit_mask | agent_mask
async def _prepare_user_message(
self,
*,
@@ -408,6 +482,7 @@ class AgentService:
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
)
messages: list[HistoryMessage] = []
+127 -199
View File
@@ -2,28 +2,22 @@ from __future__ import annotations
import asyncio
import time
from collections.abc import Mapping
from typing import Any, cast
from urllib.parse import urlparse
from pydantic import ValidationError
from fastapi import HTTPException
from supabase import AuthError
from core.config.settings import config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
UserByPhoneResponse,
)
from v1.auth.service import AuthServiceGateway
@@ -36,7 +30,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_email: dict[str, Any] = {}
self._users_by_phone: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
@@ -44,47 +38,31 @@ class SupabaseAuthGateway(AuthServiceGateway):
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
"data": metadata,
}
if request.redirect_to:
payload["options"] = {
"email_redirect_to": _validate_redirect_url(request.redirect_to)
"phone": request.phone,
"options": {"should_create_user": True},
}
try:
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
await asyncio.to_thread(sign_in_with_otp, payload)
except AuthError as exc:
logger.warning("Signup failed", error_type=type(exc).__name__)
logger.warning("Send otp failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise HTTPException(
status_code=422, detail="Invalid signup request"
) from exc
raise HTTPException(status_code=429, detail="Too many requests") from exc
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
if request.type != "signup":
raise HTTPException(status_code=422, detail="Invalid request")
client = self._get_client()
payload: dict[str, Any] = {
"type": request.type,
"email": request.email,
"type": "sms",
"phone": request.phone,
"token": request.token,
}
try:
@@ -92,7 +70,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
logger.warning("Signup verify failed", error_type=type(exc).__name__)
logger.warning("Create phone session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
@@ -102,45 +80,6 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid verification code"
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
if request.type == "recovery":
await self.request_password_reset(
PasswordResetRequest(
email=request.email,
redirect_to=request.redirect_to,
)
)
return
payload: dict[str, Any] = {"type": request.type, "email": request.email}
try:
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
@@ -189,98 +128,84 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid refresh token"
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
admin_client = self._get_admin_client()
normalized_email = email.lower()
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
normalized_phone = _normalize_phone(phone)
if not normalized_phone:
raise HTTPException(status_code=404, detail="User not found")
now = time.monotonic()
if now >= self._user_lookup_cache_expires_at:
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_email: dict[str, Any] = {}
for candidate in users:
candidate_email = str(getattr(candidate, "email", "")).lower()
if candidate_email:
users_by_email[candidate_email] = candidate
self._users_by_email = users_by_email
self._user_lookup_cache_expires_at = (
now + self._user_lookup_cache_ttl_seconds
)
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_email.get(normalized_email)
user = self._users_by_phone.get(normalized_phone)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return UserByEmailResponse(
user_phone = _normalize_phone(getattr(user, "phone", ""))
if not user_phone:
raise HTTPException(status_code=404, detail="User not found")
return UserByPhoneResponse(
id=str(getattr(user, "id", "")),
email=str(getattr(user, "email", "")),
phone=user_phone,
created_at=str(getattr(user, "created_at", "")),
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
phone_confirmed_at=(
str(getattr(user, "phone_confirmed_at", ""))
if getattr(user, "phone_confirmed_at", None)
else None
),
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {
"redirect_to": _validate_redirect_url(request.redirect_to)
}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
except AuthError as exc:
logger.warning(
"Password reset request failed",
error_type=type(exc).__name__,
)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_phone_search_query(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
if normalized_query.startswith("+"):
matched_user = self._users_by_phone.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
digits = _digits_only(normalized_query)
if not digits:
return []
matched_records: list[tuple[str, str]] = []
for cached_phone, candidate in self._users_by_phone.items():
candidate_digits = _digits_only(cached_phone)
if not candidate_digits.endswith(digits):
continue
user_id = str(getattr(candidate, "id", ""))
if user_id:
matched_records.append((cached_phone, user_id))
if not matched_records:
return []
unique_ids: list[str] = []
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
if user_id in unique_ids:
continue
unique_ids.append(user_id)
if len(unique_ids) >= max(1, limit):
break
return unique_ids
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
if now < self._user_lookup_cache_expires_at:
return
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
user_id = str(getattr(user, "id", "")) if user is not None else ""
if session is None or not user_id:
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
except AuthError as exc:
logger.warning(
"Password reset confirm failed", error_type=type(exc).__name__
)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
) from exc
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_phone: dict[str, Any] = {}
for candidate in users:
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
if candidate_phone:
users_by_phone[candidate_phone] = candidate
self._users_by_phone = users_by_phone
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
@@ -312,55 +237,24 @@ def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
return any(token in code or token in message for token in indicators)
def _validate_redirect_url(url: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
allowed_origins = {
_normalize_origin(candidate)
for candidate in config.cors.allow_origins
if _is_http_origin(candidate)
}
if origin not in allowed_origins:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
return url
def _normalize_origin(value: str) -> str:
parsed = urlparse(value)
return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
def _is_http_origin(value: str) -> bool:
parsed = urlparse(value)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _coerce_reset_email(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, Mapping):
nested = value.get("email") or value.get("value")
if isinstance(nested, str):
return nested
raise HTTPException(status_code=422, detail="Invalid email")
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
phone = _normalize_phone(getattr(user, "phone", None))
if not phone:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
try:
auth_user = AuthUser(id=str(user.id), phone=str(phone))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid phone format",
error_type=type(exc).__name__,
)
raise HTTPException(status_code=401, detail=failure_message) from exc
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
@@ -389,3 +283,37 @@ def _list_auth_users(client: Any) -> list[Any]:
page += 1
return users
def _normalize_phone(raw_phone: object) -> str | None:
phone = str(raw_phone).strip()
for separator in (" ", "-", "(", ")"):
phone = phone.replace(separator, "")
if not phone:
return None
if phone.startswith("00") and len(phone) > 2:
return f"+{phone[2:]}"
if phone.startswith("+"):
return phone
if phone.isdigit():
return f"+{phone}"
return None
def _normalize_phone_search_query(raw_query: str) -> str | None:
query = raw_query.strip()
for separator in (" ", "-", "(", ")"):
query = query.replace(separator, "")
if not query:
return None
if query.startswith("00") and len(query) > 2:
return f"+{query[2:]}"
if query.startswith("+"):
return query
if query.isdigit():
return query
return None
def _digits_only(value: str) -> str:
return "".join(ch for ch in value if ch.isdigit())
+45 -70
View File
@@ -1,20 +1,16 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request, Response
from fastapi import HTTPException
from core.config.settings import config
from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.auth.schemas import (
PasswordResetConfirmRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
from v1.auth.service import AuthService
@@ -22,80 +18,49 @@ from v1.auth.service import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post(
"/verifications", response_model=VerificationCreateResponse, status_code=202
)
async def create_verification(
payload: VerificationCreateRequest,
service: AuthService = Depends(get_auth_service),
) -> VerificationCreateResponse:
await enforce_rate_limit(
scope="signup_start",
identifier=payload.email,
limit=5,
window_seconds=60,
)
return await service.create_verification(payload)
@router.post("/verify", response_model=SessionResponse)
async def verify(
payload: VerificationVerifyRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse | Response:
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
limit = 10
window_seconds = 600
await enforce_rate_limit(
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=limit,
window_seconds=window_seconds,
)
if payload.type == "signup":
return await service.verify_verification(payload)
if payload.new_password is None:
raise HTTPException(status_code=422, detail="Invalid request")
await service.confirm_password_reset(
PasswordResetConfirmRequest(
email=payload.email,
token=payload.token,
new_password=payload.new_password,
)
)
return Response(status_code=204)
@router.post("/resend", status_code=204)
async def resend(
payload: VerificationResendRequest,
@router.post("/otp/send", status_code=204)
async def send_otp(
payload: OtpSendRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
client_ip = _client_ip(request)
await enforce_rate_limit(
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=5,
scope="otp_send_phone",
identifier=payload.phone,
limit=3,
window_seconds=60,
)
await service.resend_verification(payload)
await enforce_rate_limit(
scope="otp_send_ip",
identifier=client_ip,
limit=20,
window_seconds=60,
)
await service.send_otp(payload)
return Response(status_code=204)
@router.post("/sessions", response_model=SessionResponse)
async def create_session(
payload: SessionCreateRequest,
@router.post("/phone-session", response_model=SessionResponse)
async def create_phone_session(
payload: PhoneSessionCreateRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="login",
identifier=payload.email,
limit=10,
window_seconds=60,
scope="phone_session_phone",
identifier=payload.phone,
limit=6,
window_seconds=300,
)
return await service.create_session(payload)
await enforce_rate_limit(
scope="phone_session_ip",
identifier=client_ip,
limit=20,
window_seconds=300,
)
return await service.create_phone_session(payload)
@router.post("/sessions/refresh", response_model=SessionResponse)
@@ -130,6 +95,11 @@ async def delete_session(
def _client_ip(request: Request) -> str:
host = request.client.host if request.client else ""
if not host:
return "unknown"
if _should_trust_proxy_headers(host):
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
@@ -138,5 +108,10 @@ def _client_ip(request: Request) -> str:
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
host = request.client.host if request.client else ""
return host or "unknown"
return host
def _should_trust_proxy_headers(host: str) -> bool:
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
return host in trusted_proxies
+13 -51
View File
@@ -1,49 +1,22 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
OtpType = Literal["signup", "recovery"]
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
class VerificationCreateRequest(BaseModel):
class OtpSendRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
username: str = Field(min_length=3, max_length=30)
email: EmailStr
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
redirect_to: str | None = None
invite_code: str | None = None
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class VerificationResendRequest(BaseModel):
email: EmailStr
type: OtpType = "signup"
redirect_to: str | None = None
class PhoneSessionCreateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
class VerificationVerifyRequest(BaseModel):
type: OtpType = "signup"
email: EmailStr
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
token: str = Field(pattern=r"^\d{6}$")
new_password: str | None = Field(
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
)
@model_validator(mode="after")
def validate_type_payload(self) -> "VerificationVerifyRequest":
if self.type == "recovery" and self.new_password is None:
raise ValueError("new_password is required when type is recovery")
if self.type == "signup" and self.new_password is not None:
raise ValueError("new_password is only allowed when type is recovery")
return self
class SessionCreateRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
class SessionRefreshRequest(BaseModel):
@@ -56,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
class AuthUser(BaseModel):
id: str
email: EmailStr
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class SessionResponse(BaseModel):
@@ -67,23 +40,12 @@ class SessionResponse(BaseModel):
user: AuthUser
class UserByEmailResponse(BaseModel):
class UserByPhoneResponse(BaseModel):
id: str
email: EmailStr
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
created_at: str
email_confirmed_at: str | None = None
phone_confirmed_at: str | None = None
class VerificationCreateResponse(BaseModel):
email: EmailStr
class PasswordResetRequest(BaseModel):
email: EmailStr
redirect_to: str | None = None
class PasswordResetConfirmRequest(BaseModel):
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
+10 -66
View File
@@ -1,52 +1,30 @@
from __future__ import annotations
import re
from typing import Protocol
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
class AuthServiceGateway(Protocol):
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
async def send_otp(self, request: OtpSendRequest) -> None:
raise NotImplementedError
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
raise NotImplementedError
async def resend_verification(self, request: VerificationResendRequest) -> None:
raise NotImplementedError
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
raise NotImplementedError
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
raise NotImplementedError
async def delete_session(self, refresh_token: str | None) -> None:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
class AuthService:
_gateway: AuthServiceGateway
@@ -54,50 +32,16 @@ class AuthService:
def __init__(self, gateway: AuthServiceGateway) -> None:
self._gateway = gateway
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
normalized_invite_code = _normalize_invite_code(request.invite_code)
normalized_request = request.model_copy(
update={"invite_code": normalized_invite_code}
)
return await self._gateway.create_verification(normalized_request)
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return await self._gateway.verify_verification(request)
async def resend_verification(self, request: VerificationResendRequest) -> None:
await self._gateway.resend_verification(request)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
return await self._gateway.create_session(request)
return await self._gateway.create_phone_session(request)
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
await self._gateway.request_password_reset(request)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
await self._gateway.confirm_password_reset(request)
_INVITE_CODE_PATTERN = re.compile(r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{4}$")
def _normalize_invite_code(value: str | None) -> str | None:
if value is None:
return None
normalized = value.strip().upper()
if not normalized:
return None
return normalized if _INVITE_CODE_PATTERN.fullmatch(normalized) else None
+6 -2
View File
@@ -5,7 +5,7 @@ from typing import ClassVar
from uuid import UUID
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from schemas.inbox.messages import (
CalendarContent,
@@ -154,7 +154,11 @@ _PERMISSION_EDIT = 4
class ScheduleItemShareRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
email: EmailStr = Field(..., description="Email of user to share with")
phone: str = Field(
...,
pattern=r"^\+861[3-9]\d{9}$",
description="Phone of user to share with",
)
permission_view: bool = Field(True, description="Grant view permission")
permission_edit: bool = Field(False, description="Grant edit permission")
permission_invite: bool = Field(False, description="Grant invite permission")
+7 -7
View File
@@ -31,19 +31,19 @@ from v1.schedule_items.schemas import (
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from v1.auth.schemas import UserByEmailResponse
from v1.auth.schemas import UserByPhoneResponse
logger = get_logger("v1.schedule_items.service")
class AuthByEmailGateway(Protocol):
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
class AuthByPhoneGateway(Protocol):
async def get_user_by_phone(self, phone: str) -> "UserByPhoneResponse": ...
class ScheduleItemService(BaseService):
_repository: ScheduleItemRepository
_session: AsyncSession
_auth_gateway: AuthByEmailGateway
_auth_gateway: AuthByPhoneGateway
_inbox_repository: InboxMessageRepository
def __init__(
@@ -51,7 +51,7 @@ class ScheduleItemService(BaseService):
repository: ScheduleItemRepository,
session: AsyncSession,
current_user: CurrentUser | None,
auth_gateway: AuthByEmailGateway | None = None,
auth_gateway: AuthByPhoneGateway | None = None,
inbox_repository: InboxMessageRepository | None = None,
) -> None:
super().__init__(current_user=current_user)
@@ -329,7 +329,7 @@ class ScheduleItemService(BaseService):
detail=f"You can only share with permissions up to {inviter_permission}",
)
target_user = await self._auth_gateway.get_user_by_email(request.email)
target_user = await self._auth_gateway.get_user_by_phone(request.phone)
recipient_id = UUID(target_user.id)
existing = await self._repository.get_subscription(item_id, recipient_id)
@@ -404,7 +404,7 @@ class ScheduleItemService(BaseService):
except ValueError:
await self._session.rollback()
logger.exception(
"Auth lookup returned invalid user id", email=request.email
"Auth lookup returned invalid user id", phone=request.phone
)
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
+4 -4
View File
@@ -76,11 +76,11 @@ async def _verify_user_with_supabase(token: str) -> CurrentUser | None:
parsed_id = UUID(user_id)
except ValueError:
return None
email = getattr(user, "email", None)
phone = getattr(user, "phone", None)
role = getattr(user, "role", None)
return CurrentUser(
id=parsed_id,
email=email if isinstance(email, str) else None,
phone=phone if isinstance(phone, str) else None,
role=role if isinstance(role, str) else None,
)
@@ -125,9 +125,9 @@ async def get_current_user(
raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id))
email = payload.get("email") if isinstance(payload.get("email"), str) else None
phone = payload.get("phone") if isinstance(payload.get("phone"), str) else None
role = payload.get("role") if isinstance(payload.get("role"), str) else None
return CurrentUser(id=user_id, email=email, role=role)
return CurrentUser(id=user_id, phone=phone, role=role)
async def get_user_repository(
+1 -1
View File
@@ -38,7 +38,7 @@ class UserRepository(Protocol):
...
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
"""Search users by username (ilike) or email (exact match)."""
"""Search users by username (ilike) or phone (exact match)."""
...
+56 -24
View File
@@ -21,19 +21,22 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from schemas.user.context import UserContext
from v1.auth.schemas import UserByEmailResponse
logger = get_logger("v1.users.service")
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$")
class AuthLookupGateway(Protocol):
async def get_user_id_by_email(self, email: str) -> str | None: ...
async def search_user_ids_by_phone(
self, query: str, limit: int = 20
) -> list[str]: ...
class AuthByEmailGateway(Protocol):
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
class AuthByPhoneGateway(Protocol):
async def search_user_ids_by_phone(
self, query: str, limit: int = 20
) -> list[str]: ...
class UserContextInvalidator(Protocol):
@@ -41,15 +44,14 @@ class UserContextInvalidator(Protocol):
class AuthLookupAdapter:
def __init__(self, gateway: AuthByEmailGateway) -> None:
def __init__(self, gateway: AuthByPhoneGateway) -> None:
self._gateway = gateway
async def get_user_id_by_email(self, email: str) -> str | None:
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
try:
response = await self._gateway.get_user_by_email(email)
return response.id
return await self._gateway.search_user_ids_by_phone(query, limit=limit)
except HTTPException:
return None
return []
class UserService(BaseService):
@@ -92,11 +94,11 @@ class UserService(BaseService):
if user is None:
raise HTTPException(status_code=404, detail="User not found")
email = self._current_user.email if self._current_user else None
phone = self._current_user.phone if self._current_user else None
return UserContext(
id=str(user.id),
username=user.username,
email=email,
phone=phone,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
@@ -152,11 +154,11 @@ class UserService(BaseService):
error=str(exc),
)
email = self._current_user.email if self._current_user else None
phone = self._current_user.phone if self._current_user else None
return UserContext(
id=str(user.id),
username=user.username,
email=email,
phone=phone,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
@@ -181,28 +183,50 @@ class UserService(BaseService):
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
query = request.query.strip()
if _EMAIL_PATTERN.match(query):
return await self._search_by_email(query)
if _looks_like_phone_query(query):
phone_results = await self._search_by_phone(query)
if not query.isdigit():
return phone_results
username_results = await self._search_by_username(query)
if not phone_results:
return username_results
merged_by_id = {result.id: result for result in phone_results}
for result in username_results:
merged_by_id.setdefault(result.id, result)
return list(merged_by_id.values())
return await self._search_by_username(query)
async def _search_by_email(self, email: str) -> list[UserContext]:
async def _search_by_phone(self, phone: str) -> list[UserContext]:
if self._auth_gateway is None:
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
if user_id_str is None:
user_id_values = await self._auth_gateway.search_user_ids_by_phone(
phone, limit=20
)
if not user_id_values:
return []
user_ids: list[UUID] = []
for raw_id in user_id_values:
try:
user_ids.append(UUID(raw_id))
except ValueError:
continue
if not user_ids:
return []
try:
user = await self._repository.get_by_user_id(UUID(user_id_str))
users_by_id = await self._repository.get_by_user_ids(user_ids)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
results: list[UserContext] = []
for user_id in user_ids:
user = users_by_id.get(user_id)
if user is None:
return []
return [
continue
results.append(
UserContext(
id=str(user.id),
username=user.username,
@@ -210,7 +234,8 @@ class UserService(BaseService):
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
]
)
return results
async def _search_by_username(self, query: str) -> list[UserContext]:
try:
@@ -228,3 +253,10 @@ class UserService(BaseService):
)
for user in users
]
def _looks_like_phone_query(query: str) -> bool:
if not _PHONE_QUERY_PATTERN.fullmatch(query):
return False
digits_count = sum(char.isdigit() for char in query)
return digits_count >= 4
+16 -49
View File
@@ -12,41 +12,24 @@ from app import app
from v1.auth.dependencies import get_auth_service
from v1.auth.schemas import (
AuthUser,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
from v1.auth.service import AuthService
class FakeE2EAuthService(AuthService):
def __init__(self) -> None:
self._user = AuthUser(id="user-1", email="user@example.com")
self._user = AuthUser(id="user-1", phone="+8613812345678")
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
return VerificationCreateResponse(email=request.email)
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
return SessionResponse(
access_token="access-1",
refresh_token="refresh-1",
expires_in=3600,
token_type="bearer",
user=self._user,
)
async def resend_verification(self, request: VerificationResendRequest) -> None:
async def send_otp(self, request: OtpSendRequest) -> None:
return None
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return SessionResponse(
access_token="access-2",
refresh_token="refresh-2",
@@ -105,41 +88,25 @@ def test_auth_flow_e2e() -> None:
base_url=f"http://{host}:{port}"
)
try:
verification = request_context.post(
"/api/v1/auth/verifications",
data=json.dumps(
{
"username": "demo",
"email": "user@example.com",
"password": "secret123",
}
),
send_code = request_context.post(
"/api/v1/auth/otp/send",
data=json.dumps({"phone": "+8613812345678"}),
headers={"Content-Type": "application/json"},
)
assert verification.status == 202
assert send_code.status == 204
verify = request_context.post(
"/api/v1/auth/verify",
login_or_register = request_context.post(
"/api/v1/auth/phone-session",
data=json.dumps(
{
"email": "user@example.com",
"phone": "+8613812345678",
"token": "123456",
}
),
headers={"Content-Type": "application/json"},
)
assert verify.status == 200
assert verify.json()["access_token"] == "access-1"
login = request_context.post(
"/api/v1/auth/sessions",
data=json.dumps(
{"email": "user@example.com", "password": "secret123"}
),
headers={"Content-Type": "application/json"},
)
assert login.status == 200
assert login.json()["access_token"] == "access-2"
assert login_or_register.status == 200
assert login_or_register.json()["access_token"] == "access-2"
refresh = request_context.post(
"/api/v1/auth/sessions/refresh",
+6 -3
View File
@@ -4,11 +4,16 @@ import socket
import threading
import time
import pytest
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from v1.infra.dependencies import get_redis_service
pytest.skip(
"infra health endpoint removed from v1 API",
allow_module_level=True,
)
class _FakeService:
@@ -52,8 +57,6 @@ def _start_server(host: str, port: int):
def test_infra_health_e2e() -> None:
app.dependency_overrides[get_redis_service] = lambda: _FakeService()
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
+8 -6
View File
@@ -11,21 +11,22 @@ import uvicorn
from app import app
from core.auth.models import CurrentUser
from schemas.user.context import UserContext
from v1.users.dependencies import get_current_user, get_user_service
from v1.users.schemas import UserResponse, UserUpdateRequest
from v1.users.schemas import UserUpdateRequest
class FakeUserService:
"""Fake service for E2E testing."""
def __init__(self, user: UserResponse) -> None:
def __init__(self, user: UserContext) -> None:
self._user = user
async def get_me(self) -> UserResponse:
async def get_me(self) -> UserContext:
return self._user
async def update_me(self, update: UserUpdateRequest) -> UserResponse:
return UserResponse(
async def update_me(self, update: UserUpdateRequest) -> UserContext:
return UserContext(
id=self._user.id,
username=(
update.username if update.username is not None else self._user.username
@@ -38,6 +39,7 @@ class FakeUserService:
bio=update.bio if update.bio is not None else self._user.bio,
)
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
@@ -65,7 +67,7 @@ def _start_server(host: str, port: int):
def test_profile_flow_e2e() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
user = UserContext(
id=str(user_id),
username="demo",
avatar_url=None,
+58 -718
View File
@@ -11,16 +11,10 @@ from v1.auth.dependencies import get_auth_service
from v1.auth.rate_limit import reset_rate_limit_state
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
from v1.auth.service import AuthService
@@ -30,58 +24,39 @@ def reset_auth_rate_limit_state() -> None:
reset_rate_limit_state()
@pytest.fixture(autouse=True)
def force_in_memory_rate_limit(monkeypatch: pytest.MonkeyPatch) -> None:
async def _raise_redis_unavailable() -> None:
raise RuntimeError("redis unavailable in integration tests")
monkeypatch.setattr(
"v1.auth.rate_limit.get_or_init_redis_client",
_raise_redis_unavailable,
)
class FakeAuthService(AuthService):
def __init__(self, token_response: SessionResponse) -> None:
self._token_response = token_response
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
if request.email == "exists@example.com":
raise HTTPException(status_code=422, detail="Invalid signup request")
return VerificationCreateResponse(email=request.email)
async def send_otp(self, request: OtpSendRequest) -> None:
if request.phone == "+8613811111111":
raise HTTPException(status_code=401, detail="Invalid verification code")
return None
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
if request.token == "000000":
raise HTTPException(status_code=401, detail="Invalid verification code")
return self._token_response
async def resend_verification(self, request: VerificationResendRequest) -> None:
return None
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
raise HTTPException(status_code=401, detail="Invalid credentials")
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
raise HTTPException(status_code=401, detail="Invalid refresh token")
async def delete_session(self, refresh_token: str | None) -> None:
return None
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
if email == "missing@example.com":
raise HTTPException(status_code=404, detail="User not found")
return UserByEmailResponse(
id="user-1",
email=email,
created_at="2026-02-24T00:00:00Z",
email_confirmed_at=None,
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
return None
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
if request.token == "000000":
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
return None
def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
def _get_service() -> AuthService:
@@ -90,761 +65,126 @@ def _override_auth_service(service: AuthService) -> Callable[[], AuthService]:
return _get_service
def test_signup_start_returns_pending_response() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
def _token_response() -> SessionResponse:
user = AuthUser(id="user-1", phone="+8613812345678")
return SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
def test_send_otp_returns_204() -> None:
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
FakeAuthService(_token_response())
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
},
"/api/v1/auth/otp/send",
json={"phone": "+8613812345678"},
)
assert response.status_code == 202
assert response.json() == {"email": "user@example.com"}
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_signup_verify_returns_token_response() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
def test_phone_session_returns_token_response() -> None:
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
FakeAuthService(_token_response())
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "123456"},
"/api/v1/auth/phone-session",
json={"phone": "+8613812345678", "token": "123456"},
)
assert response.status_code == 200
body = response.json()
assert body["access_token"] == "access"
assert body["refresh_token"] == "refresh"
assert body["user"]["email"] == "user@example.com"
assert body["user"]["phone"] == "+8613812345678"
finally:
app.dependency_overrides = {}
def test_signup_resend_returns_generic_message() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
def test_phone_session_invalid_token_returns_problem_details() -> None:
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
FakeAuthService(_token_response())
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/resend",
json={"type": "recovery", "email": "user@example.com"},
)
assert response.status_code == 204
assert response.content == b""
finally:
app.dependency_overrides = {}
def test_signup_verify_invalid_token_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "000000"},
"/api/v1/auth/phone-session",
json={"phone": "+8613812345678", "token": "000000"},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
assert body["detail"] == "Invalid verification code"
finally:
app.dependency_overrides = {}
def test_signup_start_existing_email_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
def test_legacy_routes_are_removed() -> None:
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
FakeAuthService(_token_response())
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "exists@example.com",
"password": "secret123",
},
)
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Entity"
assert body["status"] == 422
assert body["detail"] == "Invalid signup request"
assert client.post("/api/v1/auth/verifications", json={}).status_code == 404
assert client.post("/api/v1/auth/verify", json={}).status_code == 404
assert client.post("/api/v1/auth/resend", json={}).status_code == 404
assert client.post("/api/v1/auth/sessions", json={}).status_code == 405
finally:
app.dependency_overrides = {}
def test_signup_verify_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
def test_send_otp_phone_rate_limited_after_too_many_attempts() -> None:
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
FakeAuthService(_token_response())
)
client = TestClient(app)
try:
for _ in range(10):
for _ in range(3):
ok = client.post(
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "123456"},
)
assert ok.status_code == 200
blocked = client.post(
"/api/v1/auth/verify",
json={"email": "user@example.com", "token": "123456"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
def test_signup_resend_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(5):
ok = client.post(
"/api/v1/auth/resend",
json={"email": "user@example.com"},
"/api/v1/auth/otp/send",
json={"phone": "+8613812345678"},
)
assert ok.status_code == 204
blocked = client.post(
"/api/v1/auth/resend",
json={"email": "user@example.com"},
"/api/v1/auth/otp/send",
json={"phone": "+8613812345678"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
def test_signup_start_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
def test_phone_session_rate_limited_after_too_many_attempts() -> None:
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
FakeAuthService(_token_response())
)
client = TestClient(app)
try:
for _ in range(5):
ok = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
},
)
assert ok.status_code == 202
for _ in range(6):
blocked = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
def test_login_invalid_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/sessions",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
assert body["detail"] == "Invalid credentials"
finally:
app.dependency_overrides = {}
def test_refresh_invalid_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": "invalid"},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
assert body["detail"] == "Invalid refresh token"
finally:
app.dependency_overrides = {}
def test_logout_returns_no_content() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": "refresh"},
)
assert response.status_code == 204
assert response.content == b""
finally:
app.dependency_overrides = {}
def test_login_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/sessions",
json={"email": "user@example.com", "password": "wrongpw"},
"/api/v1/auth/phone-session",
json={"phone": "+8613812345678", "token": "000000"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/sessions",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_refresh_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_refresh_rate_limit_not_bypassed_by_changing_refresh_token() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for index in range(10):
blocked = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": f"invalid-{index}"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/sessions/refresh",
json={"refresh_token": "invalid-extra"},
"/api/v1/auth/phone-session",
json={"phone": "+8613812345678", "token": "000000"},
)
assert blocked.status_code == 429
finally:
app.dependency_overrides = {}
def test_logout_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
ok = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": "refresh"},
)
assert ok.status_code == 204
blocked = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": "refresh"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_logout_rate_limit_not_bypassed_by_changing_refresh_token() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for index in range(10):
ok = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": f"refresh-{index}"},
)
assert ok.status_code == 204
blocked = client.request(
"DELETE",
"/api/v1/auth/sessions",
json={"refresh_token": "refresh-extra"},
)
assert blocked.status_code == 429
finally:
app.dependency_overrides = {}
def test_signup_start_validation_error_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post("/api/v1/auth/verifications", json={})
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Entity"
assert body["status"] == 422
assert body["detail"] == "Invalid request"
finally:
app.dependency_overrides = {}
def test_signup_start_missing_username_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications",
json={"email": "user@example.com", "password": "secret123"},
)
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unprocessable Entity"
assert body["status"] == 422
assert body["detail"] == "Invalid request"
finally:
app.dependency_overrides = {}
def test_password_reset_request_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/resend",
json={"email": "user@example.com"},
)
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_returns_204() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verify",
json={
"type": "recovery",
"email": "user@example.com",
"token": "123456",
"new_password": "newpassword123",
},
)
assert response.status_code == 204
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_invalid_token_returns_401() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verify",
json={
"type": "recovery",
"email": "user@example.com",
"token": "000000",
"new_password": "newpassword123",
},
)
assert response.status_code == 401
assert response.headers["content-type"].startswith("application/problem+json")
body = response.json()
assert body["title"] == "Unauthorized"
assert body["status"] == 401
finally:
app.dependency_overrides = {}
def test_password_reset_confirm_weak_password_returns_422() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verify",
json={
"type": "recovery",
"email": "user@example.com",
"token": "123456",
"new_password": "123",
},
)
assert response.status_code == 422
assert response.headers["content-type"].startswith("application/problem+json")
finally:
app.dependency_overrides = {}
class TestInviteCodeSignup:
def test_signup_with_valid_invite_code_returns_202(self) -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
"invite_code": "A2B3",
},
)
assert response.status_code == 202
assert response.json() == {"email": "user@example.com"}
finally:
app.dependency_overrides = {}
def test_signup_with_invalid_invite_code_length_returns_202(self) -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
"invite_code": "ABC123",
},
)
assert response.status_code == 202
assert response.json() == {"email": "user@example.com"}
finally:
app.dependency_overrides = {}
def test_signup_with_invalid_invite_code_chars_returns_202(self) -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
response = client.post(
"/api/v1/auth/verifications",
json={
"username": "demo",
"email": "user@example.com",
"password": "secret123",
"invite_code": "ABCD1234",
},
)
assert response.status_code == 202
assert response.json() == {"email": "user@example.com"}
finally:
app.dependency_overrides = {}
@@ -9,12 +9,12 @@ from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from schemas.user.context import UserContext
from v1.friendships.dependencies import get_friendship_service
from v1.friendships.schemas import (
FriendRequestCreate,
FriendRequestResponse,
FriendResponse,
UserBasicInfo,
)
from v1.friendships.service import FriendshipService
from v1.users.dependencies import get_current_user
@@ -31,9 +31,9 @@ class FakeFriendshipService(FriendshipService):
async def send_request(self, request: FriendRequestCreate) -> FriendRequestResponse:
return FriendRequestResponse(
id=UUID("11111111-1111-1111-1111-111111111111"),
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
content=request.content,
sender=UserContext(id="user-1", username="sender", avatar_url=None),
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
content={"text": request.content} if request.content else None,
status="pending",
created_at=datetime.now(timezone.utc),
)
@@ -41,9 +41,9 @@ class FakeFriendshipService(FriendshipService):
async def accept_request(self, friendship_id: UUID) -> FriendRequestResponse:
return FriendRequestResponse(
id=friendship_id,
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
content="Hello!",
sender=UserContext(id="user-1", username="sender", avatar_url=None),
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
content={"text": "Hello!"},
status="accepted",
created_at=datetime.now(timezone.utc),
)
@@ -51,9 +51,9 @@ class FakeFriendshipService(FriendshipService):
async def decline_request(self, friendship_id: UUID) -> FriendRequestResponse:
return FriendRequestResponse(
id=friendship_id,
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
content="Hello!",
sender=UserContext(id="user-1", username="sender", avatar_url=None),
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
content={"text": "Hello!"},
status="rejected",
created_at=datetime.now(timezone.utc),
)
@@ -61,9 +61,9 @@ class FakeFriendshipService(FriendshipService):
async def cancel_request(self, friendship_id: UUID) -> FriendRequestResponse:
return FriendRequestResponse(
id=friendship_id,
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
recipient=UserBasicInfo(id="user-2", username="recipient", avatar_url=None),
content="Hello!",
sender=UserContext(id="user-1", username="sender", avatar_url=None),
recipient=UserContext(id="user-2", username="recipient", avatar_url=None),
content={"text": "Hello!"},
status="canceled",
created_at=datetime.now(timezone.utc),
)
@@ -72,11 +72,11 @@ class FakeFriendshipService(FriendshipService):
return [
FriendRequestResponse(
id=UUID("11111111-1111-1111-1111-111111111111"),
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
recipient=UserBasicInfo(
sender=UserContext(id="user-1", username="sender", avatar_url=None),
recipient=UserContext(
id="user-2", username="recipient", avatar_url=None
),
content="Hello!",
content={"text": "Hello!"},
status="pending",
created_at=datetime.now(timezone.utc),
)
@@ -86,10 +86,8 @@ class FakeFriendshipService(FriendshipService):
return [
FriendRequestResponse(
id=UUID("22222222-2222-2222-2222-222222222222"),
sender=UserBasicInfo(id="user-1", username="sender", avatar_url=None),
recipient=UserBasicInfo(
id="user-3", username="target", avatar_url=None
),
sender=UserContext(id="user-1", username="sender", avatar_url=None),
recipient=UserContext(id="user-3", username="target", avatar_url=None),
content=None,
status="pending",
created_at=datetime.now(timezone.utc),
@@ -100,7 +98,7 @@ class FakeFriendshipService(FriendshipService):
return [
FriendResponse(
id=UUID("33333333-3333-3333-3333-333333333333"),
friend=UserBasicInfo(id="user-2", username="friend", avatar_url=None),
friend=UserContext(id="user-2", username="friend", avatar_url=None),
status="active",
created_at=datetime.now(timezone.utc),
accepted_at=datetime.now(timezone.utc),
@@ -110,7 +108,7 @@ class FakeFriendshipService(FriendshipService):
async def remove_friend(self, friend_id: UUID) -> FriendResponse:
return FriendResponse(
id=UUID("33333333-3333-3333-3333-333333333333"),
friend=UserBasicInfo(id=str(friend_id), username="friend", avatar_url=None),
friend=UserContext(id=str(friend_id), username="friend", avatar_url=None),
status="active",
created_at=datetime.now(timezone.utc),
accepted_at=datetime.now(timezone.utc),
@@ -129,7 +127,7 @@ def _override_friendship_service(
def _get_fake_current_user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="test@example.com",
phone="+8613812345678",
)
@@ -52,7 +52,7 @@ def test_share_schedule_item_returns_200() -> None:
response = client.post(
f"/api/v1/schedule-items/{item_id}/share",
json={
"email": "friend@example.com",
"phone": "+8613810000000",
"permission_view": True,
"permission_edit": False,
"permission_invite": True,
@@ -62,7 +62,7 @@ def test_share_schedule_item_returns_200() -> None:
body = response.json()
assert body["message"] == "Calendar invitation sent"
assert service.last_share_request is not None
assert service.last_share_request.email == "friend@example.com"
assert service.last_share_request.phone == "+8613810000000"
assert service.last_share_request.permission_invite is True
finally:
app.dependency_overrides = {}
+14 -13
View File
@@ -8,30 +8,31 @@ from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from schemas.user.context import UserContext
from v1.users.dependencies import get_current_user, get_user_service
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService
class FakeUserService:
"""Fake service for integration testing."""
def __init__(self, user: UserResponse) -> None:
def __init__(self, user: UserContext) -> None:
self._user = user
self._search_results: list[UserResponse] = []
self._search_results: list[UserContext] = []
def set_search_results(self, results: list[UserResponse]) -> None:
def set_search_results(self, results: list[UserContext]) -> None:
self._search_results = results
async def get_me(self) -> UserResponse:
async def get_me(self) -> UserContext:
if self._user.id is None:
raise HTTPException(status_code=404, detail="User not found")
return self._user
async def update_me(self, update: UserUpdateRequest) -> UserResponse:
async def update_me(self, update: UserUpdateRequest) -> UserContext:
if self._user.id is None:
raise HTTPException(status_code=404, detail="User not found")
return UserResponse(
return UserContext(
id=self._user.id,
username=(
update.username if update.username is not None else self._user.username
@@ -44,7 +45,7 @@ class FakeUserService:
bio=update.bio if update.bio is not None else self._user.bio,
)
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
if request.query:
return self._search_results if self._search_results else [self._user]
return []
@@ -68,7 +69,7 @@ def _override_current_user(user_id: UUID) -> Callable[[], CurrentUser]:
def test_get_me_returns_user() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
user = UserContext(
id=str(user_id),
username="demo",
avatar_url=None,
@@ -91,7 +92,7 @@ def test_get_me_returns_user() -> None:
def test_patch_me_updates_user() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
user = UserContext(
id=str(user_id),
username="demo",
avatar_url=None,
@@ -117,7 +118,7 @@ def test_patch_me_updates_user() -> None:
def test_patch_me_validation_error_returns_problem_details() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
user = UserContext(
id=str(user_id),
username="demo",
avatar_url=None,
@@ -142,7 +143,7 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
def test_search_users_returns_list() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
user = UserContext(
id=str(user_id),
username="demo",
avatar_url=None,
@@ -167,7 +168,7 @@ def test_search_users_returns_list() -> None:
def test_search_users_empty_query_returns_422() -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
user = UserResponse(
user = UserContext(
id=str(user_id),
username="demo",
avatar_url=None,
@@ -115,6 +115,34 @@ class _FailingStreamAgentService(_FakeAgentService):
raise RuntimeError("redis timeout")
class _TerminalStreamAgentService(_FakeAgentService):
def __init__(self) -> None:
super().__init__()
self.stream_calls = 0
async def stream_events(
self,
*,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
del thread_id, last_event_id, current_user
self.stream_calls += 1
if self.stream_calls == 1:
return [
{
"id": "9-0",
"event": {
"type": "RUN_FINISHED",
"threadId": "00000000-0000-0000-0000-000000000001",
"runId": "run-1",
},
}
]
return []
def test_run_requires_auth_and_returns_202_task_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
@@ -129,13 +157,13 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
},
)
assert unauthorized.status_code == 401
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
authorized = client.post(
"/api/v1/agent/runs",
@@ -146,7 +174,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
},
)
assert authorized.status_code == 202
@@ -161,7 +189,7 @@ def test_run_requires_auth_and_returns_202_task_id() -> None:
def test_stream_reads_from_last_event_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
original_acquire = agent_router._acquire_sse_slot
@@ -197,7 +225,7 @@ def test_stream_reads_from_last_event_id() -> None:
def test_stream_handles_stream_backend_errors_without_connection_crash() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FailingStreamAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
original_acquire = agent_router._acquire_sse_slot
@@ -226,10 +254,45 @@ def test_stream_handles_stream_backend_errors_without_connection_crash() -> None
app.dependency_overrides = {}
def test_stream_stops_after_terminal_run_event() -> None:
service = _TerminalStreamAgentService()
app.dependency_overrides[get_agent_service] = lambda: service
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
original_acquire = agent_router._acquire_sse_slot
original_release = agent_router._release_sse_slot
async def _allow_slot(*, user_id: str) -> bool:
del user_id
return True
async def _noop_release(*, user_id: str) -> None:
del user_id
return None
agent_router._acquire_sse_slot = _allow_slot # type: ignore[assignment]
agent_router._release_sse_slot = _noop_release # type: ignore[assignment]
try:
response = client.get(
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/events?idle_limit=3"
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert "event: RUN_FINISHED" in response.text
assert service.stream_calls == 1
finally:
agent_router._acquire_sse_slot = original_acquire # type: ignore[assignment]
agent_router._release_sse_slot = original_release # type: ignore[assignment]
app.dependency_overrides = {}
def test_stream_rejects_invalid_last_event_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -255,7 +318,7 @@ def test_history_returns_state_snapshot() -> None:
assert unauthorized.status_code == 401
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
authorized = client.get(
"/api/v1/agent/history",
@@ -276,7 +339,7 @@ def test_history_returns_state_snapshot() -> None:
def test_user_history_returns_latest_snapshot() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
try:
@@ -292,7 +355,7 @@ def test_user_history_returns_latest_snapshot() -> None:
def test_run_rejects_oversized_user_text_payload() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -312,7 +375,7 @@ def test_run_rejects_oversized_user_text_payload() -> None:
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
},
)
assert response.status_code == 422
@@ -323,7 +386,7 @@ def test_run_rejects_oversized_user_text_payload() -> None:
def test_run_rejects_client_supplied_history_messages() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -340,7 +403,7 @@ def test_run_rejects_client_supplied_history_messages() -> None:
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
},
)
assert response.status_code == 422
@@ -351,7 +414,7 @@ def test_run_rejects_client_supplied_history_messages() -> None:
def test_upload_attachment_returns_reference() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -376,7 +439,7 @@ def test_upload_attachment_returns_reference() -> None:
def test_create_attachment_signed_url_returns_url() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -399,7 +462,7 @@ def test_create_attachment_signed_url_returns_url() -> None:
def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
async def mock_transcribe_file(file_path: str, filename: str) -> str:
@@ -434,7 +497,7 @@ def test_asr_transcribe_returns_sync_transcript(monkeypatch) -> None:
def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
monkeypatch.setattr(agent_router, "_MAX_TRANSCRIBE_AUDIO_BYTES", 4)
@@ -457,7 +520,7 @@ def test_asr_transcribe_rejects_oversized_audio(monkeypatch) -> None:
def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -478,7 +541,7 @@ def test_asr_transcribe_rejects_non_wav_audio(monkeypatch) -> None:
def test_asr_transcribe_rejects_invalid_wav_payload(monkeypatch) -> None:
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
id=uuid4(), phone="+8613812345678"
)
client = TestClient(app)
@@ -20,16 +20,16 @@ FIXTURE_IMAGE_PATH = (
async def _live_access_token(client: httpx.AsyncClient) -> str:
email = os.getenv("AGENT_LIVE_EMAIL")
phone = os.getenv("AGENT_LIVE_PHONE")
password = os.getenv("AGENT_LIVE_PASSWORD")
if not email or not password:
if not phone or not password:
pytest.fail(
"AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_EMAIL and AGENT_LIVE_PASSWORD"
"AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_PHONE and AGENT_LIVE_PASSWORD"
)
response = await client.post(
f"{BASE_URL}/api/v1/auth/sessions",
json={"email": email, "password": password},
json={"phone": phone, "password": password},
)
response_text = response.text.strip().replace("\n", " ")
truncated_text = response_text[:200]
@@ -67,7 +67,7 @@ async def test_agent_sse_closed_loop_live() -> None:
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
},
)
assert run_resp.status_code == 202
@@ -143,7 +143,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": "worker"},
},
)
assert run_resp.status_code == 202
@@ -10,7 +10,7 @@ from v1.agent.service import ensure_session_owner
def test_owner_guard_denies_non_owner() -> None:
user = CurrentUser(id=uuid4(), email="self@example.com")
user = CurrentUser(id=uuid4(), phone="self@example.com")
with pytest.raises(HTTPException):
ensure_session_owner(owner_id="other-user", current_user=user)
@@ -7,6 +7,8 @@ from uuid import uuid4
import pytest
from models.agent_chat_message import AgentChatMessageRole
from sqlalchemy import select
from models.agent_chat_message import AgentChatMessage
from v1.agent.repository import AgentRepository
@@ -79,6 +81,7 @@ async def test_persist_user_message_sets_session_title_when_empty() -> None:
session_id=session_id,
content=" 请帮我安排明天下午开会 ",
metadata=None,
visibility_mask=1,
)
assert session_row.title == "请帮我安排明天下午开会"
@@ -101,6 +104,7 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
session_id=session_id,
content="新的消息内容",
metadata=None,
visibility_mask=1,
)
assert session_row.title == "已有标题"
@@ -164,3 +168,13 @@ async def test_get_history_day_uses_target_day_queries_only() -> None:
messages = payload["messages"]
assert isinstance(messages, list)
assert len(messages) == 1
def test_apply_visibility_filter_adds_bitwise_expression() -> None:
repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type]
stmt = select(AgentChatMessage)
filtered = repository._apply_visibility_filter(stmt=stmt, visibility_mask=1)
assert "visibility_mask" in str(filtered)
assert "&" in str(filtered)
+106 -7
View File
@@ -20,6 +20,7 @@ class _FakeRepository:
def __init__(self) -> None:
self.committed = False
self.persisted_user_messages: list[dict[str, object]] = []
self.created_session_calls = 0
async def get_session_owner(self, *, session_id: str) -> str:
if session_id == "00000000-0000-0000-0000-000000000001":
@@ -30,6 +31,7 @@ class _FakeRepository:
self, *, user_id: str, session_id: str | None = None
) -> str:
del user_id
self.created_session_calls += 1
return session_id or "00000000-0000-0000-0000-000000000999"
async def commit(self) -> None:
@@ -39,9 +41,13 @@ class _FakeRepository:
return None
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
del session_id, before
del session_id, before, visibility_mask
return None
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
@@ -54,15 +60,42 @@ class _FakeRepository:
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
self.persisted_user_messages.append(
{
"session_id": session_id,
"content": content,
"metadata": metadata,
"visibility_mask": visibility_mask,
}
)
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None:
normalized = agent_type.strip().lower()
mapping = {
"router": 16,
"worker": 17,
"memory": 18,
}
bit = mapping.get(normalized)
if bit is None:
return None
return {
"agent_type": normalized,
"status": "active",
"config": {
"temperature": 0.7,
"max_tokens": None,
"timeout_seconds": 30,
"visibility_consumer_bit": bit,
"context_messages": {"mode": "number", "count": 20},
"enabled_tools": [],
},
}
class _FakeQueue:
def __init__(self) -> None:
@@ -122,11 +155,11 @@ class _FakeAttachmentStorage:
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
phone="+8613812345678",
)
def _build_run_input(*, urls: list[str]) -> RunAgentInput:
def _build_run_input(*, urls: list[str], agent_type: str = "worker") -> RunAgentInput:
content: list[dict[str, str]] = [{"type": "text", "text": "hello"}]
for url in urls:
content.append({"type": "binary", "mimeType": "image/png", "url": url})
@@ -144,7 +177,7 @@ def _build_run_input(*, urls: list[str]) -> RunAgentInput:
],
"tools": [],
"context": [],
"forwardedProps": {},
"forwardedProps": {"agent_type": agent_type},
}
)
@@ -222,6 +255,68 @@ async def test_enqueue_run_persists_attachment_and_queue_without_user_token(
assert run_input["runId"] == "run-1"
@pytest.mark.asyncio
async def test_enqueue_run_rejects_unknown_agent_type(monkeypatch) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
base_url = str(config.supabase.url).rstrip("/")
safe_path = quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/a.png"
)
run_input = _build_run_input(
urls=[
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
],
agent_type="planner",
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_enqueue_run_rejects_memory_mode_for_api(monkeypatch) -> None:
monkeypatch.setattr(
agent_service_module.config.storage, "bucket", "agent-test-bucket"
)
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
attachment_storage=_FakeAttachmentStorage(),
)
base_url = str(config.supabase.url).rstrip("/")
safe_path = quote(
"agent-inputs/00000000-0000-0000-0000-000000000001/"
"00000000-0000-0000-0000-000000000001/uploads/a.png"
)
run_input = _build_run_input(
urls=[
f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
],
agent_type="memory",
)
with pytest.raises(HTTPException) as exc_info:
await service.enqueue_run(run_input=run_input, current_user=_user())
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "memory mode is automation-only"
assert repository.created_session_calls == 0
assert repository.persisted_user_messages == []
@pytest.mark.asyncio
async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None:
monkeypatch.setattr(
@@ -317,9 +412,13 @@ async def test_enqueue_run_rejects_too_many_attachments(monkeypatch) -> None:
async def test_get_history_snapshot_filters_out_tool_messages() -> None:
class _HistoryRepository(_FakeRepository):
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
del session_id, before
del session_id, before, visibility_mask
return {
"day": "2026-03-17",
"hasMore": False,
+90 -295
View File
@@ -8,13 +8,9 @@ from fastapi import HTTPException
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
VerificationCreateRequest,
VerificationVerifyRequest,
VerificationResendRequest,
)
@@ -35,314 +31,83 @@ class TestSupabaseAuthGateway:
return SupabaseAuthGateway(), mock_client, mock_admin_client
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
async def test_send_otp_sets_should_create_user(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
mock_client.auth.reset_password_email = mock_reset_email
mock_sign_in_with_otp = MagicMock()
mock_client.auth.sign_in_with_otp = mock_sign_in_with_otp
request = PasswordResetRequest(email="test@example.com")
await sut.request_password_reset(request)
await sut.send_otp(OtpSendRequest(phone="+8613812345678"))
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_create_verification_maps_timeout_error_to_503(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_client.auth.sign_up = MagicMock(
side_effect=AuthError("request_timeout", None)
)
with pytest.raises(HTTPException) as exc_info:
await sut.create_verification(
VerificationCreateRequest(
username="tester",
email="test@example.com",
password="secret123",
)
)
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Auth service temporarily unavailable"
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
options={"redirect_to": "http://localhost:3000/reset-password"},
)
@pytest.mark.asyncio
async def test_create_verification_rejects_untrusted_redirect_url(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
with pytest.raises(HTTPException) as exc_info:
await sut.create_verification(
VerificationCreateRequest(
username="tester",
email="test@example.com",
password="secret123",
redirect_to="https://evil.example.com/callback",
)
)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid redirect URL"
@pytest.mark.asyncio
async def test_request_password_reset_rejects_untrusted_redirect_url(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
with pytest.raises(HTTPException) as exc_info:
await sut.request_password_reset(
PasswordResetRequest(
email="test@example.com",
redirect_to="https://evil.example.com/reset",
)
)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid redirect URL"
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await sut.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await sut.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, mock_admin_client = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
mock_client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
mock_admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
request = PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
)
await sut.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
mock_sign_in_with_otp.assert_called_once_with(
{
"type": "recovery",
"email": "test@example.com",
"token": "123456",
"phone": "+8613812345678",
"options": {"should_create_user": True},
}
)
mock_update_user_by_id.assert_called_once_with(
"user-1",
{"password": "newpassword123"},
)
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
async def test_create_phone_session_uses_verify_otp(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
session=SimpleNamespace(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
),
user=SimpleNamespace(id="user-1", phone="+8613812345678"),
)
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
response = await sut.create_phone_session(
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
)
assert response.user.id == "user-1"
assert response.access_token == "access"
@pytest.mark.asyncio
async def test_create_phone_session_normalizes_phone_without_plus_prefix(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
),
user=SimpleNamespace(id="user-1", phone="14155552671"),
)
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
response = await sut.create_phone_session(
PhoneSessionCreateRequest(phone="+14155552671", token="123456")
)
assert response.user.phone == "+14155552671"
@pytest.mark.asyncio
async def test_refresh_session_maps_invalid_token(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_client.auth.refresh_session = MagicMock(
return_value=SimpleNamespace(session=None, user=None)
)
with pytest.raises(HTTPException) as exc_info:
await sut.confirm_password_reset(request)
await sut.refresh_session(SessionRefreshRequest(refresh_token="bad"))
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"
@pytest.mark.asyncio
async def test_recovery_resend_calls_reset_password_email(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
mock_client.auth.reset_password_email = mock_reset_email
await sut.resend_verification(
VerificationResendRequest(
type="recovery",
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
)
mock_reset_email.assert_called_once_with(
"test@example.com",
options={"redirect_to": "http://localhost:3000/reset-password"},
)
@pytest.mark.asyncio
async def test_verify_verification_maps_internal_error_to_503(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_client.auth.verify_otp = MagicMock(
side_effect=AuthError("internal_server_error", None)
)
with pytest.raises(HTTPException) as exc_info:
await sut.verify_verification(
VerificationVerifyRequest(
type="signup",
email="test@example.com",
token="123456",
)
)
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Auth service temporarily unavailable"
@pytest.mark.asyncio
async def test_create_session_maps_internal_error_to_503(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_client.auth.sign_in_with_password = MagicMock(
side_effect=AuthError("internal_server_error", None)
)
with pytest.raises(HTTPException) as exc_info:
await sut.create_session(
SessionCreateRequest(
email="test@example.com",
password="secret123",
)
)
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Auth service temporarily unavailable"
@pytest.mark.asyncio
async def test_refresh_session_maps_bad_gateway_to_503(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_client.auth.refresh_session = MagicMock(
side_effect=AuthError("bad_gateway", None)
)
with pytest.raises(HTTPException) as exc_info:
await sut.refresh_session(SessionRefreshRequest(refresh_token="rt"))
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Auth service temporarily unavailable"
@pytest.mark.asyncio
async def test_confirm_password_reset_maps_service_unavailable_to_503(
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_client.auth.verify_otp = MagicMock(
side_effect=AuthError("service_unavailable", None)
)
with pytest.raises(HTTPException) as exc_info:
await sut.confirm_password_reset(
PasswordResetConfirmRequest(
email="test@example.com",
token="123456",
new_password="newpassword123",
)
)
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Auth service temporarily unavailable"
@pytest.mark.asyncio
async def test_get_user_by_email_uses_in_memory_cache(
async def test_get_user_by_phone_uses_in_memory_cache(
self,
gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock],
monkeypatch: pytest.MonkeyPatch,
@@ -350,9 +115,9 @@ class TestSupabaseAuthGateway:
sut, _, _ = gateway
user = SimpleNamespace(
id="user-1",
email="cached@example.com",
phone="+8613811112222",
created_at="2026-03-16T00:00:00Z",
email_confirmed_at=None,
phone_confirmed_at=None,
)
list_calls = {"count": 0}
@@ -362,9 +127,39 @@ class TestSupabaseAuthGateway:
monkeypatch.setattr("v1.auth.gateway._list_auth_users", _fake_list_auth_users)
first = await sut.get_user_by_email("cached@example.com")
second = await sut.get_user_by_email("CACHED@example.com")
first = await sut.get_user_by_phone("+8613811112222")
second = await sut.get_user_by_phone("+8613811112222")
assert first.id == "user-1"
assert second.email == "cached@example.com"
assert second.phone == "+8613811112222"
assert list_calls["count"] == 1
@pytest.mark.asyncio
async def test_search_user_ids_by_phone_supports_suffix_query(
self,
gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock],
monkeypatch: pytest.MonkeyPatch,
) -> None:
sut, _, _ = gateway
users = [
SimpleNamespace(
id="user-cn",
phone="+8613811112222",
created_at="2026-03-16T00:00:00Z",
phone_confirmed_at=None,
),
SimpleNamespace(
id="user-us",
phone="+14155552671",
created_at="2026-03-16T00:00:00Z",
phone_confirmed_at=None,
),
]
monkeypatch.setattr("v1.auth.gateway._list_auth_users", lambda _client: users)
matched_cn = await sut.search_user_ids_by_phone("13811112222")
matched_us = await sut.search_user_ids_by_phone("4155552671")
assert matched_cn == ["user-cn"]
assert matched_us == ["user-us"]
+20 -59
View File
@@ -5,72 +5,28 @@ from pydantic import ValidationError
from v1.auth.schemas import (
AuthUser,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
VerificationCreateRequest,
VerificationVerifyRequest,
VerificationResendRequest,
)
def test_signup_requires_valid_email() -> None:
def test_send_otp_requires_valid_phone() -> None:
with pytest.raises(ValidationError):
VerificationCreateRequest(
username="demo", email="not-an-email", password="secret123"
)
OtpSendRequest(phone="13812345678")
def test_signup_requires_username() -> None:
def test_send_otp_accepts_e164_phone() -> None:
request = OtpSendRequest(phone="+14155552671")
assert request.phone == "+14155552671"
def test_phone_session_requires_six_digit_token() -> None:
with pytest.raises(ValidationError):
VerificationCreateRequest.model_validate(
{"email": "user@example.com", "password": "secret123"}
)
def test_signup_allows_any_invite_code_input() -> None:
request = VerificationCreateRequest(
username="demo",
email="user@example.com",
password="secret123",
invite_code="abc123",
)
assert request.invite_code == "abc123"
def test_signup_verify_requires_six_digit_token() -> None:
with pytest.raises(ValidationError):
VerificationVerifyRequest(email="user@example.com", token="abc123")
def test_signup_verify_disallows_new_password() -> None:
with pytest.raises(ValidationError):
VerificationVerifyRequest(
type="signup",
email="user@example.com",
token="123456",
new_password="secret123",
)
def test_recovery_verify_requires_new_password() -> None:
with pytest.raises(ValidationError):
VerificationVerifyRequest(
type="recovery",
email="user@example.com",
token="123456",
)
def test_signup_resend_requires_valid_email() -> None:
with pytest.raises(ValidationError):
VerificationResendRequest(email="invalid")
def test_login_requires_valid_email() -> None:
with pytest.raises(ValidationError):
SessionCreateRequest(email="invalid", password="secret123")
PhoneSessionCreateRequest(phone="+8613812345678", token="abc123")
def test_refresh_requires_token() -> None:
@@ -78,8 +34,13 @@ def test_refresh_requires_token() -> None:
SessionRefreshRequest(refresh_token="")
def test_logout_requires_token() -> None:
with pytest.raises(ValidationError):
SessionDeleteRequest(refresh_token="")
def test_session_response_maps_user() -> None:
user = AuthUser(id="user-1", email="user@example.com")
user = AuthUser(id="user-1", phone="+14155552671")
response = SessionResponse(
access_token="access",
refresh_token="refresh",
@@ -89,4 +50,4 @@ def test_session_response_maps_user() -> None:
)
assert response.user.id == "user-1"
assert response.user.email == "user@example.com"
assert response.user.phone == "+14155552671"
+33 -160
View File
@@ -2,19 +2,12 @@ from __future__ import annotations
import pytest
import v1.auth.gateway as auth_gateway_module
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
from v1.auth.service import AuthService, AuthServiceGateway
@@ -22,23 +15,16 @@ from v1.auth.service import AuthService, AuthServiceGateway
class FakeGateway(AuthServiceGateway):
def __init__(self, response: SessionResponse) -> None:
self._response = response
self.last_create_verification_request: VerificationCreateRequest | None = None
self.last_send_otp_request: OtpSendRequest | None = None
self.last_phone_session_request: PhoneSessionCreateRequest | None = None
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
self.last_create_verification_request = request
return VerificationCreateResponse(email=request.email)
async def send_otp(self, request: OtpSendRequest) -> None:
self.last_send_otp_request = request
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return self._response
async def resend_verification(self, request: VerificationResendRequest) -> None:
return None
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
self.last_phone_session_request = request
return self._response
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
@@ -47,85 +33,10 @@ class FakeGateway(AuthServiceGateway):
async def delete_session(self, refresh_token: str | None) -> None:
return None
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
class LogoutAssertingGateway(AuthServiceGateway):
def __init__(self, expected_refresh_token: str) -> None:
self._expected_refresh_token = expected_refresh_token
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
raise NotImplementedError
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
raise NotImplementedError
async def resend_verification(self, request: VerificationResendRequest) -> None:
raise NotImplementedError
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
raise NotImplementedError
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
raise NotImplementedError
async def delete_session(self, refresh_token: str | None) -> None:
assert refresh_token == self._expected_refresh_token
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
@pytest.mark.asyncio
async def test_logout_forwards_refresh_token() -> None:
service = AuthService(gateway=LogoutAssertingGateway("refresh-token"))
await service.delete_session("refresh-token")
@pytest.mark.asyncio
async def test_signup_resend_returns_none() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
service = AuthService(gateway=FakeGateway(token_response))
result = await service.resend_verification(
VerificationResendRequest(email="user@example.com")
)
assert result is None
@pytest.mark.asyncio
async def test_create_verification_ignores_invalid_invite_code() -> None:
user = AuthUser(id="user-1", email="user@example.com")
async def test_send_otp_forwards_payload() -> None:
user = AuthUser(id="user-1", phone="+8613812345678")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
@@ -136,22 +47,15 @@ async def test_create_verification_ignores_invalid_invite_code() -> None:
gateway = FakeGateway(token_response)
service = AuthService(gateway=gateway)
await service.create_verification(
VerificationCreateRequest(
username="demo",
email="user@example.com",
password="secret123",
invite_code="bad-code",
)
)
await service.send_otp(OtpSendRequest(phone="+8613812345678"))
assert gateway.last_create_verification_request is not None
assert gateway.last_create_verification_request.invite_code is None
assert gateway.last_send_otp_request is not None
assert gateway.last_send_otp_request.phone == "+8613812345678"
@pytest.mark.asyncio
async def test_create_verification_normalizes_valid_invite_code() -> None:
user = AuthUser(id="user-1", email="user@example.com")
async def test_create_phone_session_forwards_payload() -> None:
user = AuthUser(id="user-1", phone="+8613812345678")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
@@ -162,59 +66,28 @@ async def test_create_verification_normalizes_valid_invite_code() -> None:
gateway = FakeGateway(token_response)
service = AuthService(gateway=gateway)
await service.create_verification(
VerificationCreateRequest(
username="demo",
email="user@example.com",
password="secret123",
invite_code="a2b3",
)
response = await service.create_phone_session(
PhoneSessionCreateRequest(phone="+8613812345678", token="123456")
)
assert gateway.last_create_verification_request is not None
assert gateway.last_create_verification_request.invite_code == "A2B3"
assert gateway.last_phone_session_request is not None
assert gateway.last_phone_session_request.token == "123456"
assert response.user.phone == "+8613812345678"
@pytest.mark.asyncio
async def test_supabase_signup_passes_username_in_metadata(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_payload: dict[str, object] = {}
class FakeSupabaseAuth:
def sign_up(self, payload: dict[str, object]) -> object:
captured_payload.update(payload)
class _User:
id = "user-1"
email = "user@example.com"
class _Session:
access_token = "access"
refresh_token = "refresh"
expires_in = 3600
token_type = "bearer"
class _Response:
user = _User()
session = None
return _Response()
class FakeClient:
auth = FakeSupabaseAuth()
monkeypatch.setattr(
auth_gateway_module.supabase_service, "get_client", lambda: FakeClient()
async def test_refresh_session_forwards_payload() -> None:
user = AuthUser(id="user-1", phone="+8613812345678")
token_response = SessionResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
gateway = FakeGateway(token_response)
service = AuthService(gateway=gateway)
gateway = auth_gateway_module.SupabaseAuthGateway()
await gateway.create_verification(
VerificationCreateRequest(
username="demo",
email="user@example.com",
password="secret123",
)
)
response = await service.refresh_session(SessionRefreshRequest(refresh_token="rt"))
assert captured_payload["data"] == {"username": "demo"}
assert response.access_token == "access"
@@ -1,11 +1,13 @@
from __future__ import annotations
import pytest
from datetime import datetime
from uuid import uuid4
import pytest
from pydantic import ValidationError
from schemas.user.context import UserContext
from v1.friendships.schemas import (
UserBasicInfo,
FriendRequestCreate,
FriendRequestResponse,
FriendResponse,
@@ -13,16 +15,16 @@ from v1.friendships.schemas import (
)
def test_user_basic_info_maps_fields() -> None:
user = UserBasicInfo(id="user-1", username="alice", avatar_url=None)
def test_user_context_maps_fields() -> None:
user = UserContext(id="user-1", username="alice", avatar_url=None)
assert user.id == "user-1"
assert user.username == "alice"
assert user.avatar_url is None
def test_user_basic_info_with_avatar() -> None:
user = UserBasicInfo(
def test_user_context_with_avatar() -> None:
user = UserContext(
id="user-2", username="bob", avatar_url="https://example.com/avatar.png"
)
@@ -49,13 +51,13 @@ def test_friend_request_create_without_content() -> None:
def test_friend_request_create_content_max_length() -> None:
target_id = uuid4()
with pytest.raises(Exception):
with pytest.raises(ValidationError):
FriendRequestCreate(target_user_id=target_id, content="x" * 201)
def test_friend_request_response_maps_fields() -> None:
sender = UserBasicInfo(id="user-1", username="alice", avatar_url=None)
recipient = UserBasicInfo(id="user-2", username="bob", avatar_url=None)
sender = UserContext(id="user-1", username="alice", avatar_url=None)
recipient = UserContext(id="user-2", username="bob", avatar_url=None)
request_id = uuid4()
created = datetime(2026, 1, 15, 10, 30, 0)
@@ -63,7 +65,7 @@ def test_friend_request_response_maps_fields() -> None:
id=request_id,
sender=sender,
recipient=recipient,
content="Hello!",
content={"text": "Hello!"},
status="pending",
created_at=created,
)
@@ -76,7 +78,7 @@ def test_friend_request_response_maps_fields() -> None:
def test_friend_response_maps_fields() -> None:
friend_user = UserBasicInfo(id="user-2", username="bob", avatar_url=None)
friend_user = UserContext(id="user-2", username="bob", avatar_url=None)
request_id = uuid4()
created = datetime(2026, 1, 15, 10, 30, 0)
accepted = datetime(2026, 1, 16, 12, 0, 0)
@@ -96,7 +98,7 @@ def test_friend_response_maps_fields() -> None:
def test_friend_response_accepted_at_optional() -> None:
friend_user = UserBasicInfo(id="user-2", username="bob", avatar_url=None)
friend_user = UserContext(id="user-2", username="bob", avatar_url=None)
request_id = uuid4()
created = datetime(2026, 1, 15, 10, 30, 0)
@@ -12,7 +12,7 @@ from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from models.inbox_messages import InboxMessage, InboxMessageType
from models.schedule_items import ScheduleItem
from v1.auth.schemas import UserByEmailResponse
from v1.auth.schemas import UserByPhoneResponse
from v1.schedule_items.repository import ScheduleItemRepository
from v1.schedule_items.schemas import ScheduleItemShareRequest
from v1.schedule_items.service import ScheduleItemService
@@ -20,18 +20,18 @@ from v1.schedule_items.service import ScheduleItemService
def test_share_request_schema() -> None:
request = ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=True,
permission_invite=False,
)
assert request.email == "friend@example.com"
assert request.phone == "+8613810000000"
assert request.permission_view is True
def test_permission_bits_calculation() -> None:
request = ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=True,
permission_invite=False,
@@ -71,12 +71,12 @@ class ShareRepo:
class AuthGatewayStub:
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return UserByEmailResponse(
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
return UserByPhoneResponse(
id="00000000-0000-0000-0000-000000000222",
email=email,
phone=phone,
created_at="2026-02-28T10:00:00Z",
email_confirmed_at=None,
phone_confirmed_at=None,
)
@@ -119,12 +119,12 @@ class InboxRepoStub:
class AuthGatewayInvalidIdStub:
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
return UserByEmailResponse(
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
return UserByPhoneResponse(
id="not-a-uuid",
email=email,
phone=phone,
created_at="2026-02-28T10:00:00Z",
email_confirmed_at=None,
phone_confirmed_at=None,
)
@@ -148,7 +148,7 @@ async def test_share_forbidden_when_not_owner() -> None:
await service.share(
item_id,
ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=False,
permission_invite=False,
@@ -178,7 +178,7 @@ async def test_share_success_creates_calendar_invitation_message() -> None:
result = await service.share(
item_id,
ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=True,
permission_invite=False,
@@ -211,7 +211,7 @@ async def test_share_returns_not_found_when_item_missing() -> None:
await service.share(
uuid4(),
ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=False,
permission_invite=False,
@@ -241,7 +241,7 @@ async def test_share_invalid_auth_user_id_returns_503() -> None:
await service.share(
item_id,
ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=False,
permission_invite=False,
@@ -274,7 +274,7 @@ async def test_share_sqlalchemy_error_rolls_back() -> None:
await service.share(
item_id,
ScheduleItemShareRequest(
email="friend@example.com",
phone="+8613810000000",
permission_view=True,
permission_edit=False,
permission_invite=False,
@@ -22,7 +22,7 @@ async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -
del token
return deps.CurrentUser(
id=UUID("e8845a17-282b-4a63-8025-194a06235958"),
email="dagronl@126.com",
phone="dagronl@126.com",
role="authenticated",
)
@@ -31,7 +31,7 @@ async def test_get_current_user_falls_back_to_supabase_validation(monkeypatch) -
user = await deps.get_current_user(authorization="Bearer valid-token")
assert str(user.id) == "e8845a17-282b-4a63-8025-194a06235958"
assert user.email == "dagronl@126.com"
assert user.phone == "dagronl@126.com"
@pytest.mark.asyncio
@@ -6,7 +6,7 @@ from uuid import uuid4
import pytest
from core.auth.models import CurrentUser
from v1.users.schemas import UserUpdateRequest
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService
@@ -16,6 +16,7 @@ class _FakeProfile:
username: str
avatar_url: str | None
bio: str | None
settings: dict | None = None
class _FakeRepository:
@@ -51,6 +52,37 @@ class _FakeSession:
self.rollback_called += 1
class _FakeSearchRepository:
def __init__(self, profiles: list[_FakeProfile]) -> None:
self._profiles_by_id = {profile.id: profile for profile in profiles}
async def get_by_user_ids(
self, user_ids: list[object]
) -> dict[object, _FakeProfile]:
return {
user_id: self._profiles_by_id[user_id]
for user_id in user_ids
if user_id in self._profiles_by_id
}
async def search_users(self, query: str, limit: int = 20) -> list[_FakeProfile]:
_ = limit
return [
profile
for profile in self._profiles_by_id.values()
if query.lower() in profile.username.lower()
]
class _FakeAuthLookup:
def __init__(self, mapping: dict[str, list[str]]) -> None:
self.mapping = mapping
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
_ = limit
return self.mapping.get(query, [])
class _FakeUserContextCache:
def __init__(self, *, should_fail: bool = False) -> None:
self.should_fail = should_fail
@@ -72,7 +104,7 @@ async def test_update_me_invalidates_user_context_cache() -> None:
session = _FakeSession()
cache = _FakeUserContextCache()
service = UserService(
repository=repo,
repository=repo, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
current_user=CurrentUser(id=user_id),
user_context_cache=cache, # type: ignore[arg-type]
@@ -94,7 +126,7 @@ async def test_update_me_succeeds_when_cache_invalidation_fails() -> None:
session = _FakeSession()
cache = _FakeUserContextCache(should_fail=True)
service = UserService(
repository=repo,
repository=repo, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
current_user=CurrentUser(id=user_id),
user_context_cache=cache, # type: ignore[arg-type]
@@ -105,3 +137,59 @@ async def test_update_me_succeeds_when_cache_invalidation_fails() -> None:
assert result.username == "new-name"
assert session.commit_called == 1
assert cache.invalidated_user_ids == [user_id]
@pytest.mark.asyncio
async def test_search_users_supports_phone_without_country_code() -> None:
user_id = uuid4()
repo = _FakeSearchRepository(
[
_FakeProfile(
id=user_id,
username="alice",
avatar_url=None,
bio=None,
)
]
)
session = _FakeSession()
auth_lookup = _FakeAuthLookup({"13812345678": [str(user_id)]})
service = UserService(
repository=repo, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
current_user=CurrentUser(id=user_id),
auth_gateway=auth_lookup, # type: ignore[arg-type]
)
results = await service.search_users(UserSearchRequest(query="13812345678"))
assert len(results) == 1
assert results[0].id == str(user_id)
@pytest.mark.asyncio
async def test_search_users_preserves_numeric_username_lookup() -> None:
user_id = uuid4()
repo = _FakeSearchRepository(
[
_FakeProfile(
id=user_id,
username="20260319",
avatar_url=None,
bio=None,
)
]
)
session = _FakeSession()
auth_lookup = _FakeAuthLookup({})
service = UserService(
repository=repo, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
current_user=CurrentUser(id=user_id),
auth_gateway=auth_lookup, # type: ignore[arg-type]
)
results = await service.search_users(UserSearchRequest(query="20260319"))
assert len(results) == 1
assert results[0].username == "20260319"