chore: 迁移到 social-app 架构,集成 Supabase 和 taskiq worker

This commit is contained in:
qzl
2026-04-02 16:36:35 +08:00
parent 695adb7d6f
commit 92cdfd9fca
132 changed files with 5802 additions and 759 deletions
View File
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,35 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
import re
from typing import Any
import yaml
from schemas.domain.automation import AutomationJobConfig
_CONFIG_NAME_PATTERN = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _automation_yaml_path(config_name: str) -> Path:
if not _CONFIG_NAME_PATTERN.fullmatch(config_name):
raise ValueError("invalid automation config name")
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "automation"
/ f"{config_name}.yaml"
)
@lru_cache(maxsize=16)
def load_static_automation_job_config(*, config_name: str) -> AutomationJobConfig:
path = _automation_yaml_path(config_name)
with path.open("r", encoding="utf-8") as file:
loaded: Any = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"invalid automation config format: {path}")
return AutomationJobConfig.model_validate(loaded)
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.auth.registration_bootstrap import (
RegistrationAutomationBootstrapService,
RegistrationBootstrapRepository,
)
from v1.auth.service import AuthService
def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
bootstrapper = RegistrationAutomationBootstrapService(
repository=RegistrationBootstrapRepository(session=session),
session=session,
)
return AuthService(
gateway=SupabaseAuthGateway(),
registration_bootstrapper=bootstrapper,
)
+430
View File
@@ -0,0 +1,430 @@
from __future__ import annotations
import asyncio
import time
from typing import Any, cast
from pydantic import ValidationError
from supabase import AuthError
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
UserByIdResponse,
UserByPhoneResponse,
)
from v1.auth.service import AuthServiceGateway
logger = get_logger("v1.auth.gateway")
AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
def _auth_error(
*,
status_code: int,
code: str,
detail: str,
) -> ApiProblemError:
return ApiProblemError(status_code=status_code, code=code, detail=detail)
class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_phone: dict[str, Any] = {}
self._users_by_id: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def send_otp(self, request: OtpSendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {
"phone": request.phone,
"options": {"should_create_user": True},
}
try:
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
await asyncio.to_thread(sign_in_with_otp, payload)
except AuthError as exc:
logger.warning("Send otp failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
) from exc
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "sms",
"phone": request.phone,
"token": request.token,
}
try:
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(
response,
"Invalid verification code",
"AUTH_VERIFICATION_CODE_INVALID",
)
except AuthError as exc:
logger.warning("Create phone session failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_VERIFICATION_CODE_INVALID",
detail="Invalid verification code",
) from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(
response,
"Invalid refresh token",
"AUTH_REFRESH_TOKEN_INVALID",
)
except AuthError as exc:
logger.warning("Refresh failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_MISSING",
detail="Missing refresh token",
)
client = self._get_client()
try:
response = await asyncio.to_thread(
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
)
await asyncio.to_thread(
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
if _is_auth_upstream_unavailable(exc):
raise _auth_error(
status_code=503,
code="AUTH_SERVICE_UNAVAILABLE",
detail=AUTH_UNAVAILABLE_DETAIL,
) from exc
raise _auth_error(
status_code=401,
code="AUTH_REFRESH_TOKEN_INVALID",
detail="Invalid refresh token",
) from exc
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
normalized_phone = _normalize_phone(phone)
if not normalized_phone:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
await self._refresh_user_lookup_cache_if_needed()
user = self._users_by_phone.get(normalized_phone)
if user is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
user_phone = _normalize_phone(getattr(user, "phone", ""))
if not user_phone:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return UserByPhoneResponse(
id=str(getattr(user, "id", "")),
phone=user_phone,
created_at=str(getattr(user, "created_at", "")),
phone_confirmed_at=(
str(getattr(user, "phone_confirmed_at", ""))
if getattr(user, "phone_confirmed_at", None)
else None
),
)
async def get_user_by_id(self, user_id: str) -> UserByIdResponse:
users = await self.get_users_by_ids([user_id])
resolved = users.get(user_id)
if resolved is None:
raise _auth_error(
status_code=404,
code="AUTH_USER_NOT_FOUND",
detail="User not found",
)
return resolved
async def get_users_by_ids(
self, user_ids: list[str]
) -> dict[str, UserByIdResponse]:
await self._refresh_user_lookup_cache_if_needed()
resolved: dict[str, UserByIdResponse] = {}
for raw_user_id in user_ids:
normalized_user_id = raw_user_id.strip()
if not normalized_user_id:
continue
user = self._users_by_id.get(normalized_user_id)
if user is None:
continue
user_attrs = getattr(user, "user", user)
resolved[normalized_user_id] = UserByIdResponse(
id=str(getattr(user_attrs, "id", "")),
phone=getattr(user_attrs, "phone", None),
created_at=str(getattr(user_attrs, "created_at", "")),
phone_confirmed_at=(
str(getattr(user_attrs, "phone_confirmed_at", ""))
if getattr(user_attrs, "phone_confirmed_at", None)
else None
),
)
return resolved
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
normalized_query = _normalize_phone_search_query(query)
if not normalized_query:
return []
await self._refresh_user_lookup_cache_if_needed()
if normalized_query.startswith("+"):
matched_user = self._users_by_phone.get(normalized_query)
if matched_user is None:
return []
user_id = str(getattr(matched_user, "id", ""))
return [user_id] if user_id else []
digits = _digits_only(normalized_query)
if not digits:
return []
matched_records: list[tuple[str, str]] = []
for cached_phone, candidate in self._users_by_phone.items():
candidate_digits = _digits_only(cached_phone)
if not candidate_digits.endswith(digits):
continue
user_id = str(getattr(candidate, "id", ""))
if user_id:
matched_records.append((cached_phone, user_id))
if not matched_records:
return []
unique_ids: list[str] = []
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
if user_id in unique_ids:
continue
unique_ids.append(user_id)
if len(unique_ids) >= max(1, limit):
break
return unique_ids
async def _refresh_user_lookup_cache_if_needed(self) -> None:
now = time.monotonic()
if now < self._user_lookup_cache_expires_at:
return
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_phone: dict[str, Any] = {}
users_by_id: dict[str, Any] = {}
for candidate in users:
candidate_id = str(getattr(candidate, "id", "")).strip()
if candidate_id:
users_by_id[candidate_id] = candidate
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
if candidate_phone:
users_by_phone[candidate_phone] = candidate
self._users_by_id = users_by_id
self._users_by_phone = users_by_phone
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
raw_status = getattr(exc, "status", None)
if raw_status is None:
raw_status = getattr(exc, "status_code", None)
if isinstance(raw_status, int) and 500 <= raw_status < 600:
return True
raw_code = getattr(exc, "code", None)
code = str(raw_code).lower() if raw_code is not None else ""
message = str(exc).lower()
indicators = (
"request_timeout",
"timed out",
"timeout",
"gateway timeout",
"bad_gateway",
"service_unavailable",
"internal_server_error",
"unexpected_failure",
"upstream",
"500",
"502",
"503",
"504",
"5xx",
)
return any(token in code or token in message for token in indicators)
def _map_auth_response(
response: object, failure_message: str, failure_code: str
) -> SessionResponse:
session = getattr(response, "session", None)
user = getattr(response, "user", None)
if session is None or user is None:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
phone = _normalize_phone(getattr(user, "phone", None))
if not phone:
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
)
try:
auth_user = AuthUser(id=str(user.id), phone=str(phone))
except ValidationError as exc:
logger.warning(
"Auth response returned invalid phone format",
error_type=type(exc).__name__,
)
raise _auth_error(
status_code=401,
code=failure_code,
detail=failure_message,
) from exc
return SessionResponse(
access_token=str(session.access_token),
refresh_token=str(session.refresh_token),
expires_in=int(session.expires_in or 0),
token_type=str(session.token_type),
user=auth_user,
)
def _list_auth_users(client: Any) -> list[Any]:
users: list[Any] = []
page = 1
max_pages = 100
while page <= max_pages:
response = client.auth.admin.list_users(page=page, per_page=100)
batch = (
list(response)
if isinstance(response, list)
else list(getattr(response, "users", []))
)
users.extend(batch)
if len(batch) < 100:
break
page += 1
return users
def _sanitize_phone_token(raw: object) -> str:
token = str(raw).strip()
for separator in (" ", "-", "(", ")"):
token = token.replace(separator, "")
return token
def _normalize_phone(raw_phone: object) -> str | None:
phone = _sanitize_phone_token(raw_phone)
if not phone:
return None
if phone.startswith("00") and len(phone) > 2:
return f"+{phone[2:]}"
if phone.startswith("+"):
return phone
if phone.isdigit():
return f"+{phone}"
return None
def _normalize_phone_search_query(raw_query: str) -> str | None:
query = _sanitize_phone_token(raw_query)
if not query:
return None
if query.startswith("00") and len(query) > 2:
return f"+{query[2:]}"
if query.startswith("+"):
return query
if query.isdigit():
return query
return None
def _digits_only(value: str) -> str:
return "".join(ch for ch in value if ch.isdigit())
+114
View File
@@ -0,0 +1,114 @@
from __future__ import annotations
import asyncio
from collections import deque
from time import monotonic
from core.http.errors import ApiProblemError
from core.logging import get_logger
from services.base.redis import get_or_init_redis_client
_BUCKETS: dict[str, deque[float]] = {}
_LAST_SEEN: dict[str, float] = {}
_LOCK = asyncio.Lock()
_CLEANUP_INTERVAL = 200
_CALL_COUNT = 0
logger = get_logger("v1.auth.rate_limit")
_REDIS_LIMIT_SCRIPT = """
local current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ARGV[1])
end
return current
"""
async def enforce_rate_limit(
*,
scope: str,
identifier: str,
limit: int,
window_seconds: int,
) -> None:
key = f"auth:rate_limit:{scope}:{identifier.lower()}"
try:
await _enforce_rate_limit_with_redis(
key=key,
limit=limit,
window_seconds=window_seconds,
)
return
except ApiProblemError:
raise
except Exception as exc: # noqa: BLE001
logger.warning(
"Rate limit fallback to in-memory",
scope=scope,
error_type=type(exc).__name__,
)
await _enforce_rate_limit_in_memory(
key=key,
limit=limit,
window_seconds=window_seconds,
)
async def _enforce_rate_limit_with_redis(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
client = await get_or_init_redis_client()
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
if int(current) > limit:
raise ApiProblemError(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
)
async def _enforce_rate_limit_in_memory(
*,
key: str,
limit: int,
window_seconds: int,
) -> None:
global _CALL_COUNT
now = monotonic()
async with _LOCK:
bucket = _BUCKETS.setdefault(key, deque())
_LAST_SEEN[key] = now
cutoff = now - float(window_seconds)
while bucket and bucket[0] <= cutoff:
bucket.popleft()
if len(bucket) >= limit:
raise ApiProblemError(
status_code=429,
code="AUTH_TOO_MANY_REQUESTS",
detail="Too many requests",
)
bucket.append(now)
_CALL_COUNT += 1
if _CALL_COUNT % _CLEANUP_INTERVAL == 0:
_cleanup_stale_buckets(now)
def _cleanup_stale_buckets(now: float) -> None:
stale_keys = [
key
for key, last_seen in _LAST_SEEN.items()
if key not in _BUCKETS or (not _BUCKETS[key] and now - last_seen > 3600)
]
for key in stale_keys:
_BUCKETS.pop(key, None)
_LAST_SEEN.pop(key, None)
def reset_rate_limit_state() -> None:
_BUCKETS.clear()
_LAST_SEEN.clear()
global _CALL_COUNT
_CALL_COUNT = 0
@@ -0,0 +1,239 @@
from __future__ import annotations
from datetime import UTC, datetime, time, timedelta
from typing import Protocol
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from core.logging import get_logger
from models.automation_jobs import AutomationJob
from schemas.enums import AutomationJobStatus, MemoryType, ScheduleType
from models.profile import Profile
from schemas.domain.automation import AutomationJobConfig, ScheduleConfig
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from schemas.shared.user import parse_profile_settings
from v1.auth.automation_static_config import load_static_automation_job_config
from v1.auth.schemas import RegistrationBootstrapRequest
from v1.memories.repository import SQLAlchemyMemoriesRepository
logger = get_logger("v1.auth.registration_bootstrap")
class RegistrationBootstrapRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
self._memories_repository = SQLAlchemyMemoriesRepository(session)
async def get_profile_timezone(self, *, user_id: UUID) -> str:
stmt = select(Profile.settings).where(Profile.id == user_id)
settings = (await self._session.execute(stmt)).scalar_one_or_none()
parsed = parse_profile_settings(
settings if isinstance(settings, dict) else None
)
return parsed.preferences.timezone
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
next_run_at: datetime,
) -> bool:
stmt = (
insert(AutomationJob)
.values(
id=uuid4(),
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=title,
config=config.model_dump(mode="json"),
next_run_at=next_run_at,
timezone=timezone_name,
status=AutomationJobStatus.ACTIVE,
created_by=owner_id,
)
.on_conflict_do_nothing(
index_elements=["owner_id", "bootstrap_key"],
index_where=AutomationJob.deleted_at.is_(None)
& AutomationJob.bootstrap_key.is_not(None),
)
.returning(AutomationJob.id)
)
inserted_id = (await self._session.execute(stmt)).scalar_one_or_none()
await self._session.flush()
return inserted_id is not None
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool:
return await self._memories_repository.create_if_absent(
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None: ...
class RegistrationBootstrapRepositoryLike(Protocol):
async def get_profile_timezone(self, *, user_id: UUID) -> str: ...
async def insert_bootstrap_automation_job_if_absent(
self,
*,
owner_id: UUID,
bootstrap_key: str,
title: str,
config: AutomationJobConfig,
timezone_name: str,
next_run_at: datetime,
) -> bool: ...
async def upsert_initial_memory(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> bool: ...
class SessionLike(Protocol):
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
def compute_first_run_at_utc(
*,
now_utc: datetime,
timezone_name: str,
schedule: ScheduleConfig,
) -> datetime:
try:
timezone_obj = ZoneInfo(timezone_name)
except ZoneInfoNotFoundError:
timezone_obj = ZoneInfo("UTC")
local_now = now_utc.astimezone(timezone_obj)
run_clock = time(
hour=schedule.run_at.hour,
minute=schedule.run_at.minute,
tzinfo=timezone_obj,
)
if schedule.type == ScheduleType.DAILY:
candidate_local = datetime.combine(local_now.date(), run_clock)
if candidate_local <= local_now:
candidate_local = candidate_local + timedelta(days=1)
return candidate_local.astimezone(UTC)
weekdays = schedule.weekdays or []
if not weekdays:
raise ValueError("weekly schedule requires weekdays")
normalized_weekdays = sorted(set(weekdays))
for day_offset in range(0, 8):
candidate_day = local_now.date() + timedelta(days=day_offset)
if candidate_day.isoweekday() not in normalized_weekdays:
continue
candidate_local = datetime.combine(candidate_day, run_clock)
if candidate_local > local_now:
return candidate_local.astimezone(UTC)
fallback_day = local_now.date() + timedelta(days=7)
while fallback_day.isoweekday() not in normalized_weekdays:
fallback_day = fallback_day + timedelta(days=1)
fallback_local = datetime.combine(fallback_day, run_clock)
return fallback_local.astimezone(UTC)
class RegistrationAutomationBootstrapService:
def __init__(
self,
*,
repository: RegistrationBootstrapRepositoryLike,
session: SessionLike,
) -> None:
self._repository = repository
self._session = session
async def ensure_user_automation_jobs(self, *, user_id: str | UUID) -> None:
request = RegistrationBootstrapRequest.model_validate({"user_id": user_id})
owner_id = request.user_id
timezone_name = await self._repository.get_profile_timezone(user_id=owner_id)
definitions = [
{
"bootstrap_key": "memory_extraction",
"config_name": "memory_extraction",
"title": "记忆推送",
}
]
try:
inserted_any = False
created_or_updated_memory = False
user_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=UserMemoryContent().model_dump(mode="json"),
)
work_initialized = await self._repository.upsert_initial_memory(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=WorkProfileContent().model_dump(mode="json"),
)
created_or_updated_memory = user_initialized or work_initialized
for definition in definitions:
bootstrap_key = str(definition["bootstrap_key"])
job_config = load_static_automation_job_config(
config_name=str(definition["config_name"])
)
schedule = job_config.schedule
if schedule is None:
raise ValueError(
f"bootstrap job {bootstrap_key} has no schedule configured"
)
next_run_at = compute_first_run_at_utc(
now_utc=datetime.now(UTC),
timezone_name=timezone_name,
schedule=schedule,
)
inserted = (
await self._repository.insert_bootstrap_automation_job_if_absent(
owner_id=owner_id,
bootstrap_key=bootstrap_key,
title=str(definition["title"]),
config=job_config,
timezone_name=timezone_name,
next_run_at=next_run_at,
)
)
inserted_any = inserted_any or inserted
if inserted_any or created_or_updated_memory:
await self._session.commit()
logger.info(
"user automation jobs bootstrapped",
user_id=user_id,
timezone=timezone_name,
memory_initialized=created_or_updated_memory,
)
except Exception:
await self._session.rollback()
raise
+117
View File
@@ -0,0 +1,117 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request, Response
from core.config.settings import config
from v1.auth.rate_limit import enforce_rate_limit
from v1.auth.dependencies import get_auth_service
from v1.auth.schemas import (
OtpSendRequest,
PhoneSessionCreateRequest,
SessionDeleteRequest,
SessionRefreshRequest,
SessionResponse,
)
from v1.auth.service import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/otp/send", status_code=204)
async def send_otp(
payload: OtpSendRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="otp_send_phone",
identifier=payload.phone,
limit=3,
window_seconds=60,
)
await enforce_rate_limit(
scope="otp_send_ip",
identifier=client_ip,
limit=20,
window_seconds=60,
)
await service.send_otp(payload)
return Response(status_code=204)
@router.post("/phone-session", response_model=SessionResponse)
async def create_phone_session(
payload: PhoneSessionCreateRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
client_ip = _client_ip(request)
await enforce_rate_limit(
scope="phone_session_phone",
identifier=payload.phone,
limit=6,
window_seconds=300,
)
await enforce_rate_limit(
scope="phone_session_ip",
identifier=client_ip,
limit=20,
window_seconds=300,
)
return await service.create_phone_session(payload)
@router.post("/sessions/refresh", response_model=SessionResponse)
async def refresh_session(
payload: SessionRefreshRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> SessionResponse:
await enforce_rate_limit(
scope="refresh",
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
return await service.refresh_session(payload)
@router.delete("/sessions", status_code=204)
async def delete_session(
payload: SessionDeleteRequest,
request: Request,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="logout",
identifier=_client_ip(request),
limit=10,
window_seconds=60,
)
await service.delete_session(payload.refresh_token)
return Response(status_code=204)
def _client_ip(request: Request) -> str:
host = request.client.host if request.client else ""
if not host:
return "unknown"
if _should_trust_proxy_headers(host):
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
first = forwarded_for.split(",")[0].strip()
if first:
return first
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
return host
def _should_trust_proxy_headers(host: str) -> bool:
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
return host in trusted_proxies
+66
View File
@@ -0,0 +1,66 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
SUPABASE_PASSWORD_MIN_LENGTH = 6
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
class OtpSendRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class PhoneSessionCreateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
token: str = Field(pattern=r"^\d{6}$")
class SessionRefreshRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class SessionDeleteRequest(BaseModel):
refresh_token: str = Field(min_length=1)
class AuthUser(BaseModel):
id: str
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class SessionResponse(BaseModel):
access_token: str
refresh_token: str
expires_in: int
token_type: str
user: AuthUser
class UserByPhoneResponse(BaseModel):
id: str
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
created_at: str
phone_confirmed_at: str | None = None
class UserByIdResponse(BaseModel):
id: str
phone: str | None = None
created_at: str
phone_confirmed_at: str | None = None
class OtpSendResponse(BaseModel):
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
class RegistrationBootstrapRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
user_id: UUID
+63
View File
@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Protocol
from v1.auth.schemas import (
OtpSendRequest,
PhoneSessionCreateRequest,
SessionRefreshRequest,
SessionResponse,
)
class AuthServiceGateway(Protocol):
async def send_otp(self, request: OtpSendRequest) -> None:
raise NotImplementedError
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
raise NotImplementedError
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
raise NotImplementedError
async def delete_session(self, refresh_token: str | None) -> None:
raise NotImplementedError
class AuthService:
_gateway: AuthServiceGateway
_registration_bootstrapper: RegistrationBootstrapper | None
def __init__(
self,
gateway: AuthServiceGateway,
registration_bootstrapper: "RegistrationBootstrapper | None" = None,
) -> None:
self._gateway = gateway
self._registration_bootstrapper = registration_bootstrapper
async def send_otp(self, request: OtpSendRequest) -> None:
await self._gateway.send_otp(request)
async def create_phone_session(
self, request: PhoneSessionCreateRequest
) -> SessionResponse:
response = await self._gateway.create_phone_session(request)
if self._registration_bootstrapper is not None:
await self._registration_bootstrapper.ensure_user_automation_jobs(
user_id=response.user.id
)
return response
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
return await self._gateway.refresh_session(request)
async def delete_session(self, refresh_token: str | None) -> None:
await self._gateway.delete_session(refresh_token)
class RegistrationBootstrapper(Protocol):
async def ensure_user_automation_jobs(self, *, user_id: str) -> None:
raise NotImplementedError