refactor(backend): update API routes and service layer

- Update agent router/service/repository with new endpoints
- Update auth routes with phone-based authentication
- Update users service with new phone lookup
- Update schedule_items with new schemas
- Update message schemas with visibility support
- Update settings with new automation scheduler config
- Update CLI with new commands
- Update tests to match new API contracts
This commit is contained in:
qzl
2026-03-19 18:42:59 +08:00
parent 641d847008
commit f0af44d840
36 changed files with 1083 additions and 1853 deletions
+44 -3
View File
@@ -6,7 +6,7 @@ from typing import Protocol
from uuid import UUID, uuid4
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
@@ -95,6 +95,7 @@ class AgentRepository:
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
@@ -124,6 +125,7 @@ class AgentRepository:
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content,
visibility_mask=max(int(visibility_mask), 0),
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
@@ -132,7 +134,11 @@ class AgentRepository:
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
@@ -152,6 +158,10 @@ class AgentRepository:
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
target_created_at_stmt = self._apply_visibility_filter(
stmt=target_created_at_stmt,
visibility_mask=visibility_mask,
)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
@@ -175,6 +185,10 @@ class AgentRepository:
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more_stmt = (
select(AgentChatMessage.id)
@@ -183,6 +197,10 @@ class AgentRepository:
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more_stmt = self._apply_visibility_filter(
stmt=has_more_stmt,
visibility_mask=visibility_mask,
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
@@ -196,7 +214,11 @@ class AgentRepository:
}
async def get_recent_messages_by_user_window(
self, *, session_id: str, user_message_limit: int
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
@@ -210,6 +232,10 @@ class AgentRepository:
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.desc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
if not messages_desc:
return []
@@ -294,6 +320,21 @@ class AgentRepository:
)
return payload_model.model_dump(mode="json", exclude_none=True)
def _apply_visibility_filter(
self,
*,
stmt: Select,
visibility_mask: int | None,
) -> Select:
if visibility_mask is None:
return stmt
required_mask = max(int(visibility_mask), 0)
if required_mask == 0:
return stmt
return stmt.where(
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
)
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
+29 -3
View File
@@ -47,6 +47,7 @@ logger = get_logger("v1.agent.router")
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
@@ -72,8 +73,14 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
else:
ttl = await redis.ttl(key)
if int(ttl) < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
after_decr = await redis.decr(key)
if int(after_decr) <= 0:
await redis.delete(key)
return False
return True
except Exception as exc: # noqa: BLE001
@@ -82,7 +89,7 @@ async def _acquire_sse_slot(*, user_id: str) -> bool:
user_id=user_id,
reason=str(exc),
)
return False
return True
async def _release_sse_slot(*, user_id: str) -> None:
@@ -92,10 +99,21 @@ async def _release_sse_slot(*, user_id: str) -> None:
count = await redis.decr(key)
if int(count) <= 0:
await redis.delete(key)
return None
ttl = await redis.ttl(key)
if int(ttl) < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception: # noqa: BLE001
return None
def _is_terminal_run_event(event: dict[str, object]) -> bool:
raw_event_type = event.get("type")
return (
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
)
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
@@ -145,8 +163,13 @@ async def stream_events(
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
terminal_event_reached = False
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
while (
not terminal_event_reached
and not await request.is_disconnected()
and idle_polls < idle_limit
):
try:
rows = await service.stream_events(
thread_id=thread_id,
@@ -181,6 +204,9 @@ async def stream_events(
continue
cursor = row_id
yield to_sse_event(row_id, event)
if _is_terminal_run_event(event):
terminal_event_reached = True
break
finally:
await _release_sse_slot(user_id=str(current_user.id))
+84 -9
View File
@@ -17,6 +17,9 @@ from core.auth.models import CurrentUser
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.forwarded_props import parse_forwarded_props_agent_type
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachment,
@@ -51,7 +54,11 @@ class AgentRepositoryLike(Protocol):
async def rollback(self) -> None: ...
async def get_history_day(
self, *, session_id: str, before: date | None
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
@@ -62,8 +69,13 @@ class AgentRepositoryLike(Protocol):
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
@@ -138,6 +150,17 @@ class AgentService:
created = False
thread_id = run_input.thread_id
run_id = run_input.run_id
forwarded_props = getattr(run_input, "forwarded_props", None)
try:
agent_type = parse_forwarded_props_agent_type(forwarded_props)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if agent_type == "memory":
raise HTTPException(
status_code=422,
detail="memory mode is automation-only",
)
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except HTTPException as exc:
@@ -161,25 +184,21 @@ class AgentService:
run_input=run_input,
current_user=current_user,
)
visibility_mask = await self._resolve_user_message_visibility_mask(
agent_type=agent_type
)
await self._repository.persist_user_message(
session_id=thread_id,
content=user_message_text,
metadata=user_message_metadata,
visibility_mask=visibility_mask,
)
await self._repository.commit()
forwarded_props = getattr(run_input, "forwarded_props", None)
system_agent_mode = "worker"
if isinstance(forwarded_props, dict):
raw_mode = forwarded_props.get("system_agent_mode")
if isinstance(raw_mode, str) and raw_mode.strip():
system_agent_mode = raw_mode.strip().lower()
task_id = await self._queue.enqueue(
command={
"command": "run",
"owner_id": str(current_user.id),
"system_agent_mode": system_agent_mode,
"run_input": run_input.model_dump(
mode="json", by_alias=True, exclude_none=True
),
@@ -193,6 +212,61 @@ class AgentService:
created=created,
)
async def _resolve_user_message_visibility_mask(self, *, agent_type: str) -> int:
normalized_agent_type = agent_type.strip().lower()
history_bit_mask = bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY))
if normalized_agent_type == "memory":
return bit_mask(bit=18)
agent_config = await self._repository.get_system_agent_config(
agent_type=normalized_agent_type
)
if agent_config is None:
raise HTTPException(
status_code=422, detail="invalid forwarded_props.agent_type"
)
llm_config = SystemAgentLLMConfig.model_validate(
(agent_config.get("config") if isinstance(agent_config, dict) else {}) or {}
)
agent_mask = bit_mask(bit=llm_config.visibility_consumer_bit)
if normalized_agent_type == "worker":
router_config = await self._repository.get_system_agent_config(
agent_type="router"
)
worker_config = await self._repository.get_system_agent_config(
agent_type="worker"
)
if router_config is None or worker_config is None:
raise HTTPException(
status_code=500,
detail="system agent visibility config missing",
)
router_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
router_config.get("config")
if isinstance(router_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
worker_mask = bit_mask(
bit=SystemAgentLLMConfig.model_validate(
(
worker_config.get("config")
if isinstance(worker_config, dict)
else {}
)
or {}
).visibility_consumer_bit
)
return history_bit_mask | router_mask | worker_mask
return history_bit_mask | agent_mask
async def _prepare_user_message(
self,
*,
@@ -408,6 +482,7 @@ class AgentService:
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
)
messages: list[HistoryMessage] = []
+127 -199
View File
@@ -2,28 +2,22 @@ from __future__ import annotations
import asyncio
import time
from collections.abc import Mapping
from typing import Any, cast
from urllib.parse import urlparse
from pydantic import ValidationError
from fastapi import HTTPException
from supabase import AuthError
from core.config.settings import config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByEmailResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
UserByPhoneResponse,
)
from v1.auth.service import AuthServiceGateway
@@ -36,7 +30,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_email: dict[str, Any] = {}
self._users_by_phone: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
@@ -44,47 +38,31 @@ class SupabaseAuthGateway(AuthServiceGateway):
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
payload: dict[str, Any] = {
"email": request.email,
"password": request.password,
"data": metadata,
"phone": request.phone,
"options": {"should_create_user": True},
}
if request.redirect_to:
payload["options"] = {
"email_redirect_to": _validate_redirect_url(request.redirect_to)
}
try:
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
await asyncio.to_thread(sign_in_with_otp, payload)
except AuthError as exc:
logger.warning("Signup failed", error_type=type(exc).__name__)
logger.warning("Send otp failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise HTTPException(
status_code=422, detail="Invalid signup request"
) from exc
raise HTTPException(status_code=429, detail="Too many requests") from exc
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
if request.type != "signup":
raise HTTPException(status_code=422, detail="Invalid request")
client = self._get_client()
payload: dict[str, Any] = {
"type": request.type,
"email": request.email,
"type": "sms",
"phone": request.phone,
"token": request.token,
}
try:
@@ -92,7 +70,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
logger.warning("Signup verify failed", error_type=type(exc).__name__)
logger.warning("Create phone session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
@@ -102,45 +80,6 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid verification code"
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
if request.type == "recovery":
await self.request_password_reset(
PasswordResetRequest(
email=request.email,
redirect_to=request.redirect_to,
)
)
return
payload: dict[str, Any] = {"type": request.type, "email": request.email}
try:
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
logger.warning("Login failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
@@ -189,98 +128,84 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid refresh token"
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
admin_client = self._get_admin_client()
normalized_email = email.lower()
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
normalized_phone = _normalize_phone(phone)
if not normalized_phone:
raise HTTPException(status_code=404, detail="User not found")
now = time.monotonic()
if now >= self._user_lookup_cache_expires_at:
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_email: dict[str, Any] = {}
for candidate in users:
candidate_email = str(getattr(candidate, "email", "")).lower()
if candidate_email:
users_by_email[candidate_email] = candidate
self._users_by_email = users_by_email
self._user_lookup_cache_expires_at = (
now + self._user_lookup_cache_ttl_seconds
)
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_email.get(normalized_email)
user = self._users_by_phone.get(normalized_phone)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return UserByEmailResponse(
user_phone = _normalize_phone(getattr(user, "phone", ""))
if not user_phone:
raise HTTPException(status_code=404, detail="User not found")
return UserByPhoneResponse(
id=str(getattr(user, "id", "")),
email=str(getattr(user, "email", "")),
phone=user_phone,
created_at=str(getattr(user, "created_at", "")),
email_confirmed_at=(
str(getattr(user, "email_confirmed_at", ""))
if getattr(user, "email_confirmed_at", None)
phone_confirmed_at=(
str(getattr(user, "phone_confirmed_at", ""))
if getattr(user, "phone_confirmed_at", None)
else None
),
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {
"redirect_to": _validate_redirect_url(request.redirect_to)
}
await asyncio.to_thread(reset_email, email, options=options)
else:
await asyncio.to_thread(reset_email, email)
except AuthError as exc:
logger.warning(
"Password reset request failed",
error_type=type(exc).__name__,
)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_phone_search_query(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
if normalized_query.startswith("+"):
matched_user = self._users_by_phone.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
digits = _digits_only(normalized_query)
if not digits:
return []
matched_records: list[tuple[str, str]] = []
for cached_phone, candidate in self._users_by_phone.items():
candidate_digits = _digits_only(cached_phone)
if not candidate_digits.endswith(digits):
continue
user_id = str(getattr(candidate, "id", ""))
if user_id:
matched_records.append((cached_phone, user_id))
if not matched_records:
return []
unique_ids: list[str] = []
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
if user_id in unique_ids:
continue
unique_ids.append(user_id)
if len(unique_ids) >= max(1, limit):
break
return unique_ids
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
if now < self._user_lookup_cache_expires_at:
return
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
user_id = str(getattr(user, "id", "")) if user is not None else ""
if session is None or not user_id:
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
except AuthError as exc:
logger.warning(
"Password reset confirm failed", error_type=type(exc).__name__
)
if _is_auth_upstream_unavailable(exc):
raise HTTPException(
status_code=503,
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise HTTPException(
status_code=401, detail="Invalid or expired verification code"
) from exc
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_phone: dict[str, Any] = {}
for candidate in users:
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
if candidate_phone:
users_by_phone[candidate_phone] = candidate
self._users_by_phone = users_by_phone
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
@@ -312,55 +237,24 @@ def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
return any(token in code or token in message for token in indicators)
def _validate_redirect_url(url: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
allowed_origins = {
_normalize_origin(candidate)
for candidate in config.cors.allow_origins
if _is_http_origin(candidate)
}
if origin not in allowed_origins:
raise HTTPException(status_code=422, detail="Invalid redirect URL")
return url
def _normalize_origin(value: str) -> str:
parsed = urlparse(value)
return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
def _is_http_origin(value: str) -> bool:
parsed = urlparse(value)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _coerce_reset_email(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, Mapping):
nested = value.get("email") or value.get("value")
if isinstance(nested, str):
return nested
raise HTTPException(status_code=422, detail="Invalid email")
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise HTTPException(status_code=401, detail=failure_message)
email = getattr(user, "email", None)
if not email:
phone = _normalize_phone(getattr(user, "phone", None))
if not phone:
raise HTTPException(status_code=401, detail=failure_message)
auth_user = AuthUser(id=str(user.id), email=str(email))
try:
auth_user = AuthUser(id=str(user.id), phone=str(phone))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid phone format",
error_type=type(exc).__name__,
)
raise HTTPException(status_code=401, detail=failure_message) from exc
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
@@ -389,3 +283,37 @@ def _list_auth_users(client: Any) -> list[Any]:
page += 1
return users
def _normalize_phone(raw_phone: object) -> str | None:
phone = str(raw_phone).strip()
for separator in (" ", "-", "(", ")"):
phone = phone.replace(separator, "")
if not phone:
return None
if phone.startswith("00") and len(phone) > 2:
return f"+{phone[2:]}"
if phone.startswith("+"):
return phone
if phone.isdigit():
return f"+{phone}"
return None
def _normalize_phone_search_query(raw_query: str) -> str | None:
query = raw_query.strip()
for separator in (" ", "-", "(", ")"):
query = query.replace(separator, "")
if not query:
return None
if query.startswith("00") and len(query) > 2:
return f"+{query[2:]}"
if query.startswith("+"):
return query
if query.isdigit():
return query
return None
def _digits_only(value: str) -> str:
return "".join(ch for ch in value if ch.isdigit())
+52 -77
View File
@@ -1,20 +1,16 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request, Response
from fastapi import HTTPException
from core.config.settings import config
from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.auth.schemas import (
PasswordResetConfirmRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
from v1.auth.service import AuthService
@@ -22,80 +18,49 @@ from v1.auth.service import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post(
"/verifications", response_model=VerificationCreateResponse, status_code=202
)
async def create_verification(
payload: VerificationCreateRequest,
service: AuthService = Depends(get_auth_service),
) -> VerificationCreateResponse:
await enforce_rate_limit(
scope="signup_start",
identifier=payload.email,
limit=5,
window_seconds=60,
)
return await service.create_verification(payload)
@router.post("/verify", response_model=SessionResponse)
async def verify(
payload: VerificationVerifyRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse | Response:
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
limit = 10
window_seconds = 600
await enforce_rate_limit(
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=limit,
window_seconds=window_seconds,
)
if payload.type == "signup":
return await service.verify_verification(payload)
if payload.new_password is None:
raise HTTPException(status_code=422, detail="Invalid request")
await service.confirm_password_reset(
PasswordResetConfirmRequest(
email=payload.email,
token=payload.token,
new_password=payload.new_password,
)
)
return Response(status_code=204)
@router.post("/resend", status_code=204)
async def resend(
payload: VerificationResendRequest,
@router.post("/otp/send", status_code=204)
async def send_otp(
payload: OtpSendRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
client_ip = _client_ip(request)
await enforce_rate_limit(
scope=scope,
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
limit=5,
scope="otp_send_phone",
identifier=payload.phone,
limit=3,
window_seconds=60,
)
await service.resend_verification(payload)
await enforce_rate_limit(
scope="otp_send_ip",
identifier=client_ip,
limit=20,
window_seconds=60,
)
await service.send_otp(payload)
return Response(status_code=204)
@router.post("/sessions", response_model=SessionResponse)
async def create_session(
payload: SessionCreateRequest,
@router.post("/phone-session", response_model=SessionResponse)
async def create_phone_session(
payload: PhoneSessionCreateRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="login",
identifier=payload.email,
limit=10,
window_seconds=60,
scope="phone_session_phone",
identifier=payload.phone,
limit=6,
window_seconds=300,
)
return await service.create_session(payload)
await enforce_rate_limit(
scope="phone_session_ip",
identifier=client_ip,
limit=20,
window_seconds=300,
)
return await service.create_phone_session(payload)
@router.post("/sessions/refresh", response_model=SessionResponse)
@@ -130,13 +95,23 @@ async def delete_session(
def _client_ip(request: Request) -> str:
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
if first:
return first
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
host = request.client.host if request.client else ""
return host or "unknown"
if not host:
return "unknown"
if _should_trust_proxy_headers(host):
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
if first:
return first
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
return host
def _should_trust_proxy_headers(host: str) -> bool:
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
return host in trusted_proxies
+13 -51
View File
@@ -1,49 +1,22 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
OtpType = Literal["signup", "recovery"]
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
class VerificationCreateRequest(BaseModel):
class OtpSendRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
username: str = Field(min_length=3, max_length=30)
email: EmailStr
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
redirect_to: str | None = None
invite_code: str | None = None
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class VerificationResendRequest(BaseModel):
email: EmailStr
type: OtpType = "signup"
redirect_to: str | None = None
class PhoneSessionCreateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
class VerificationVerifyRequest(BaseModel):
type: OtpType = "signup"
email: EmailStr
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
token: str = Field(pattern=r"^\d{6}$")
new_password: str | None = Field(
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
)
@model_validator(mode="after")
def validate_type_payload(self) -> "VerificationVerifyRequest":
if self.type == "recovery" and self.new_password is None:
raise ValueError("new_password is required when type is recovery")
if self.type == "signup" and self.new_password is not None:
raise ValueError("new_password is only allowed when type is recovery")
return self
class SessionCreateRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
class SessionRefreshRequest(BaseModel):
@@ -56,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
class AuthUser(BaseModel):
id: str
email: EmailStr
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class SessionResponse(BaseModel):
@@ -67,23 +40,12 @@ class SessionResponse(BaseModel):
user: AuthUser
class UserByEmailResponse(BaseModel):
class UserByPhoneResponse(BaseModel):
id: str
email: EmailStr
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
created_at: str
email_confirmed_at: str | None = None
phone_confirmed_at: str | None = None
class VerificationCreateResponse(BaseModel):
email: EmailStr
class PasswordResetRequest(BaseModel):
email: EmailStr
redirect_to: str | None = None
class PasswordResetConfirmRequest(BaseModel):
email: EmailStr
token: str = Field(pattern=r"^\d{6}$")
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
+10 -66
View File
@@ -1,52 +1,30 @@
from __future__ import annotations
import re
from typing import Protocol
from v1.auth.schemas import (
PasswordResetConfirmRequest,
PasswordResetRequest,
SessionCreateRequest,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
VerificationCreateRequest,
VerificationCreateResponse,
VerificationResendRequest,
VerificationVerifyRequest,
)
class AuthServiceGateway(Protocol):
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
async def send_otp(self, request: OtpSendRequest) -> None:
raise NotImplementedError
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
raise NotImplementedError
async def resend_verification(self, request: VerificationResendRequest) -> None:
raise NotImplementedError
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
raise NotImplementedError
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
raise NotImplementedError
async def delete_session(self, refresh_token: str | None) -> None:
raise NotImplementedError
async def request_password_reset(self, request: PasswordResetRequest) -> None:
raise NotImplementedError
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
raise NotImplementedError
class AuthService:
_gateway: AuthServiceGateway
@@ -54,50 +32,16 @@ class AuthService:
def __init__(self, gateway: AuthServiceGateway) -> None:
self._gateway = gateway
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
normalized_invite_code = _normalize_invite_code(request.invite_code)
normalized_request = request.model_copy(
update={"invite_code": normalized_invite_code}
)
return await self._gateway.create_verification(normalized_request)
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
async def verify_verification(
self, request: VerificationVerifyRequest
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
return await self._gateway.verify_verification(request)
async def resend_verification(self, request: VerificationResendRequest) -> None:
await self._gateway.resend_verification(request)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
return await self._gateway.create_session(request)
return await self._gateway.create_phone_session(request)
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
await self._gateway.request_password_reset(request)
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
await self._gateway.confirm_password_reset(request)
_INVITE_CODE_PATTERN = re.compile(r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{4}$")
def _normalize_invite_code(value: str | None) -> str | None:
if value is None:
return None
normalized = value.strip().upper()
if not normalized:
return None
return normalized if _INVITE_CODE_PATTERN.fullmatch(normalized) else None
+6 -2
View File
@@ -5,7 +5,7 @@ from typing import ClassVar
from uuid import UUID
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from schemas.inbox.messages import (
CalendarContent,
@@ -154,7 +154,11 @@ _PERMISSION_EDIT = 4
class ScheduleItemShareRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
email: EmailStr = Field(..., description="Email of user to share with")
phone: str = Field(
...,
pattern=r"^\+861[3-9]\d{9}$",
description="Phone of user to share with",
)
permission_view: bool = Field(True, description="Grant view permission")
permission_edit: bool = Field(False, description="Grant edit permission")
permission_invite: bool = Field(False, description="Grant invite permission")
+7 -7
View File
@@ -31,19 +31,19 @@ from v1.schedule_items.schemas import (
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from v1.auth.schemas import UserByEmailResponse
from v1.auth.schemas import UserByPhoneResponse
logger = get_logger("v1.schedule_items.service")
class AuthByEmailGateway(Protocol):
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
class AuthByPhoneGateway(Protocol):
async def get_user_by_phone(self, phone: str) -> "UserByPhoneResponse": ...
class ScheduleItemService(BaseService):
_repository: ScheduleItemRepository
_session: AsyncSession
_auth_gateway: AuthByEmailGateway
_auth_gateway: AuthByPhoneGateway
_inbox_repository: InboxMessageRepository
def __init__(
@@ -51,7 +51,7 @@ class ScheduleItemService(BaseService):
repository: ScheduleItemRepository,
session: AsyncSession,
current_user: CurrentUser | None,
auth_gateway: AuthByEmailGateway | None = None,
auth_gateway: AuthByPhoneGateway | None = None,
inbox_repository: InboxMessageRepository | None = None,
) -> None:
super().__init__(current_user=current_user)
@@ -329,7 +329,7 @@ class ScheduleItemService(BaseService):
detail=f"You can only share with permissions up to {inviter_permission}",
)
target_user = await self._auth_gateway.get_user_by_email(request.email)
target_user = await self._auth_gateway.get_user_by_phone(request.phone)
recipient_id = UUID(target_user.id)
existing = await self._repository.get_subscription(item_id, recipient_id)
@@ -404,7 +404,7 @@ class ScheduleItemService(BaseService):
except ValueError:
await self._session.rollback()
logger.exception(
"Auth lookup returned invalid user id", email=request.email
"Auth lookup returned invalid user id", phone=request.phone
)
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
+4 -4
View File
@@ -76,11 +76,11 @@ async def _verify_user_with_supabase(token: str) -> CurrentUser | None:
parsed_id = UUID(user_id)
except ValueError:
return None
email = getattr(user, "email", None)
phone = getattr(user, "phone", None)
role = getattr(user, "role", None)
return CurrentUser(
id=parsed_id,
email=email if isinstance(email, str) else None,
phone=phone if isinstance(phone, str) else None,
role=role if isinstance(role, str) else None,
)
@@ -125,9 +125,9 @@ async def get_current_user(
raise HTTPException(status_code=401, detail="Unauthorized")
logger.debug("JWT validation successful", user_id=str(user_id))
email = payload.get("email") if isinstance(payload.get("email"), str) else None
phone = payload.get("phone") if isinstance(payload.get("phone"), str) else None
role = payload.get("role") if isinstance(payload.get("role"), str) else None
return CurrentUser(id=user_id, email=email, role=role)
return CurrentUser(id=user_id, phone=phone, role=role)
async def get_user_repository(
+1 -1
View File
@@ -38,7 +38,7 @@ class UserRepository(Protocol):
...
async def search_users(self, query: str, limit: int = 20) -> list[Profile]:
"""Search users by username (ilike) or email (exact match)."""
"""Search users by username (ilike) or phone (exact match)."""
...
+63 -31
View File
@@ -21,19 +21,22 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from schemas.user.context import UserContext
from v1.auth.schemas import UserByEmailResponse
logger = get_logger("v1.users.service")
_EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$")
class AuthLookupGateway(Protocol):
async def get_user_id_by_email(self, email: str) -> str | None: ...
async def search_user_ids_by_phone(
self, query: str, limit: int = 20
) -> list[str]: ...
class AuthByEmailGateway(Protocol):
async def get_user_by_email(self, email: str) -> "UserByEmailResponse": ...
class AuthByPhoneGateway(Protocol):
async def search_user_ids_by_phone(
self, query: str, limit: int = 20
) -> list[str]: ...
class UserContextInvalidator(Protocol):
@@ -41,15 +44,14 @@ class UserContextInvalidator(Protocol):
class AuthLookupAdapter:
def __init__(self, gateway: AuthByEmailGateway) -> None:
def __init__(self, gateway: AuthByPhoneGateway) -> None:
self._gateway = gateway
async def get_user_id_by_email(self, email: str) -> str | None:
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
try:
response = await self._gateway.get_user_by_email(email)
return response.id
return await self._gateway.search_user_ids_by_phone(query, limit=limit)
except HTTPException:
return None
return []
class UserService(BaseService):
@@ -92,11 +94,11 @@ class UserService(BaseService):
if user is None:
raise HTTPException(status_code=404, detail="User not found")
email = self._current_user.email if self._current_user else None
phone = self._current_user.phone if self._current_user else None
return UserContext(
id=str(user.id),
username=user.username,
email=email,
phone=phone,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
@@ -152,11 +154,11 @@ class UserService(BaseService):
error=str(exc),
)
email = self._current_user.email if self._current_user else None
phone = self._current_user.phone if self._current_user else None
return UserContext(
id=str(user.id),
username=user.username,
email=email,
phone=phone,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
@@ -181,36 +183,59 @@ class UserService(BaseService):
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
query = request.query.strip()
if _EMAIL_PATTERN.match(query):
return await self._search_by_email(query)
if _looks_like_phone_query(query):
phone_results = await self._search_by_phone(query)
if not query.isdigit():
return phone_results
username_results = await self._search_by_username(query)
if not phone_results:
return username_results
merged_by_id = {result.id: result for result in phone_results}
for result in username_results:
merged_by_id.setdefault(result.id, result)
return list(merged_by_id.values())
return await self._search_by_username(query)
async def _search_by_email(self, email: str) -> list[UserContext]:
async def _search_by_phone(self, phone: str) -> list[UserContext]:
if self._auth_gateway is None:
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
user_id_str = await self._auth_gateway.get_user_id_by_email(email)
if user_id_str is None:
user_id_values = await self._auth_gateway.search_user_ids_by_phone(
phone, limit=20
)
if not user_id_values:
return []
user_ids: list[UUID] = []
for raw_id in user_id_values:
try:
user_ids.append(UUID(raw_id))
except ValueError:
continue
if not user_ids:
return []
try:
user = await self._repository.get_by_user_id(UUID(user_id_str))
users_by_id = await self._repository.get_by_user_ids(user_ids)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
if user is None:
return []
return [
UserContext(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
results: list[UserContext] = []
for user_id in user_ids:
user = users_by_id.get(user_id)
if user is None:
continue
results.append(
UserContext(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
)
]
return results
async def _search_by_username(self, query: str) -> list[UserContext]:
try:
@@ -228,3 +253,10 @@ class UserService(BaseService):
)
for user in users
]
def _looks_like_phone_query(query: str) -> bool:
if not _PHONE_QUERY_PATTERN.fullmatch(query):
return False
digits_count = sum(char.isdigit() for char in query)
return digits_count >= 4