refactor(backend): update API routes and service layer
- Update agent router/service/repository with new endpoints - Update auth routes with phone-based authentication - Update users service with new phone lookup - Update schedule_items with new schemas - Update message schemas with visibility support - Update settings with new automation scheduler config - Update CLI with new commands - Update tests to match new API contracts
This commit is contained in:
+127
-199
@@ -2,28 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
UserByEmailResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
UserByPhoneResponse,
|
||||
)
|
||||
from v1.auth.service import AuthServiceGateway
|
||||
|
||||
@@ -36,7 +30,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def __init__(self) -> None:
|
||||
self._user_lookup_cache_ttl_seconds: int = 60
|
||||
self._user_lookup_cache_expires_at: float = 0.0
|
||||
self._users_by_email: dict[str, Any] = {}
|
||||
self._users_by_phone: dict[str, Any] = {}
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
@@ -44,47 +38,31 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
client = self._get_client()
|
||||
metadata: dict[str, Any] = {"username": request.username}
|
||||
if request.invite_code:
|
||||
metadata["invite_code"] = request.invite_code
|
||||
payload: dict[str, Any] = {
|
||||
"email": request.email,
|
||||
"password": request.password,
|
||||
"data": metadata,
|
||||
"phone": request.phone,
|
||||
"options": {"should_create_user": True},
|
||||
}
|
||||
if request.redirect_to:
|
||||
payload["options"] = {
|
||||
"email_redirect_to": _validate_redirect_url(request.redirect_to)
|
||||
}
|
||||
try:
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
sign_in_with_otp = cast(Any, client.auth.sign_in_with_otp)
|
||||
await asyncio.to_thread(sign_in_with_otp, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup failed", error_type=type(exc).__name__)
|
||||
logger.warning("Send otp failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid signup request"
|
||||
) from exc
|
||||
raise HTTPException(status_code=429, detail="Too many requests") from exc
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
if request.type != "signup":
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": request.type,
|
||||
"email": request.email,
|
||||
"type": "sms",
|
||||
"phone": request.phone,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
@@ -92,7 +70,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(response, "Invalid verification code")
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup verify failed", error_type=type(exc).__name__)
|
||||
logger.warning("Create phone session failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
@@ -102,45 +80,6 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid verification code"
|
||||
) from exc
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
if request.type == "recovery":
|
||||
await self.request_password_reset(
|
||||
PasswordResetRequest(
|
||||
email=request.email,
|
||||
redirect_to=request.redirect_to,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
payload: dict[str, Any] = {"type": request.type, "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
logger.warning("Login failed", error_type=type(exc).__name__)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
@@ -189,98 +128,84 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid refresh token"
|
||||
) from exc
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
admin_client = self._get_admin_client()
|
||||
normalized_email = email.lower()
|
||||
async def get_user_by_phone(self, phone: str) -> UserByPhoneResponse:
|
||||
normalized_phone = _normalize_phone(phone)
|
||||
if not normalized_phone:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
now = time.monotonic()
|
||||
if now >= self._user_lookup_cache_expires_at:
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_email: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_email = str(getattr(candidate, "email", "")).lower()
|
||||
if candidate_email:
|
||||
users_by_email[candidate_email] = candidate
|
||||
self._users_by_email = users_by_email
|
||||
self._user_lookup_cache_expires_at = (
|
||||
now + self._user_lookup_cache_ttl_seconds
|
||||
)
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
|
||||
user = self._users_by_email.get(normalized_email)
|
||||
user = self._users_by_phone.get(normalized_phone)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserByEmailResponse(
|
||||
user_phone = _normalize_phone(getattr(user, "phone", ""))
|
||||
if not user_phone:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return UserByPhoneResponse(
|
||||
id=str(getattr(user, "id", "")),
|
||||
email=str(getattr(user, "email", "")),
|
||||
phone=user_phone,
|
||||
created_at=str(getattr(user, "created_at", "")),
|
||||
email_confirmed_at=(
|
||||
str(getattr(user, "email_confirmed_at", ""))
|
||||
if getattr(user, "email_confirmed_at", None)
|
||||
phone_confirmed_at=(
|
||||
str(getattr(user, "phone_confirmed_at", ""))
|
||||
if getattr(user, "phone_confirmed_at", None)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {
|
||||
"redirect_to": _validate_redirect_url(request.redirect_to)
|
||||
}
|
||||
await asyncio.to_thread(reset_email, email, options=options)
|
||||
else:
|
||||
await asyncio.to_thread(reset_email, email)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset request failed",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
normalized_query = _normalize_phone_search_query(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
await self._refresh_user_lookup_cache_if_needed()
|
||||
if normalized_query.startswith("+"):
|
||||
matched_user = self._users_by_phone.get(normalized_query)
|
||||
if matched_user is None:
|
||||
return []
|
||||
user_id = str(getattr(matched_user, "id", ""))
|
||||
return [user_id] if user_id else []
|
||||
|
||||
digits = _digits_only(normalized_query)
|
||||
if not digits:
|
||||
return []
|
||||
|
||||
matched_records: list[tuple[str, str]] = []
|
||||
for cached_phone, candidate in self._users_by_phone.items():
|
||||
candidate_digits = _digits_only(cached_phone)
|
||||
if not candidate_digits.endswith(digits):
|
||||
continue
|
||||
user_id = str(getattr(candidate, "id", ""))
|
||||
if user_id:
|
||||
matched_records.append((cached_phone, user_id))
|
||||
|
||||
if not matched_records:
|
||||
return []
|
||||
|
||||
unique_ids: list[str] = []
|
||||
for _, user_id in sorted(matched_records, key=lambda item: item[0]):
|
||||
if user_id in unique_ids:
|
||||
continue
|
||||
unique_ids.append(user_id)
|
||||
if len(unique_ids) >= max(1, limit):
|
||||
break
|
||||
return unique_ids
|
||||
|
||||
async def _refresh_user_lookup_cache_if_needed(self) -> None:
|
||||
now = time.monotonic()
|
||||
if now < self._user_lookup_cache_expires_at:
|
||||
return
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
client = self._get_client()
|
||||
admin_client = self._get_admin_client()
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
user_id = str(getattr(user, "id", "")) if user is not None else ""
|
||||
if session is None or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
except AuthError as exc:
|
||||
logger.warning(
|
||||
"Password reset confirm failed", error_type=type(exc).__name__
|
||||
)
|
||||
if _is_auth_upstream_unavailable(exc):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=AUTH_UNAVAILABLE_DETAIL,
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
) from exc
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
users_by_phone: dict[str, Any] = {}
|
||||
for candidate in users:
|
||||
candidate_phone = _normalize_phone(getattr(candidate, "phone", ""))
|
||||
if candidate_phone:
|
||||
users_by_phone[candidate_phone] = candidate
|
||||
self._users_by_phone = users_by_phone
|
||||
self._user_lookup_cache_expires_at = now + self._user_lookup_cache_ttl_seconds
|
||||
|
||||
|
||||
def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
@@ -312,55 +237,24 @@ def _is_auth_upstream_unavailable(exc: AuthError) -> bool:
|
||||
return any(token in code or token in message for token in indicators)
|
||||
|
||||
|
||||
def _validate_redirect_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
|
||||
origin = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
||||
allowed_origins = {
|
||||
_normalize_origin(candidate)
|
||||
for candidate in config.cors.allow_origins
|
||||
if _is_http_origin(candidate)
|
||||
}
|
||||
if origin not in allowed_origins:
|
||||
raise HTTPException(status_code=422, detail="Invalid redirect URL")
|
||||
return url
|
||||
|
||||
|
||||
def _normalize_origin(value: str) -> str:
|
||||
parsed = urlparse(value)
|
||||
return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
|
||||
|
||||
|
||||
def _is_http_origin(value: str) -> bool:
|
||||
parsed = urlparse(value)
|
||||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||||
|
||||
|
||||
def _coerce_reset_email(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
nested = value.get("email") or value.get("value")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
|
||||
raise HTTPException(status_code=422, detail="Invalid email")
|
||||
|
||||
|
||||
def _map_auth_response(response: object, failure_message: str) -> SessionResponse:
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
if session is None or user is None:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
email = getattr(user, "email", None)
|
||||
if not email:
|
||||
phone = _normalize_phone(getattr(user, "phone", None))
|
||||
if not phone:
|
||||
raise HTTPException(status_code=401, detail=failure_message)
|
||||
|
||||
auth_user = AuthUser(id=str(user.id), email=str(email))
|
||||
try:
|
||||
auth_user = AuthUser(id=str(user.id), phone=str(phone))
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Auth response returned invalid phone format",
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=failure_message) from exc
|
||||
return SessionResponse(
|
||||
access_token=str(session.access_token),
|
||||
refresh_token=str(session.refresh_token),
|
||||
@@ -389,3 +283,37 @@ def _list_auth_users(client: Any) -> list[Any]:
|
||||
page += 1
|
||||
|
||||
return users
|
||||
|
||||
|
||||
def _normalize_phone(raw_phone: object) -> str | None:
|
||||
phone = str(raw_phone).strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
phone = phone.replace(separator, "")
|
||||
if not phone:
|
||||
return None
|
||||
if phone.startswith("00") and len(phone) > 2:
|
||||
return f"+{phone[2:]}"
|
||||
if phone.startswith("+"):
|
||||
return phone
|
||||
if phone.isdigit():
|
||||
return f"+{phone}"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_phone_search_query(raw_query: str) -> str | None:
|
||||
query = raw_query.strip()
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
query = query.replace(separator, "")
|
||||
if not query:
|
||||
return None
|
||||
if query.startswith("00") and len(query) > 2:
|
||||
return f"+{query[2:]}"
|
||||
if query.startswith("+"):
|
||||
return query
|
||||
if query.isdigit():
|
||||
return query
|
||||
return None
|
||||
|
||||
|
||||
def _digits_only(value: str) -> str:
|
||||
return "".join(ch for ch in value if ch.isdigit())
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.config.settings import config
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.auth.dependencies import get_auth_service
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionDeleteRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
from v1.auth.service import AuthService
|
||||
|
||||
@@ -22,80 +18,49 @@ from v1.auth.service import AuthService
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/verifications", response_model=VerificationCreateResponse, status_code=202
|
||||
)
|
||||
async def create_verification(
|
||||
payload: VerificationCreateRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> VerificationCreateResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="signup_start",
|
||||
identifier=payload.email,
|
||||
limit=5,
|
||||
window_seconds=60,
|
||||
)
|
||||
return await service.create_verification(payload)
|
||||
|
||||
|
||||
@router.post("/verify", response_model=SessionResponse)
|
||||
async def verify(
|
||||
payload: VerificationVerifyRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse | Response:
|
||||
scope = "signup_verify" if payload.type == "signup" else "password_reset_confirm"
|
||||
limit = 10
|
||||
window_seconds = 600
|
||||
await enforce_rate_limit(
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=limit,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if payload.type == "signup":
|
||||
return await service.verify_verification(payload)
|
||||
if payload.new_password is None:
|
||||
raise HTTPException(status_code=422, detail="Invalid request")
|
||||
await service.confirm_password_reset(
|
||||
PasswordResetConfirmRequest(
|
||||
email=payload.email,
|
||||
token=payload.token,
|
||||
new_password=payload.new_password,
|
||||
)
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/resend", status_code=204)
|
||||
async def resend(
|
||||
payload: VerificationResendRequest,
|
||||
@router.post("/otp/send", status_code=204)
|
||||
async def send_otp(
|
||||
payload: OtpSendRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
scope = "signup_resend" if payload.type == "signup" else "password_reset_request"
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope=scope,
|
||||
identifier=f"{payload.email.lower()}:{_client_ip(request)}",
|
||||
limit=5,
|
||||
scope="otp_send_phone",
|
||||
identifier=payload.phone,
|
||||
limit=3,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.resend_verification(payload)
|
||||
await enforce_rate_limit(
|
||||
scope="otp_send_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.send_otp(payload)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(
|
||||
payload: SessionCreateRequest,
|
||||
@router.post("/phone-session", response_model=SessionResponse)
|
||||
async def create_phone_session(
|
||||
payload: PhoneSessionCreateRequest,
|
||||
request: Request,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> SessionResponse:
|
||||
client_ip = _client_ip(request)
|
||||
await enforce_rate_limit(
|
||||
scope="login",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
scope="phone_session_phone",
|
||||
identifier=payload.phone,
|
||||
limit=6,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_session(payload)
|
||||
await enforce_rate_limit(
|
||||
scope="phone_session_ip",
|
||||
identifier=client_ip,
|
||||
limit=20,
|
||||
window_seconds=300,
|
||||
)
|
||||
return await service.create_phone_session(payload)
|
||||
|
||||
|
||||
@router.post("/sessions/refresh", response_model=SessionResponse)
|
||||
@@ -130,13 +95,23 @@ async def delete_session(
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
if first:
|
||||
return first
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
host = request.client.host if request.client else ""
|
||||
return host or "unknown"
|
||||
if not host:
|
||||
return "unknown"
|
||||
|
||||
if _should_trust_proxy_headers(host):
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
first = forwarded_for.split(",")[0].strip()
|
||||
if first:
|
||||
return first
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
return host
|
||||
|
||||
|
||||
def _should_trust_proxy_headers(host: str) -> bool:
|
||||
trusted_proxies = {entry.strip() for entry in config.runtime.trusted_proxy_ips}
|
||||
return host in trusted_proxies
|
||||
|
||||
@@ -1,49 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
SUPABASE_PASSWORD_MIN_LENGTH = 6
|
||||
OtpType = Literal["signup", "recovery"]
|
||||
SUPABASE_PHONE_PATTERN = r"^\+[1-9]\d{7,14}$"
|
||||
|
||||
|
||||
class VerificationCreateRequest(BaseModel):
|
||||
class OtpSendRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
username: str = Field(min_length=3, max_length=30)
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
redirect_to: str | None = None
|
||||
invite_code: str | None = None
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class VerificationResendRequest(BaseModel):
|
||||
email: EmailStr
|
||||
type: OtpType = "signup"
|
||||
redirect_to: str | None = None
|
||||
class PhoneSessionCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class VerificationVerifyRequest(BaseModel):
|
||||
type: OtpType = "signup"
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str | None = Field(
|
||||
default=None, min_length=SUPABASE_PASSWORD_MIN_LENGTH
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_type_payload(self) -> "VerificationVerifyRequest":
|
||||
if self.type == "recovery" and self.new_password is None:
|
||||
raise ValueError("new_password is required when type is recovery")
|
||||
if self.type == "signup" and self.new_password is not None:
|
||||
raise ValueError("new_password is only allowed when type is recovery")
|
||||
return self
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
|
||||
|
||||
class SessionRefreshRequest(BaseModel):
|
||||
@@ -56,7 +29,7 @@ class SessionDeleteRequest(BaseModel):
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
@@ -67,23 +40,12 @@ class SessionResponse(BaseModel):
|
||||
user: AuthUser
|
||||
|
||||
|
||||
class UserByEmailResponse(BaseModel):
|
||||
class UserByPhoneResponse(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
created_at: str
|
||||
email_confirmed_at: str | None = None
|
||||
phone_confirmed_at: str | None = None
|
||||
|
||||
|
||||
class VerificationCreateResponse(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
redirect_to: str | None = None
|
||||
|
||||
|
||||
class PasswordResetConfirmRequest(BaseModel):
|
||||
email: EmailStr
|
||||
token: str = Field(pattern=r"^\d{6}$")
|
||||
new_password: str = Field(min_length=SUPABASE_PASSWORD_MIN_LENGTH)
|
||||
class OtpSendResponse(BaseModel):
|
||||
phone: str = Field(pattern=SUPABASE_PHONE_PATTERN)
|
||||
|
||||
@@ -1,52 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Protocol
|
||||
|
||||
from v1.auth.schemas import (
|
||||
PasswordResetConfirmRequest,
|
||||
PasswordResetRequest,
|
||||
SessionCreateRequest,
|
||||
OtpSendRequest,
|
||||
PhoneSessionCreateRequest,
|
||||
SessionRefreshRequest,
|
||||
SessionResponse,
|
||||
VerificationCreateRequest,
|
||||
VerificationCreateResponse,
|
||||
VerificationResendRequest,
|
||||
VerificationVerifyRequest,
|
||||
)
|
||||
|
||||
|
||||
class AuthServiceGateway(Protocol):
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AuthService:
|
||||
_gateway: AuthServiceGateway
|
||||
@@ -54,50 +32,16 @@ class AuthService:
|
||||
def __init__(self, gateway: AuthServiceGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
normalized_invite_code = _normalize_invite_code(request.invite_code)
|
||||
normalized_request = request.model_copy(
|
||||
update={"invite_code": normalized_invite_code}
|
||||
)
|
||||
return await self._gateway.create_verification(normalized_request)
|
||||
async def send_otp(self, request: OtpSendRequest) -> None:
|
||||
await self._gateway.send_otp(request)
|
||||
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
async def create_phone_session(
|
||||
self, request: PhoneSessionCreateRequest
|
||||
) -> SessionResponse:
|
||||
return await self._gateway.verify_verification(request)
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
await self._gateway.resend_verification(request)
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
return await self._gateway.create_session(request)
|
||||
return await self._gateway.create_phone_session(request)
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
return await self._gateway.refresh_session(request)
|
||||
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
await self._gateway.delete_session(refresh_token)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
await self._gateway.request_password_reset(request)
|
||||
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
await self._gateway.confirm_password_reset(request)
|
||||
|
||||
|
||||
_INVITE_CODE_PATTERN = re.compile(r"^[ABCDEFGHJKMNPQRSTUVWXYZ23456789]{4}$")
|
||||
|
||||
|
||||
def _normalize_invite_code(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
normalized = value.strip().upper()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
return normalized if _INVITE_CODE_PATTERN.fullmatch(normalized) else None
|
||||
|
||||
Reference in New Issue
Block a user