refactor(backend): update API routes and service layer
- Update agent router/service/repository with new endpoints - Update auth routes with phone-based authentication - Update users service with new phone lookup - Update schedule_items with new schemas - Update message schemas with visibility support - Update settings with new automation scheduler config - Update CLI with new commands - Update tests to match new API contracts
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
@@ -95,6 +95,7 @@ class AgentRepository:
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None:
|
||||
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
|
||||
|
||||
@@ -124,6 +125,7 @@ class AgentRepository:
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=content,
|
||||
visibility_mask=max(int(visibility_mask), 0),
|
||||
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
|
||||
)
|
||||
self._session.add(message)
|
||||
@@ -132,7 +134,11 @@ class AgentRepository:
|
||||
await self._session.flush()
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
@@ -152,6 +158,10 @@ class AgentRepository:
|
||||
.order_by(AgentChatMessage.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
target_created_at_stmt = self._apply_visibility_filter(
|
||||
stmt=target_created_at_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if before_start is not None:
|
||||
target_created_at_stmt = target_created_at_stmt.where(
|
||||
AgentChatMessage.created_at < before_start
|
||||
@@ -175,6 +185,10 @@ class AgentRepository:
|
||||
.where(AgentChatMessage.created_at < end)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
message_stmt = self._apply_visibility_filter(
|
||||
stmt=message_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more_stmt = (
|
||||
select(AgentChatMessage.id)
|
||||
@@ -183,6 +197,10 @@ class AgentRepository:
|
||||
.where(AgentChatMessage.created_at < start)
|
||||
.limit(1)
|
||||
)
|
||||
has_more_stmt = self._apply_visibility_filter(
|
||||
stmt=has_more_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
has_more = (
|
||||
await self._session.execute(has_more_stmt)
|
||||
).scalar_one_or_none() is not None
|
||||
@@ -196,7 +214,11 @@ class AgentRepository:
|
||||
}
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
self, *, session_id: str, user_message_limit: int
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
@@ -210,6 +232,10 @@ class AgentRepository:
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.seq.desc())
|
||||
)
|
||||
message_stmt = self._apply_visibility_filter(
|
||||
stmt=message_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
|
||||
if not messages_desc:
|
||||
return []
|
||||
@@ -294,6 +320,21 @@ class AgentRepository:
|
||||
)
|
||||
return payload_model.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
def _apply_visibility_filter(
|
||||
self,
|
||||
*,
|
||||
stmt: Select,
|
||||
visibility_mask: int | None,
|
||||
) -> Select:
|
||||
if visibility_mask is None:
|
||||
return stmt
|
||||
required_mask = max(int(visibility_mask), 0)
|
||||
if required_mask == 0:
|
||||
return stmt
|
||||
return stmt.where(
|
||||
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
|
||||
)
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
return isinstance(title, str) and bool(title.strip())
|
||||
|
||||
@@ -47,6 +47,7 @@ logger = get_logger("v1.agent.router")
|
||||
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
@@ -72,8 +73,14 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if int(ttl) < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
await redis.decr(key)
|
||||
after_decr = await redis.decr(key)
|
||||
if int(after_decr) <= 0:
|
||||
await redis.delete(key)
|
||||
return False
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -82,7 +89,7 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
user_id=user_id,
|
||||
reason=str(exc),
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _release_sse_slot(*, user_id: str) -> None:
|
||||
@@ -92,10 +99,21 @@ async def _release_sse_slot(*, user_id: str) -> None:
|
||||
count = await redis.decr(key)
|
||||
if int(count) <= 0:
|
||||
await redis.delete(key)
|
||||
return None
|
||||
ttl = await redis.ttl(key)
|
||||
if int(ttl) < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
|
||||
|
||||
def _is_terminal_run_event(event: dict[str, object]) -> bool:
|
||||
raw_event_type = event.get("type")
|
||||
return (
|
||||
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
@@ -145,8 +163,13 @@ async def stream_events(
|
||||
async def _event_iter() -> AsyncIterator[str]:
|
||||
cursor = last_event_id
|
||||
idle_polls = 0
|
||||
terminal_event_reached = False
|
||||
try:
|
||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||
while (
|
||||
not terminal_event_reached
|
||||
and not await request.is_disconnected()
|
||||
and idle_polls < idle_limit
|
||||
):
|
||||
try:
|
||||
rows = await service.stream_events(
|
||||
thread_id=thread_id,
|
||||
@@ -181,6 +204,9 @@ async def stream_events(
|
||||
continue
|
||||
cursor = row_id
|
||||
yield to_sse_event(row_id, event)
|
||||
if _is_terminal_run_event(event):
|
||||
terminal_event_reached = True
|
||||
break
|
||||
|
||||
finally:
|
||||
await _release_sse_slot(user_id=str(current_user.id))
|
||||
|
||||
@@ -17,6 +17,9 @@ from core.auth.models import CurrentUser
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachment,
|
||||
@@ -51,7 +54,11 @@ class AgentRepositoryLike(Protocol):
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
@@ -62,8 +69,13 @@ class AgentRepositoryLike(Protocol):
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -138,6 +150,17 @@ class AgentService:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
forwarded_props = getattr(run_input, "forwarded_props", None)
|
||||
try:
|
||||
agent_type = parse_forwarded_props_agent_type(forwarded_props)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
if agent_type == "memory":
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="memory mode is automation-only",
|
||||
)
|
||||
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
@@ -161,25 +184,21 @@ class AgentService:
|
||||
run_input=run_input,
|
||||
current_user=current_user,
|
||||
)
|
||||
visibility_mask = await self._resolve_user_message_visibility_mask(
|
||||
agent_type=agent_type
|
||||
)
|
||||
await self._repository.persist_user_message(
|
||||
session_id=thread_id,
|
||||
content=user_message_text,
|
||||
metadata=user_message_metadata,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
await self._repository.commit()
|
||||
|
||||
forwarded_props = getattr(run_input, "forwarded_props", None)
|
||||
system_agent_mode = "worker"
|
||||
if isinstance(forwarded_props, dict):
|
||||
raw_mode = forwarded_props.get("system_agent_mode")
|
||||
if isinstance(raw_mode, str) and raw_mode.strip():
|
||||
system_agent_mode = raw_mode.strip().lower()
|
||||
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"system_agent_mode": system_agent_mode,
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
@@ -193,6 +212,61 @@ class AgentService:
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
|
||||
normalized_agent_type = agent_type.strip().lower()
|
||||
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
|
||||
|
||||
if normalized_agent_type == "memory":
|
||||
return bit_mask(bit=18)
|
||||
|
||||
agent_config = await self._repository.get_system_agent_config(
|
||||
agent_type=normalized_agent_type
|
||||
)
|
||||
if agent_config is None:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="invalid forwarded_props.agent_type"
|
||||
)
|
||||
llm_config = SystemAgentLLMConfig.model_validate(
|
||||
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
|
||||
)
|
||||
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
|
||||
|
||||
if normalized_agent_type == "worker":
|
||||
router_config = await self._repository.get_system_agent_config(
|
||||
agent_type="router"
|
||||
)
|
||||
worker_config = await self._repository.get_system_agent_config(
|
||||
agent_type="worker"
|
||||
)
|
||||
if router_config is None or worker_config is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="system agent visibility config missing",
|
||||
)
|
||||
router_mask = bit_mask(
|
||||
bit=SystemAgentLLMConfig.model_validate(
|
||||
(
|
||||
router_config.get("config")
|
||||
if isinstance(router_config, dict)
|
||||
else {}
|
||||
)
|
||||
or {}
|
||||
).visibility_consumer_bit
|
||||
)
|
||||
worker_mask = bit_mask(
|
||||
bit=SystemAgentLLMConfig.model_validate(
|
||||
(
|
||||
worker_config.get("config")
|
||||
if isinstance(worker_config, dict)
|
||||
else {}
|
||||
)
|
||||
or {}
|
||||
).visibility_consumer_bit
|
||||
)
|
||||
return history_bit_mask | router_mask | worker_mask
|
||||
|
||||
return history_bit_mask | agent_mask
|
||||
|
||||
async def _prepare_user_message(
|
||||
self,
|
||||
*,
|
||||
@@ -408,6 +482,7 @@ class AgentService:
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
|
||||
)
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
|
||||
+127
-199
@@ -2,28 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
UserByPhoneResponse,
|
||||
)
|
||||
from v1.auth.service import AuthServiceGateway
|
||||
|
||||
@@ -36,7 +30,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def __init__(self) -> None:
|
||||
self._user_lookup_cache_ttl_seconds: int = 60
|
||||
self._user_lookup_cache_expires_at: float = 0.0
|
||||
self._users_by_email: dict[str, Any] = {}
|
||||
self._users_by_phone: dict[str, Any] = {}
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
@@ -44,47 +38,31 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
client = self._get_client()
|
||||
metadata: dict[str, Any] = {"username": request.username}
|
||||
if request.invite_code:
|
||||
metadata["invite_code"] = request.invite_code
|
||||
payload: dict[str, Any] = {
|
||||
"email": request.email,
|
||||
"password": request.password,
|
||||
"data": metadata,
|
||||
"phone": request.phone,
|
||||
"options": {"should_create_user": True},
|
||||
}
|
||||
if request.redirect_to:
|
||||
payload["options"] = {
|
||||
"email_redirect_to": _validate_redirect_url(request.redirect_to)
|
||||
}
|
||||
try:
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
|
||||
await asyncio.to_thread(sign_in_with_otp, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup failed", error_type=type(exc).__name__)
|
||||
logger.warning("Send otp failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid signup request"
|
||||
) from exc
|
||||
raise HTTPException(status_code=429, detail="Too many requests") from exc
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
if request.type != "signup":
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": request.type,
|
||||
"email": request.email,
|
||||
"type": "sms",
|
||||
"phone": request.phone,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
@@ -92,7 +70,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(response, "Invalid verification code")
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup verify failed", error_type=type(exc).__name__)
|
||||
logger.warning("Create phone session failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
@@ -102,45 +80,6 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid verification code"
|
||||
) from exc
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
if request.type == "recovery":
|
||||
await self.request_password_reset(
|
||||
PasswordResetRequest(
|
||||
email=request.email,
|
||||
redirect_to=request.redirect_to,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
payload: dict[str, Any] = {"type": request.type, "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
logger.warning("Login failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
@@ -189,98 +128,84 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
admin_client = self._get_admin_client()
|
||||
normalized_email = email.lower()
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
normalized_phone = _normalize_phone(phone)
|
||||
if not normalized_phone:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
now = time.monotonic()
|
||||
if now >= self._user_lookup_cache_expires_at:
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_email: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_email = str(getattr(candidate, "email", "")).lower()
|
||||
if candidate_email:
|
||||
users_by_email[candidate_email] = candidate
|
||||
self._users_by_email = users_by_email
|
||||
self._user_lookup_cache_expires_at = (
|
||||
now + self._user_lookup_cache_ttl_seconds
|
||||
)
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
|
||||
user = self._users_by_email.get(normalized_email)
|
||||
user = self._users_by_phone.get(normalized_phone)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserByEmailResponse(
|
||||
user_phone = _normalize_phone(getattr(user, "phone", ""))
|
||||
if not user_phone:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserByPhoneResponse(
|
||||
id=str(getattr(user, "id", "")),
|
||||
email=str(getattr(user, "email", "")),
|
||||
phone=user_phone,
|
||||
created_at=str(getattr(user, "created_at", "")),
|
||||
email_confirmed_at=(
|
||||
str(getattr(user, "email_confirmed_at", ""))
|
||||
if getattr(user, "email_confirmed_at", None)
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user, "phone_confirmed_at", ""))
|
||||
if getattr(user, "phone_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {
|
||||
"redirect_to": _validate_redirect_url(request.redirect_to)
|
||||
}
|
||||
await asyncio.to_thread(reset_email, email, options=options)
|
||||
else:
|
||||
await asyncio.to_thread(reset_email, email)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset request failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
normalized_query = _normalize_phone_search_query(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
if normalized_query.startswith("+"):
|
||||
matched_user = self._users_by_phone.get(normalized_query)
|
||||
if matched_user is None:
|
||||
return []
|
||||
user_id = str(getattr(matched_user, "id", ""))
|
||||
return [user_id] if user_id else []
|
||||
|
||||
digits = _digits_only(normalized_query)
|
||||
if not digits:
|
||||
return []
|
||||
|
||||
matched_records: list[tuple[str, str]] = []
|
||||
for cached_phone, candidate in self._users_by_phone.items():
|
||||
candidate_digits = _digits_only(cached_phone)
|
||||
if not candidate_digits.endswith(digits):
|
||||
continue
|
||||
user_id = str(getattr(candidate, "id", ""))
|
||||
if user_id:
|
||||
matched_records.append((cached_phone, user_id))
|
||||
|
||||
if not matched_records:
|
||||
return []
|
||||
|
||||
unique_ids: list[str] = []
|
||||
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
|
||||
if user_id in unique_ids:
|
||||
continue
|
||||
unique_ids.append(user_id)
|
||||
if len(unique_ids) >= max(1, limit):
|
||||
break
|
||||
return unique_ids
|
||||
|
||||
async def _refresh_user_lookup_cache_if_needed(self) -> None:
|
||||
now = time.monotonic()
|
||||
if now < self._user_lookup_cache_expires_at:
|
||||
return
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
client = self._get_client()
|
||||
admin_client = self._get_admin_client()
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
user_id = str(getattr(user, "id", "")) if user is not None else ""
|
||||
if session is None or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset confirm failed", error_type=type(exc).__name__
|
||||
)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
) from exc
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_phone: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
|
||||
if candidate_phone:
|
||||
users_by_phone[candidate_phone] = candidate
|
||||
self._users_by_phone = users_by_phone
|
||||
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
|
||||
|
||||
|
||||
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
@@ -312,55 +237,24 @@ def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
return any(token in code or token in message for token in indicators)
|
||||
|
||||
|
||||
def _validate_redirect_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
|
||||
origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
||||
allowed_origins = {
|
||||
_normalize_origin(candidate)
|
||||
for candidate in config.cors.allow_origins
|
||||
if _is_http_origin(candidate)
|
||||
}
|
||||
if origin not in allowed_origins:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
return url
|
||||
|
||||
|
||||
def _normalize_origin(value: str) -> str:
|
||||
parsed = urlparse(value)
|
||||
return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
||||
|
||||
|
||||
def _is_http_origin(value: str) -> bool:
|
||||
parsed = urlparse(value)
|
||||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||||
|
||||
|
||||
def _coerce_reset_email(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
nested = value.get("email") or value.get("value")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
|
||||
raise HTTPException(status_code=422, detail="Invalid email")
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
if session is None or user is None:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
email = getattr(user, "email", None)
|
||||
if not email:
|
||||
phone = _normalize_phone(getattr(user, "phone", None))
|
||||
if not phone:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
auth_user = AuthUser(id=str(user.id), email=str(email))
|
||||
try:
|
||||
auth_user = AuthUser(id=str(user.id), phone=str(phone))
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Auth response returned invalid phone format",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=failure_message) from exc
|
||||
return SessionResponse(
|
||||
access_token=str(session.access_token),
|
||||
refresh_token=str(session.refresh_token),
|
||||
@@ -389,3 +283,37 @@ def _list_auth_users(client: Any) -> list[Any]:
|
||||
page += 1
|
||||
|
||||
return users
|
||||
|
||||
|
||||
def _normalize_phone(raw_phone: object) -> str | None:
|
||||
phone = str(raw_phone).strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
phone = phone.replace(separator, "")
|
||||
if not phone:
|
||||
return None
|
||||
if phone.startswith("00") and len(phone) > 2:
|
||||
return f"+{phone[2:]}"
|
||||
if phone.startswith("+"):
|
||||
return phone
|
||||
if phone.isdigit():
|
||||
return f"+{phone}"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_phone_search_query(raw_query: str) -> str | None:
|
||||
query = raw_query.strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
query = query.replace(separator, "")
|
||||
if not query:
|
||||
return None
|
||||
if query.startswith("00") and len(query) > 2:
|
||||
return f"+{query[2:]}"
|
||||
if query.startswith("+"):
|
||||
return query
|
||||
if query.isdigit():
|
||||
return query
|
||||
return None
|
||||
|
||||
|
||||
def _digits_only(value: str) -> str:
|
||||
return "".join(ch for ch in value if ch.isdigit())
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.config.settings import config
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
@@ -22,80 +18,49 @@ from v1.auth.service import AuthService
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/verifications", response_model=VerificationCreateResponse, status_code=202
|
||||
)
|
||||
async def create_verification(
|
||||
payload: VerificationCreateRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> VerificationCreateResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="signup_start",
|
||||
identifier=payload.email,
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
return await service.create_verification(payload)
|
||||
|
||||
|
||||
@router.post("/verify", response_model=SessionResponse)
|
||||
async def verify(
|
||||
payload: VerificationVerifyRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse | Response:
|
||||
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
|
||||
limit = 10
|
||||
window_seconds = 600
|
||||
await enforce_rate_limit(
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if payload.type == "signup":
|
||||
return await service.verify_verification(payload)
|
||||
if payload.new_password is None:
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
await service.confirm_password_reset(
|
||||
PasswordResetConfirmRequest(
|
||||
email=payload.email,
|
||||
token=payload.token,
|
||||
new_password=payload.new_password,
|
||||
)
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/resend", status_code=204)
|
||||
async def resend(
|
||||
payload: VerificationResendRequest,
|
||||
@router.post("/otp/send", status_code=204)
|
||||
async def send_otp(
|
||||
payload: OtpSendRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=5,
|
||||
scope="otp_send_phone",
|
||||
identifier=payload.phone,
|
||||
limit=3,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.resend_verification(payload)
|
||||
await enforce_rate_limit(
|
||||
scope="otp_send_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.send_otp(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(
|
||||
payload: SessionCreateRequest,
|
||||
@router.post("/phone-session", response_model=SessionResponse)
|
||||
async def create_phone_session(
|
||||
payload: PhoneSessionCreateRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="login",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
scope="phone_session_phone",
|
||||
identifier=payload.phone,
|
||||
limit=6,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_session(payload)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_phone_session(payload)
|
||||
|
||||
|
||||
@router.post("/sessions/refresh", response_model=SessionResponse)
|
||||
@@ -130,13 +95,23 @@ async def delete_session(
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
if first:
|
||||
return first
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
host = request.client.host if request.client else ""
|
||||
return host or "unknown"
|
||||
if not host:
|
||||
return "unknown"
|
||||
|
||||
if _should_trust_proxy_headers(host):
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
if first:
|
||||
return first
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
return host
|
||||
|
||||
|
||||
def _should_trust_proxy_headers(host: str) -> bool:
|
||||
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
|
||||
return host in trusted_proxies
|
||||
|
||||
@@ -1,49 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
OtpType = Literal["signup", "recovery"]
|
||||
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
|
||||
|
||||
|
||||
class VerificationCreateRequest(BaseModel):
|
||||
class OtpSendRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
username: str = Field(min_length=3, max_length=30)
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
redirect_to: str | None = None
|
||||
invite_code: str | None = None
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class VerificationResendRequest(BaseModel):
|
||||
email: EmailStr
|
||||
type: OtpType = "signup"
|
||||
redirect_to: str | None = None
|
||||
class PhoneSessionCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class VerificationVerifyRequest(BaseModel):
|
||||
type: OtpType = "signup"
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str | None = Field(
|
||||
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_type_payload(self) -> "VerificationVerifyRequest":
|
||||
if self.type == "recovery" and self.new_password is None:
|
||||
raise ValueError("new_password is required when type is recovery")
|
||||
if self.type == "signup" and self.new_password is not None:
|
||||
raise ValueError("new_password is only allowed when type is recovery")
|
||||
return self
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
|
||||
|
||||
class SessionRefreshRequest(BaseModel):
|
||||
@@ -56,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
@@ -67,23 +40,12 @@ class SessionResponse(BaseModel):
|
||||
user: AuthUser
|
||||
|
||||
|
||||
class UserByEmailResponse(BaseModel):
|
||||
class UserByPhoneResponse(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
created_at: str
|
||||
email_confirmed_at: str | None = None
|
||||
phone_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class VerificationCreateResponse(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
redirect_to: str | None = None
|
||||
|
||||
|
||||
class PasswordResetConfirmRequest(BaseModel):
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
@@ -1,52 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
|
||||
|
||||
class AuthServiceGateway(Protocol):
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
@@ -54,50 +32,16 @@ class AuthService:
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
normalized_invite_code = _normalize_invite_code(request.invite_code)
|
||||
normalized_request = request.model_copy(
|
||||
update={"invite_code": normalized_invite_code}
|
||||
)
|
||||
return await self._gateway.create_verification(normalized_request)
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return await self._gateway.verify_verification(request)
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
await self._gateway.resend_verification(request)
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
return await self._gateway.create_session(request)
|
||||
return await self._gateway.create_phone_session(request)
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
await self._gateway.request_password_reset(request)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
await self._gateway.confirm_password_reset(request)
|
||||
|
||||
|
||||
_INVITE_CODE_PATTERN = re.compile(r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{4}$")
|
||||
|
||||
|
||||
def _normalize_invite_code(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
normalized = value.strip().upper()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
return normalized if _INVITE_CODE_PATTERN.fullmatch(normalized) else None
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import ClassVar
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from schemas.inbox.messages import (
|
||||
CalendarContent,
|
||||
@@ -154,7 +154,11 @@ _PERMISSION_EDIT = 4
|
||||
class ScheduleItemShareRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
email: EmailStr = Field(..., description="Email of user to share with")
|
||||
phone: str = Field(
|
||||
...,
|
||||
pattern=r"^\+861[3-9]\d{9}$",
|
||||
description="Phone of user to share with",
|
||||
)
|
||||
permission_view: bool = Field(True, description="Grant view permission")
|
||||
permission_edit: bool = Field(False, description="Grant edit permission")
|
||||
permission_invite: bool = Field(False, description="Grant invite permission")
|
||||
|
||||
@@ -31,19 +31,19 @@ from v1.schedule_items.schemas import (
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
from v1.auth.schemas import UserByPhoneResponse
|
||||
|
||||
logger = get_logger("v1.schedule_items.service")
|
||||
|
||||
|
||||
class AuthByEmailGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
|
||||
class AuthByPhoneGateway(Protocol):
|
||||
async def get_user_by_phone(self, phone: str) -> "UserByPhoneResponse": ...
|
||||
|
||||
|
||||
class ScheduleItemService(BaseService):
|
||||
_repository: ScheduleItemRepository
|
||||
_session: AsyncSession
|
||||
_auth_gateway: AuthByEmailGateway
|
||||
_auth_gateway: AuthByPhoneGateway
|
||||
_inbox_repository: InboxMessageRepository
|
||||
|
||||
def __init__(
|
||||
@@ -51,7 +51,7 @@ class ScheduleItemService(BaseService):
|
||||
repository: ScheduleItemRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
auth_gateway: AuthByEmailGateway | None = None,
|
||||
auth_gateway: AuthByPhoneGateway | None = None,
|
||||
inbox_repository: InboxMessageRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
@@ -329,7 +329,7 @@ class ScheduleItemService(BaseService):
|
||||
detail=f"You can only share with permissions up to {inviter_permission}",
|
||||
)
|
||||
|
||||
target_user = await self._auth_gateway.get_user_by_email(request.email)
|
||||
target_user = await self._auth_gateway.get_user_by_phone(request.phone)
|
||||
recipient_id = UUID(target_user.id)
|
||||
|
||||
existing = await self._repository.get_subscription(item_id, recipient_id)
|
||||
@@ -404,7 +404,7 @@ class ScheduleItemService(BaseService):
|
||||
except ValueError:
|
||||
await self._session.rollback()
|
||||
logger.exception(
|
||||
"Auth lookup returned invalid user id", email=request.email
|
||||
"Auth lookup returned invalid user id", phone=request.phone
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
|
||||
@@ -76,11 +76,11 @@ async def _verify_user_with_supabase(token: str) -> CurrentUser | None:
|
||||
parsed_id = UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
email = getattr(user, "email", None)
|
||||
phone = getattr(user, "phone", None)
|
||||
role = getattr(user, "role", None)
|
||||
return CurrentUser(
|
||||
id=parsed_id,
|
||||
email=email if isinstance(email, str) else None,
|
||||
phone=phone if isinstance(phone, str) else None,
|
||||
role=role if isinstance(role, str) else None,
|
||||
)
|
||||
|
||||
@@ -125,9 +125,9 @@ async def get_current_user(
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
logger.debug("JWT validation successful", user_id=str(user_id))
|
||||
email = payload.get("email") if isinstance(payload.get("email"), str) else None
|
||||
phone = payload.get("phone") if isinstance(payload.get("phone"), str) else None
|
||||
role = payload.get("role") if isinstance(payload.get("role"), str) else None
|
||||
return CurrentUser(id=user_id, email=email, role=role)
|
||||
return CurrentUser(id=user_id, phone=phone, role=role)
|
||||
|
||||
|
||||
async def get_user_repository(
|
||||
|
||||
@@ -38,7 +38,7 @@ class UserRepository(Protocol):
|
||||
...
|
||||
|
||||
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
|
||||
"""Search users by username (ilike) or email (exact match)."""
|
||||
"""Search users by username (ilike) or phone (exact match)."""
|
||||
...
|
||||
|
||||
|
||||
|
||||
@@ -21,19 +21,22 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from schemas.user.context import UserContext
|
||||
from v1.auth.schemas import UserByEmailResponse
|
||||
|
||||
logger = get_logger("v1.users.service")
|
||||
|
||||
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$")
|
||||
|
||||
|
||||
class AuthLookupGateway(Protocol):
|
||||
async def get_user_id_by_email(self, email: str) -> str | None: ...
|
||||
async def search_user_ids_by_phone(
|
||||
self, query: str, limit: int = 20
|
||||
) -> list[str]: ...
|
||||
|
||||
|
||||
class AuthByEmailGateway(Protocol):
|
||||
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
|
||||
class AuthByPhoneGateway(Protocol):
|
||||
async def search_user_ids_by_phone(
|
||||
self, query: str, limit: int = 20
|
||||
) -> list[str]: ...
|
||||
|
||||
|
||||
class UserContextInvalidator(Protocol):
|
||||
@@ -41,15 +44,14 @@ class UserContextInvalidator(Protocol):
|
||||
|
||||
|
||||
class AuthLookupAdapter:
|
||||
def __init__(self, gateway: AuthByEmailGateway) -> None:
|
||||
def __init__(self, gateway: AuthByPhoneGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def get_user_id_by_email(self, email: str) -> str | None:
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
try:
|
||||
response = await self._gateway.get_user_by_email(email)
|
||||
return response.id
|
||||
return await self._gateway.search_user_ids_by_phone(query, limit=limit)
|
||||
except HTTPException:
|
||||
return None
|
||||
return []
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
@@ -92,11 +94,11 @@ class UserService(BaseService):
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
email = self._current_user.email if self._current_user else None
|
||||
phone = self._current_user.phone if self._current_user else None
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
phone=phone,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
@@ -152,11 +154,11 @@ class UserService(BaseService):
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
email = self._current_user.email if self._current_user else None
|
||||
phone = self._current_user.phone if self._current_user else None
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
email=email,
|
||||
phone=phone,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
@@ -181,36 +183,59 @@ class UserService(BaseService):
|
||||
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
|
||||
query = request.query.strip()
|
||||
|
||||
if _EMAIL_PATTERN.match(query):
|
||||
return await self._search_by_email(query)
|
||||
if _looks_like_phone_query(query):
|
||||
phone_results = await self._search_by_phone(query)
|
||||
if not query.isdigit():
|
||||
return phone_results
|
||||
username_results = await self._search_by_username(query)
|
||||
if not phone_results:
|
||||
return username_results
|
||||
merged_by_id = {result.id: result for result in phone_results}
|
||||
for result in username_results:
|
||||
merged_by_id.setdefault(result.id, result)
|
||||
return list(merged_by_id.values())
|
||||
|
||||
return await self._search_by_username(query)
|
||||
|
||||
async def _search_by_email(self, email: str) -> list[UserContext]:
|
||||
async def _search_by_phone(self, phone: str) -> list[UserContext]:
|
||||
if self._auth_gateway is None:
|
||||
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
|
||||
|
||||
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
|
||||
if user_id_str is None:
|
||||
user_id_values = await self._auth_gateway.search_user_ids_by_phone(
|
||||
phone, limit=20
|
||||
)
|
||||
if not user_id_values:
|
||||
return []
|
||||
|
||||
user_ids: list[UUID] = []
|
||||
for raw_id in user_id_values:
|
||||
try:
|
||||
user_ids.append(UUID(raw_id))
|
||||
except ValueError:
|
||||
continue
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
user = await self._repository.get_by_user_id(UUID(user_id_str))
|
||||
users_by_id = await self._repository.get_by_user_ids(user_ids)
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=503, detail="User store unavailable")
|
||||
|
||||
if user is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
results: list[UserContext] = []
|
||||
for user_id in user_ids:
|
||||
user = users_by_id.get(user_id)
|
||||
if user is None:
|
||||
continue
|
||||
results.append(
|
||||
UserContext(
|
||||
id=str(user.id),
|
||||
username=user.username,
|
||||
avatar_url=user.avatar_url,
|
||||
bio=user.bio,
|
||||
settings=parse_profile_settings(user.settings),
|
||||
)
|
||||
)
|
||||
]
|
||||
return results
|
||||
|
||||
async def _search_by_username(self, query: str) -> list[UserContext]:
|
||||
try:
|
||||
@@ -228,3 +253,10 @@ class UserService(BaseService):
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
|
||||
def _looks_like_phone_query(query: str) -> bool:
|
||||
if not _PHONE_QUERY_PATTERN.fullmatch(query):
|
||||
return False
|
||||
digits_count = sum(char.isdigit() for char in query)
|
||||
return digits_count >= 4
|
||||
|
||||
Reference in New Issue
Block a user