chore: checkpoint current backend/runtime changes
This commit is contained in:
+13
-12
@@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from core.config.settings import config
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger, log_service_banner
|
||||
from services.base.redis import redis_service
|
||||
from services.base import close_registered_services, initialize_registered_services
|
||||
from v1.router import router as mobile_router
|
||||
|
||||
|
||||
@@ -29,22 +29,23 @@ log_service_banner(
|
||||
)
|
||||
|
||||
logger = get_logger("api.app")
|
||||
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
initialized = await redis_service.initialize()
|
||||
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
|
||||
if not initialized:
|
||||
logger.error("Redis service failed to initialize, aborting startup")
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
logger.info(
|
||||
"Redis service initialized",
|
||||
host=config.redis.host,
|
||||
db=config.redis.db,
|
||||
)
|
||||
yield
|
||||
await redis_service.close()
|
||||
logger.info("Redis service closed")
|
||||
logger.error("Service initialization failed, aborting startup")
|
||||
raise RuntimeError("Service initialization failed")
|
||||
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
closed = await close_registered_services(services)
|
||||
if not closed:
|
||||
logger.warning("Failed to close all base services")
|
||||
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.celery.app import celery_app
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import redis_service
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.queue.tasks")
|
||||
|
||||
@@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol):
|
||||
|
||||
|
||||
async def _build_redis_publisher() -> PublishEvent:
|
||||
settings = cast(Any, config)
|
||||
client = redis_service.get_client()
|
||||
client = await get_or_init_redis_client()
|
||||
event_store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
@@ -70,22 +69,27 @@ async def run_agent_task(
|
||||
raise ValueError("session_id is required")
|
||||
UUID(session_id)
|
||||
|
||||
tool_call_id = ""
|
||||
user_input = ""
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
|
||||
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
|
||||
await publisher(start_event, {"session_id": session_id})
|
||||
|
||||
try:
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
result = await service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
result = await service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
@@ -125,6 +129,16 @@ async def run_agent_task(
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.agent.run_command")
|
||||
@default_broker.task(task_name="tasks.agent.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
|
||||
@critical_broker.task(task_name="tasks.agent.run_command.critical")
|
||||
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
|
||||
@bulk_broker.task(task_name="tasks.agent.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals as celery_signals
|
||||
from kombu import Queue
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from core.logging.celery import configure_celery_app
|
||||
from services.base.redis import redis_service
|
||||
|
||||
logger = get_logger("core.celery")
|
||||
|
||||
|
||||
def _init_redis_on_worker_startup(**_: object) -> None:
|
||||
import asyncio
|
||||
|
||||
logger.info("Initializing Redis service for Celery worker")
|
||||
try:
|
||||
result = asyncio.run(redis_service.initialize())
|
||||
if result:
|
||||
logger.info("Redis service initialized for Celery worker")
|
||||
else:
|
||||
logger.warning("Redis service initialization returned False")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
|
||||
|
||||
|
||||
def create_celery_app() -> Celery:
|
||||
"""Create and configure the Celery application."""
|
||||
celery_settings = config.celery
|
||||
|
||||
app = Celery(
|
||||
"social_app",
|
||||
broker=config.celery_broker_url,
|
||||
backend=config.celery_result_backend,
|
||||
include=["core.agent.infrastructure.queue.tasks"],
|
||||
)
|
||||
|
||||
app.conf.update(
|
||||
task_serializer=celery_settings.task_serializer,
|
||||
result_serializer=celery_settings.result_serializer,
|
||||
accept_content=celery_settings.accept_content,
|
||||
timezone=celery_settings.timezone,
|
||||
enable_utc=celery_settings.enable_utc,
|
||||
task_track_started=celery_settings.task_track_started,
|
||||
task_time_limit=celery_settings.task_time_limit,
|
||||
task_soft_time_limit=celery_settings.task_soft_time_limit,
|
||||
task_default_retry_delay=celery_settings.task_default_retry_delay,
|
||||
task_default_queue="default",
|
||||
task_create_missing_queues=False,
|
||||
task_queues=(
|
||||
Queue("default"),
|
||||
Queue("critical"),
|
||||
Queue("bulk"),
|
||||
),
|
||||
task_routes={
|
||||
"tasks.critical.*": {"queue": "critical"},
|
||||
"tasks.bulk.*": {"queue": "bulk"},
|
||||
},
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
)
|
||||
|
||||
configure_celery_app(app, settings=config)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
celery_app = create_celery_app()
|
||||
|
||||
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
|
||||
@@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class CelerySettings(BaseModel):
|
||||
class TaskiqSettings(BaseModel):
|
||||
broker_url: str | None = None
|
||||
result_backend: str | None = None
|
||||
task_serializer: str = "json"
|
||||
result_serializer: str = "json"
|
||||
accept_content: list[str] = Field(default_factory=lambda: ["json"])
|
||||
timezone: str = "UTC"
|
||||
enable_utc: bool = True
|
||||
task_track_started: bool = True
|
||||
task_time_limit: int = 300
|
||||
task_soft_time_limit: int = 240
|
||||
task_default_retry_delay: int = 30
|
||||
task_max_retries: int = 3
|
||||
result_backend_url: str | None = None
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
@@ -189,7 +179,7 @@ class Settings(BaseSettings):
|
||||
storage: StorageSettings = StorageSettings()
|
||||
llm: LlmSettings = LlmSettings()
|
||||
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
|
||||
celery: CelerySettings = CelerySettings()
|
||||
taskiq: TaskiqSettings = TaskiqSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
@computed_field
|
||||
@@ -199,13 +189,13 @@ class Settings(BaseSettings):
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
return self.celery.broker_url or self.redis.url
|
||||
def taskiq_broker_url(self) -> str:
|
||||
return self.taskiq.broker_url or self.redis.url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
return self.celery.result_backend or self.redis.url
|
||||
def taskiq_result_backend_url(self) -> str:
|
||||
return self.taskiq.result_backend_url or self.redis.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging import celery
|
||||
from core.logging.banner import log_service_banner
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context, get_context
|
||||
@@ -8,7 +7,6 @@ from core.logging.logger import get_logger
|
||||
|
||||
__all__ = [
|
||||
"bind_context",
|
||||
"celery",
|
||||
"clear_context",
|
||||
"configure_logging",
|
||||
"get_context",
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery, signals
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.banner import log_service_banner
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CelerySignalHandlers:
|
||||
on_setup_logging: Callable[..., None]
|
||||
on_after_setup_task_logger: Callable[..., None]
|
||||
on_task_prerun: Callable[..., None]
|
||||
on_task_postrun: Callable[..., None]
|
||||
|
||||
|
||||
def build_celery_signal_handlers(
|
||||
settings: Settings | None = None,
|
||||
) -> CelerySignalHandlers:
|
||||
active_settings = settings or Settings()
|
||||
|
||||
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
log_service_banner(
|
||||
service_name=active_settings.runtime.service_name,
|
||||
environment=active_settings.runtime.environment,
|
||||
)
|
||||
|
||||
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_task_prerun(*_args: object, **kwargs: object) -> None:
|
||||
task_id = cast(str | None, kwargs.get("task_id"))
|
||||
task = kwargs.get("task")
|
||||
task_name = getattr(task, "name", None)
|
||||
bind_context(task_id=task_id, task_name=task_name)
|
||||
|
||||
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
|
||||
clear_context()
|
||||
|
||||
return CelerySignalHandlers(
|
||||
on_setup_logging=on_setup_logging,
|
||||
on_after_setup_task_logger=on_after_setup_task_logger,
|
||||
on_task_prerun=on_task_prerun,
|
||||
on_task_postrun=on_task_postrun,
|
||||
)
|
||||
|
||||
|
||||
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
|
||||
app.conf.worker_hijack_root_logger = False
|
||||
|
||||
handlers = build_celery_signal_handlers(settings)
|
||||
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
|
||||
signals.after_setup_task_logger.connect(
|
||||
handlers.on_after_setup_task_logger, weak=False
|
||||
)
|
||||
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
|
||||
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
|
||||
@@ -0,0 +1,3 @@
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import configure_logging
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
|
||||
def _build_broker(queue_name: str) -> ListQueueBroker:
|
||||
return ListQueueBroker(
|
||||
url=config.taskiq_broker_url,
|
||||
queue_name=queue_name,
|
||||
).with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
|
||||
)
|
||||
|
||||
|
||||
default_broker = _build_broker("default")
|
||||
critical_broker = _build_broker("critical")
|
||||
bulk_broker = _build_broker("bulk")
|
||||
|
||||
# Backward-compatible export name for existing imports/tests.
|
||||
broker = default_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
@@ -1,18 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
resolve_registered_services,
|
||||
)
|
||||
from services.base.supabase import SupabaseService, supabase_service
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"SupabaseService",
|
||||
"close_registered_services",
|
||||
"get_or_init_redis_client",
|
||||
"initialize_registered_services",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
"resolve_registered_services",
|
||||
"supabase_service",
|
||||
]
|
||||
|
||||
@@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider):
|
||||
return self._require_client()
|
||||
|
||||
|
||||
async def get_or_init_redis_client() -> redis.Redis:
|
||||
if not redis_service.is_initialized:
|
||||
initialized = await redis_service.initialize()
|
||||
if not initialized:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
return redis_service.get_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "redis_service"]
|
||||
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
|
||||
|
||||
@@ -61,6 +61,12 @@ class ServiceRegistry:
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
return cls.get_service(service_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
@@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
|
||||
|
||||
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
|
||||
services: list[BaseServiceProvider] = []
|
||||
for service_name in service_names:
|
||||
service = ServiceRegistry.get_service(service_name)
|
||||
if service is None:
|
||||
raise RuntimeError(f"Service is not registered: {service_name}")
|
||||
services.append(service)
|
||||
return services
|
||||
|
||||
|
||||
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
all_closed = True
|
||||
for service in reversed(services):
|
||||
try:
|
||||
closed = await service.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Failed to close service",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
all_closed = False
|
||||
continue
|
||||
if not closed:
|
||||
lifecycle_logger.warning(
|
||||
"Service close returned false",
|
||||
service=service.service_name,
|
||||
)
|
||||
all_closed = False
|
||||
return all_closed
|
||||
|
||||
|
||||
async def initialize_registered_services(
|
||||
service_names: list[str],
|
||||
) -> tuple[bool, list[BaseServiceProvider]]:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
initialized_services: list[BaseServiceProvider] = []
|
||||
try:
|
||||
services = resolve_registered_services(service_names)
|
||||
except RuntimeError as exc:
|
||||
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
|
||||
return False, []
|
||||
|
||||
for service in services:
|
||||
try:
|
||||
initialized = await service.initialize()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Service initialization raised exception",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
initialized = False
|
||||
|
||||
if not initialized:
|
||||
lifecycle_logger.error(
|
||||
"Service initialization failed, rolling back",
|
||||
service=service.service_name,
|
||||
)
|
||||
await close_registered_services(initialized_services)
|
||||
return False, []
|
||||
|
||||
initialized_services.append(service)
|
||||
|
||||
return True, initialized_services
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from supabase import create_client
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class SupabaseService(BaseServiceProvider):
|
||||
def __init__(self, settings: SupabaseSettings | None = None) -> None:
|
||||
super().__init__("supabase")
|
||||
self._settings = settings or config.supabase
|
||||
self._client: Any = None
|
||||
self._admin_client: Any = None
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
self.logger.info("Supabase service closed")
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
client = self._client
|
||||
admin_client = self._admin_client
|
||||
if client is None or admin_client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"anon_client": "ready",
|
||||
"admin_client": "ready",
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self._require_client()
|
||||
|
||||
def get_admin_client(self) -> Any:
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
)
|
||||
|
||||
__all__ = ["SupabaseService", "supabase_service"]
|
||||
@@ -1,57 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.queue.tasks import run_command_task
|
||||
from core.agent.infrastructure.queue.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from services.base.redis import redis_service
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
DEDUP_WAIT_RETRIES = 20
|
||||
DEDUP_WAIT_SECONDS = 0.05
|
||||
DEDUP_LOCK_SECONDS = 300
|
||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||
|
||||
class CeleryQueueClient:
|
||||
|
||||
class TaskiqQueueClient:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_service.get_client()
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def _get_redis(self) -> Redis:
|
||||
if self._redis is None:
|
||||
self._redis = await get_or_init_redis_client()
|
||||
return self._redis
|
||||
|
||||
@staticmethod
|
||||
def _select_queue_task(command: dict[str, object]) -> Any:
|
||||
queue = str(command.get("queue", "default")).strip().lower()
|
||||
if queue == "critical":
|
||||
return run_command_task_critical
|
||||
if queue == "bulk":
|
||||
return run_command_task_bulk
|
||||
return run_command_task
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
redis_client = await self._get_redis()
|
||||
redis_key = None
|
||||
if dedup_key:
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
|
||||
locked = await redis_client.set(
|
||||
redis_key,
|
||||
DEDUP_INFLIGHT_MARKER,
|
||||
nx=True,
|
||||
ex=DEDUP_LOCK_SECONDS,
|
||||
)
|
||||
if not locked:
|
||||
existing = await self._redis.get(redis_key)
|
||||
if existing and existing != "__inflight__":
|
||||
return existing
|
||||
for _ in range(DEDUP_WAIT_RETRIES):
|
||||
existing = await redis_client.get(redis_key)
|
||||
if existing and existing != DEDUP_INFLIGHT_MARKER:
|
||||
return existing
|
||||
await asyncio.sleep(DEDUP_WAIT_SECONDS)
|
||||
raise RuntimeError("duplicate request is still in progress")
|
||||
|
||||
payload = dict(command)
|
||||
if dedup_key:
|
||||
payload["dedup_key"] = dedup_key
|
||||
delay = getattr(run_command_task, "delay")
|
||||
result = delay(payload)
|
||||
task_id = str(result.id)
|
||||
if redis_key is not None:
|
||||
await self._redis.set(redis_key, task_id, ex=300)
|
||||
return task_id
|
||||
queue_task = self._select_queue_task(payload)
|
||||
try:
|
||||
result = await queue_task.kiq(payload)
|
||||
task_id = str(result.task_id)
|
||||
if redis_key is not None:
|
||||
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
|
||||
return task_id
|
||||
except Exception:
|
||||
if redis_key is not None:
|
||||
await redis_client.delete(redis_key)
|
||||
raise
|
||||
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
client = redis_service.get_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
self._store: RedisStreamEventStore | None = None
|
||||
|
||||
async def _get_store(self) -> RedisStreamEventStore:
|
||||
if self._store is None:
|
||||
client = await get_or_init_redis_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
return self._store
|
||||
|
||||
async def read(
|
||||
self,
|
||||
@@ -59,7 +100,8 @@ class RedisEventStream:
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = await self._store.read_events(
|
||||
store = await self._get_store()
|
||||
rows = await store.read_events(
|
||||
session_id=UUID(session_id),
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
@@ -69,6 +111,6 @@ class RedisEventStream:
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
return AgentService(
|
||||
repository=AgentRepository(session),
|
||||
queue=CeleryQueueClient(),
|
||||
queue=TaskiqQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
)
|
||||
|
||||
@@ -5,10 +5,10 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError, create_client
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
|
||||
|
||||
|
||||
class SupabaseAuthGateway(AuthServiceGateway):
|
||||
_client: Any
|
||||
_admin_client: Any
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
|
||||
def __init__(self) -> None:
|
||||
settings: SupabaseSettings = config.supabase
|
||||
self._client = create_client(settings.url, settings.anon_key)
|
||||
self._admin_client = create_client(settings.url, settings.service_role_key)
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
client = self._get_client()
|
||||
metadata: dict[str, Any] = {"username": request.username}
|
||||
if request.invite_code:
|
||||
metadata["invite_code"] = request.invite_code
|
||||
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
if request.redirect_to:
|
||||
payload["options"] = {"email_redirect_to": request.redirect_to}
|
||||
try:
|
||||
sign_up = cast(Any, self._client.auth.sign_up)
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
except AuthError as exc:
|
||||
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": "signup",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(response, "Invalid verification code")
|
||||
except AuthError as exc:
|
||||
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
) from exc
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"type": "signup", "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, self._client.auth.resend)
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, self._client.auth.sign_in_with_password)
|
||||
sign_in = cast(Any, client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
client.auth.refresh_session,
|
||||
request.refresh_token,
|
||||
)
|
||||
return _map_auth_response(response, "Invalid refresh token")
|
||||
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="Missing refresh token")
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
client.auth.refresh_session,
|
||||
refresh_token,
|
||||
)
|
||||
session = getattr(response, "session", None)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
await asyncio.to_thread(
|
||||
self._client.auth.set_session,
|
||||
client.auth.set_session,
|
||||
str(session.access_token),
|
||||
str(session.refresh_token),
|
||||
)
|
||||
await asyncio.to_thread(self._client.auth.sign_out)
|
||||
await asyncio.to_thread(client.auth.sign_out)
|
||||
except AuthError as exc:
|
||||
logger.warning("Logout failed", error_type=type(exc).__name__)
|
||||
raise HTTPException(
|
||||
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
) from exc
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
|
||||
admin_client = self._get_admin_client()
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
normalized_email = email.lower()
|
||||
user = next(
|
||||
(
|
||||
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
reset_email = cast(Any, self._client.auth.reset_password_email)
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {"redirect_to": request.redirect_to}
|
||||
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
client = self._get_client()
|
||||
admin_client = self._get_admin_client()
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self._admin_client.auth.admin.update_user_by_id,
|
||||
admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user