chore: checkpoint current backend/runtime changes

This commit is contained in:
qzl
2026-03-06 17:28:17 +08:00
parent 2c59fe5ee2
commit b6087fd195
32 changed files with 1641 additions and 469 deletions
+13 -12
View File
@@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from core.config.settings import config
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger, log_service_banner
from services.base.redis import redis_service
from services.base import close_registered_services, initialize_registered_services
from v1.router import router as mobile_router
@@ -29,22 +29,23 @@ log_service_banner(
)
logger = get_logger("api.app")
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
initialized = await redis_service.initialize()
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
if not initialized:
logger.error("Redis service failed to initialize, aborting startup")
raise RuntimeError("Redis service initialization failed")
logger.info(
"Redis service initialized",
host=config.redis.host,
db=config.redis.db,
)
yield
await redis_service.close()
logger.info("Redis service closed")
logger.error("Service initialization failed, aborting startup")
raise RuntimeError("Service initialization failed")
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
try:
yield
finally:
closed = await close_registered_services(services)
if not closed:
logger.warning("Failed to close all base services")
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
app = FastAPI(lifespan=lifespan)
@@ -1,15 +1,15 @@
from __future__ import annotations
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import UUID
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.celery.app import celery_app
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import redis_service
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agent.infrastructure.queue.tasks")
@@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol):
async def _build_redis_publisher() -> PublishEvent:
settings = cast(Any, config)
client = redis_service.get_client()
client = await get_or_init_redis_client()
event_store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
async def _publish(event_type: str, payload: dict[str, object]) -> None:
@@ -70,22 +69,27 @@ async def run_agent_task(
raise ValueError("session_id is required")
UUID(session_id)
tool_call_id = ""
user_input = ""
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
await publisher(start_event, {"session_id": session_id})
try:
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
result = await service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
result = await service_run.run(
session_id=session_id,
user_input=user_input,
@@ -125,6 +129,16 @@ async def run_agent_task(
raise
@celery_app.task(name="tasks.agent.run_command")
@default_broker.task(task_name="tasks.agent.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
@critical_broker.task(task_name="tasks.agent.run_command.critical")
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
@bulk_broker.task(task_name="tasks.agent.run_command.bulk")
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
-73
View File
@@ -1,73 +0,0 @@
from __future__ import annotations
from celery import Celery
from celery import signals as celery_signals
from kombu import Queue
from core.config.settings import config
from core.logging import get_logger
from core.logging.celery import configure_celery_app
from services.base.redis import redis_service
logger = get_logger("core.celery")
def _init_redis_on_worker_startup(**_: object) -> None:
import asyncio
logger.info("Initializing Redis service for Celery worker")
try:
result = asyncio.run(redis_service.initialize())
if result:
logger.info("Redis service initialized for Celery worker")
else:
logger.warning("Redis service initialization returned False")
except Exception as exc: # noqa: BLE001
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
def create_celery_app() -> Celery:
"""Create and configure the Celery application."""
celery_settings = config.celery
app = Celery(
"social_app",
broker=config.celery_broker_url,
backend=config.celery_result_backend,
include=["core.agent.infrastructure.queue.tasks"],
)
app.conf.update(
task_serializer=celery_settings.task_serializer,
result_serializer=celery_settings.result_serializer,
accept_content=celery_settings.accept_content,
timezone=celery_settings.timezone,
enable_utc=celery_settings.enable_utc,
task_track_started=celery_settings.task_track_started,
task_time_limit=celery_settings.task_time_limit,
task_soft_time_limit=celery_settings.task_soft_time_limit,
task_default_retry_delay=celery_settings.task_default_retry_delay,
task_default_queue="default",
task_create_missing_queues=False,
task_queues=(
Queue("default"),
Queue("critical"),
Queue("bulk"),
),
task_routes={
"tasks.critical.*": {"queue": "critical"},
"tasks.bulk.*": {"queue": "bulk"},
},
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
)
configure_celery_app(app, settings=config)
return app
celery_app = create_celery_app()
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
+7 -17
View File
@@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel):
return self
class CelerySettings(BaseModel):
class TaskiqSettings(BaseModel):
broker_url: str | None = None
result_backend: str | None = None
task_serializer: str = "json"
result_serializer: str = "json"
accept_content: list[str] = Field(default_factory=lambda: ["json"])
timezone: str = "UTC"
enable_utc: bool = True
task_track_started: bool = True
task_time_limit: int = 300
task_soft_time_limit: int = 240
task_default_retry_delay: int = 30
task_max_retries: int = 3
result_backend_url: str | None = None
class CorsSettings(BaseModel):
@@ -189,7 +179,7 @@ class Settings(BaseSettings):
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
celery: CelerySettings = CelerySettings()
taskiq: TaskiqSettings = TaskiqSettings()
database: DatabaseSettings = DatabaseSettings()
@computed_field
@@ -199,13 +189,13 @@ class Settings(BaseSettings):
@computed_field
@property
def celery_broker_url(self) -> str:
return self.celery.broker_url or self.redis.url
def taskiq_broker_url(self) -> str:
return self.taskiq.broker_url or self.redis.url
@computed_field
@property
def celery_result_backend(self) -> str:
return self.celery.result_backend or self.redis.url
def taskiq_result_backend_url(self) -> str:
return self.taskiq.result_backend_url or self.redis.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
-2
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
from core.logging import celery
from core.logging.banner import log_service_banner
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context, get_context
@@ -8,7 +7,6 @@ from core.logging.logger import get_logger
__all__ = [
"bind_context",
"celery",
"clear_context",
"configure_logging",
"get_context",
-64
View File
@@ -1,64 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import cast
from celery import Celery, signals
from core.config.settings import Settings
from core.logging.banner import log_service_banner
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context
@dataclass(frozen=True)
class CelerySignalHandlers:
on_setup_logging: Callable[..., None]
on_after_setup_task_logger: Callable[..., None]
on_task_prerun: Callable[..., None]
on_task_postrun: Callable[..., None]
def build_celery_signal_handlers(
settings: Settings | None = None,
) -> CelerySignalHandlers:
active_settings = settings or Settings()
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
log_service_banner(
service_name=active_settings.runtime.service_name,
environment=active_settings.runtime.environment,
)
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_task_prerun(*_args: object, **kwargs: object) -> None:
task_id = cast(str | None, kwargs.get("task_id"))
task = kwargs.get("task")
task_name = getattr(task, "name", None)
bind_context(task_id=task_id, task_name=task_name)
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
clear_context()
return CelerySignalHandlers(
on_setup_logging=on_setup_logging,
on_after_setup_task_logger=on_after_setup_task_logger,
on_task_prerun=on_task_prerun,
on_task_postrun=on_task_postrun,
)
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
app.conf.worker_hijack_root_logger = False
handlers = build_celery_signal_handlers(settings)
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
signals.after_setup_task_logger.connect(
handlers.on_after_setup_task_logger, weak=False
)
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
+3
View File
@@ -0,0 +1,3 @@
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config.settings import config
from core.logging import configure_logging
configure_logging(config)
def _build_broker(queue_name: str) -> ListQueueBroker:
return ListQueueBroker(
url=config.taskiq_broker_url,
queue_name=queue_name,
).with_result_backend(
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
)
default_broker = _build_broker("default")
critical_broker = _build_broker("critical")
bulk_broker = _build_broker("bulk")
# Backward-compatible export name for existing imports/tests.
broker = default_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
+11 -1
View File
@@ -1,18 +1,28 @@
from __future__ import annotations
from services.base.redis import RedisService, redis_service
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
resolve_registered_services,
)
from services.base.supabase import SupabaseService, supabase_service
__all__ = [
"BaseServiceProvider",
"RedisService",
"ServiceRegistry",
"SupabaseService",
"close_registered_services",
"get_or_init_redis_client",
"initialize_registered_services",
"redis_service",
"register_service",
"register_service_instance",
"resolve_registered_services",
"supabase_service",
]
+9 -1
View File
@@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider):
return self._require_client()
async def get_or_init_redis_client() -> redis.Redis:
if not redis_service.is_initialized:
initialized = await redis_service.initialize()
if not initialized:
raise RuntimeError("Redis service initialization failed")
return redis_service.get_client()
redis_service: RedisService = register_service_instance("redis", RedisService())
__all__ = ["RedisService", "redis_service"]
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
@@ -61,6 +61,12 @@ class ServiceRegistry:
@classmethod
def create_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
return cls.get_service(service_name, **kwargs)
@classmethod
def get_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
factory = cls.get_service_factory(service_name)
if not factory:
@@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider)
def register_service_instance(service_name: str, service: TService) -> TService:
ServiceRegistry.register(service_name, lambda: service)
return service
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
services: list[BaseServiceProvider] = []
for service_name in service_names:
service = ServiceRegistry.get_service(service_name)
if service is None:
raise RuntimeError(f"Service is not registered: {service_name}")
services.append(service)
return services
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
lifecycle_logger = get_logger("services.base.lifecycle")
all_closed = True
for service in reversed(services):
try:
closed = await service.close()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Failed to close service",
service=service.service_name,
error=str(exc),
)
all_closed = False
continue
if not closed:
lifecycle_logger.warning(
"Service close returned false",
service=service.service_name,
)
all_closed = False
return all_closed
async def initialize_registered_services(
service_names: list[str],
) -> tuple[bool, list[BaseServiceProvider]]:
lifecycle_logger = get_logger("services.base.lifecycle")
initialized_services: list[BaseServiceProvider] = []
try:
services = resolve_registered_services(service_names)
except RuntimeError as exc:
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
return False, []
for service in services:
try:
initialized = await service.initialize()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Service initialization raised exception",
service=service.service_name,
error=str(exc),
)
initialized = False
if not initialized:
lifecycle_logger.error(
"Service initialization failed, rolling back",
service=service.service_name,
)
await close_registered_services(initialized_services)
return False, []
initialized_services.append(service)
return True, initialized_services
+89
View File
@@ -0,0 +1,89 @@
from __future__ import annotations
import asyncio
from typing import Any
from supabase import create_client
from core.config.settings import SupabaseSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class SupabaseService(BaseServiceProvider):
def __init__(self, settings: SupabaseSettings | None = None) -> None:
super().__init__("supabase")
self._settings = settings or config.supabase
self._client: Any = None
self._admin_client: Any = None
async def initialize(self, **_: Any) -> bool:
try:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
self._set_initialized(True)
self.logger.info("Supabase service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase service initialization failed", error=str(exc))
self._client = None
self._admin_client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
self._client = None
self._admin_client = None
self._set_initialized(False)
self.logger.info("Supabase service closed")
return True
async def health_check(self) -> dict[str, Any]:
client = self._client
admin_client = self._admin_client
if client is None or admin_client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
await asyncio.to_thread(client.auth.get_session)
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
return {
"status": "healthy",
"details": {
"anon_client": "ready",
"admin_client": "ready",
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> Any:
return self._require_client()
def get_admin_client(self) -> Any:
return self._require_admin_client()
def _require_client(self) -> Any:
client = self._client
if client is None:
raise RuntimeError("Supabase client is not initialized")
return client
def _require_admin_client(self) -> Any:
admin_client = self._admin_client
if admin_client is None:
raise RuntimeError("Supabase admin client is not initialized")
return admin_client
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
)
__all__ = ["SupabaseService", "supabase_service"]
+69 -27
View File
@@ -1,57 +1,98 @@
from __future__ import annotations
from typing import Any, cast
import asyncio
from typing import Any
from uuid import UUID
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.agent.infrastructure.queue.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import redis_service
from services.base.redis import get_or_init_redis_client
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
class CeleryQueueClient:
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis = redis_service.get_client()
self._redis: Redis | None = None
async def _get_redis(self) -> Redis:
if self._redis is None:
self._redis = await get_or_init_redis_client()
return self._redis
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_client = await self._get_redis()
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
locked = await redis_client.set(
redis_key,
DEDUP_INFLIGHT_MARKER,
nx=True,
ex=DEDUP_LOCK_SECONDS,
)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
for _ in range(DEDUP_WAIT_RETRIES):
existing = await redis_client.get(redis_key)
if existing and existing != DEDUP_INFLIGHT_MARKER:
return existing
await asyncio.sleep(DEDUP_WAIT_SECONDS)
raise RuntimeError("duplicate request is still in progress")
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
queue_task = self._select_queue_task(payload)
try:
result = await queue_task.kiq(payload)
task_id = str(result.task_id)
if redis_key is not None:
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
return task_id
except Exception:
if redis_key is not None:
await redis_client.delete(redis_key)
raise
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis_service.get_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
self._store: RedisStreamEventStore | None = None
async def _get_store(self) -> RedisStreamEventStore:
if self._store is None:
client = await get_or_init_redis_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
return self._store
async def read(
self,
@@ -59,7 +100,8 @@ class RedisEventStream:
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
store = await self._get_store()
rows = await store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
@@ -69,6 +111,6 @@ class RedisEventStream:
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
)
+28 -20
View File
@@ -5,10 +5,10 @@ from collections.abc import Mapping
from typing import Any, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from supabase import AuthError
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
_admin_client: Any
def _get_client(self) -> Any:
return supabase_service.get_client()
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
self._admin_client = create_client(settings.url, settings.service_role_key)
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
if request.redirect_to:
payload["options"] = {"email_redirect_to": request.redirect_to}
try:
sign_up = cast(Any, self._client.auth.sign_up)
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
except AuthError as exc:
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "signup",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {"type": "signup", "email": request.email}
try:
resend = cast(Any, self._client.auth.resend)
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
normalized_email = email.lower()
user = next(
(
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, self._client.auth.reset_password_email)
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
self._admin_client.auth.admin.update_user_by_id,
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)