chore: checkpoint current backend/runtime changes
This commit is contained in:
+6
-7
@@ -35,17 +35,16 @@ SOCIAL_REDIS__DB=0
|
||||
# default: 常规异步任务
|
||||
# bulk: 批处理/重计算/可延迟任务
|
||||
SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY=2
|
||||
SOCIAL_WORKER__GROUPS__CRITICAL__PREFETCH_MULTIPLIER=1
|
||||
SOCIAL_WORKER__GROUPS__CRITICAL__TIME_LIMIT=300
|
||||
|
||||
SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY=2
|
||||
SOCIAL_WORKER__GROUPS__DEFAULT__PREFETCH_MULTIPLIER=4
|
||||
SOCIAL_WORKER__GROUPS__DEFAULT__TIME_LIMIT=600
|
||||
|
||||
SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY=1
|
||||
SOCIAL_WORKER__GROUPS__BULK__PREFETCH_MULTIPLIER=1
|
||||
SOCIAL_WORKER__GROUPS__BULK__TIME_LIMIT=3600
|
||||
SOCIAL_WORKER__GROUPS__BULK__MAX_TASKS_PER_CHILD=100
|
||||
|
||||
############
|
||||
# Taskiq(可选,默认回落到 Redis URL)
|
||||
############
|
||||
# SOCIAL_TASKIQ__BROKER_URL=redis://:password@localhost:6379/0
|
||||
# SOCIAL_TASKIQ__RESULT_BACKEND_URL=redis://:password@localhost:6379/0
|
||||
|
||||
############
|
||||
# Supabase(本地 Docker 与阿里云自托管保持同一变量)
|
||||
|
||||
+13
-12
@@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from core.config.settings import config
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger, log_service_banner
|
||||
from services.base.redis import redis_service
|
||||
from services.base import close_registered_services, initialize_registered_services
|
||||
from v1.router import router as mobile_router
|
||||
|
||||
|
||||
@@ -29,22 +29,23 @@ log_service_banner(
|
||||
)
|
||||
|
||||
logger = get_logger("api.app")
|
||||
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
initialized = await redis_service.initialize()
|
||||
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
|
||||
if not initialized:
|
||||
logger.error("Redis service failed to initialize, aborting startup")
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
logger.info(
|
||||
"Redis service initialized",
|
||||
host=config.redis.host,
|
||||
db=config.redis.db,
|
||||
)
|
||||
yield
|
||||
await redis_service.close()
|
||||
logger.info("Redis service closed")
|
||||
logger.error("Service initialization failed, aborting startup")
|
||||
raise RuntimeError("Service initialization failed")
|
||||
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
closed = await close_registered_services(services)
|
||||
if not closed:
|
||||
logger.warning("Failed to close all base services")
|
||||
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.celery.app import celery_app
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import redis_service
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.queue.tasks")
|
||||
|
||||
@@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol):
|
||||
|
||||
|
||||
async def _build_redis_publisher() -> PublishEvent:
|
||||
settings = cast(Any, config)
|
||||
client = redis_service.get_client()
|
||||
client = await get_or_init_redis_client()
|
||||
event_store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
@@ -70,22 +69,27 @@ async def run_agent_task(
|
||||
raise ValueError("session_id is required")
|
||||
UUID(session_id)
|
||||
|
||||
tool_call_id = ""
|
||||
user_input = ""
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
|
||||
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
|
||||
await publisher(start_event, {"session_id": session_id})
|
||||
|
||||
try:
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
result = await service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
result = await service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
@@ -125,6 +129,16 @@ async def run_agent_task(
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.agent.run_command")
|
||||
@default_broker.task(task_name="tasks.agent.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
|
||||
@critical_broker.task(task_name="tasks.agent.run_command.critical")
|
||||
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
|
||||
@bulk_broker.task(task_name="tasks.agent.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals as celery_signals
|
||||
from kombu import Queue
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from core.logging.celery import configure_celery_app
|
||||
from services.base.redis import redis_service
|
||||
|
||||
logger = get_logger("core.celery")
|
||||
|
||||
|
||||
def _init_redis_on_worker_startup(**_: object) -> None:
|
||||
import asyncio
|
||||
|
||||
logger.info("Initializing Redis service for Celery worker")
|
||||
try:
|
||||
result = asyncio.run(redis_service.initialize())
|
||||
if result:
|
||||
logger.info("Redis service initialized for Celery worker")
|
||||
else:
|
||||
logger.warning("Redis service initialization returned False")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
|
||||
|
||||
|
||||
def create_celery_app() -> Celery:
|
||||
"""Create and configure the Celery application."""
|
||||
celery_settings = config.celery
|
||||
|
||||
app = Celery(
|
||||
"social_app",
|
||||
broker=config.celery_broker_url,
|
||||
backend=config.celery_result_backend,
|
||||
include=["core.agent.infrastructure.queue.tasks"],
|
||||
)
|
||||
|
||||
app.conf.update(
|
||||
task_serializer=celery_settings.task_serializer,
|
||||
result_serializer=celery_settings.result_serializer,
|
||||
accept_content=celery_settings.accept_content,
|
||||
timezone=celery_settings.timezone,
|
||||
enable_utc=celery_settings.enable_utc,
|
||||
task_track_started=celery_settings.task_track_started,
|
||||
task_time_limit=celery_settings.task_time_limit,
|
||||
task_soft_time_limit=celery_settings.task_soft_time_limit,
|
||||
task_default_retry_delay=celery_settings.task_default_retry_delay,
|
||||
task_default_queue="default",
|
||||
task_create_missing_queues=False,
|
||||
task_queues=(
|
||||
Queue("default"),
|
||||
Queue("critical"),
|
||||
Queue("bulk"),
|
||||
),
|
||||
task_routes={
|
||||
"tasks.critical.*": {"queue": "critical"},
|
||||
"tasks.bulk.*": {"queue": "bulk"},
|
||||
},
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
)
|
||||
|
||||
configure_celery_app(app, settings=config)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
celery_app = create_celery_app()
|
||||
|
||||
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
|
||||
@@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class CelerySettings(BaseModel):
|
||||
class TaskiqSettings(BaseModel):
|
||||
broker_url: str | None = None
|
||||
result_backend: str | None = None
|
||||
task_serializer: str = "json"
|
||||
result_serializer: str = "json"
|
||||
accept_content: list[str] = Field(default_factory=lambda: ["json"])
|
||||
timezone: str = "UTC"
|
||||
enable_utc: bool = True
|
||||
task_track_started: bool = True
|
||||
task_time_limit: int = 300
|
||||
task_soft_time_limit: int = 240
|
||||
task_default_retry_delay: int = 30
|
||||
task_max_retries: int = 3
|
||||
result_backend_url: str | None = None
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
@@ -189,7 +179,7 @@ class Settings(BaseSettings):
|
||||
storage: StorageSettings = StorageSettings()
|
||||
llm: LlmSettings = LlmSettings()
|
||||
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
|
||||
celery: CelerySettings = CelerySettings()
|
||||
taskiq: TaskiqSettings = TaskiqSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
@computed_field
|
||||
@@ -199,13 +189,13 @@ class Settings(BaseSettings):
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
return self.celery.broker_url or self.redis.url
|
||||
def taskiq_broker_url(self) -> str:
|
||||
return self.taskiq.broker_url or self.redis.url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
return self.celery.result_backend or self.redis.url
|
||||
def taskiq_result_backend_url(self) -> str:
|
||||
return self.taskiq.result_backend_url or self.redis.url
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_file=_resolve_env_file(),
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.logging import celery
|
||||
from core.logging.banner import log_service_banner
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context, get_context
|
||||
@@ -8,7 +7,6 @@ from core.logging.logger import get_logger
|
||||
|
||||
__all__ = [
|
||||
"bind_context",
|
||||
"celery",
|
||||
"clear_context",
|
||||
"configure_logging",
|
||||
"get_context",
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery, signals
|
||||
|
||||
from core.config.settings import Settings
|
||||
from core.logging.banner import log_service_banner
|
||||
from core.logging.config import configure_logging
|
||||
from core.logging.context import bind_context, clear_context
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CelerySignalHandlers:
|
||||
on_setup_logging: Callable[..., None]
|
||||
on_after_setup_task_logger: Callable[..., None]
|
||||
on_task_prerun: Callable[..., None]
|
||||
on_task_postrun: Callable[..., None]
|
||||
|
||||
|
||||
def build_celery_signal_handlers(
|
||||
settings: Settings | None = None,
|
||||
) -> CelerySignalHandlers:
|
||||
active_settings = settings or Settings()
|
||||
|
||||
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
log_service_banner(
|
||||
service_name=active_settings.runtime.service_name,
|
||||
environment=active_settings.runtime.environment,
|
||||
)
|
||||
|
||||
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
|
||||
configure_logging(settings)
|
||||
|
||||
def on_task_prerun(*_args: object, **kwargs: object) -> None:
|
||||
task_id = cast(str | None, kwargs.get("task_id"))
|
||||
task = kwargs.get("task")
|
||||
task_name = getattr(task, "name", None)
|
||||
bind_context(task_id=task_id, task_name=task_name)
|
||||
|
||||
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
|
||||
clear_context()
|
||||
|
||||
return CelerySignalHandlers(
|
||||
on_setup_logging=on_setup_logging,
|
||||
on_after_setup_task_logger=on_after_setup_task_logger,
|
||||
on_task_prerun=on_task_prerun,
|
||||
on_task_postrun=on_task_postrun,
|
||||
)
|
||||
|
||||
|
||||
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
|
||||
app.conf.worker_hijack_root_logger = False
|
||||
|
||||
handlers = build_celery_signal_handlers(settings)
|
||||
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
|
||||
signals.after_setup_task_logger.connect(
|
||||
handlers.on_after_setup_task_logger, weak=False
|
||||
)
|
||||
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
|
||||
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
|
||||
@@ -0,0 +1,3 @@
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import configure_logging
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
|
||||
def _build_broker(queue_name: str) -> ListQueueBroker:
|
||||
return ListQueueBroker(
|
||||
url=config.taskiq_broker_url,
|
||||
queue_name=queue_name,
|
||||
).with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
|
||||
)
|
||||
|
||||
|
||||
default_broker = _build_broker("default")
|
||||
critical_broker = _build_broker("critical")
|
||||
bulk_broker = _build_broker("bulk")
|
||||
|
||||
# Backward-compatible export name for existing imports/tests.
|
||||
broker = default_broker
|
||||
|
||||
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
|
||||
@@ -1,18 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, redis_service
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
resolve_registered_services,
|
||||
)
|
||||
from services.base.supabase import SupabaseService, supabase_service
|
||||
|
||||
__all__ = [
|
||||
"BaseServiceProvider",
|
||||
"RedisService",
|
||||
"ServiceRegistry",
|
||||
"SupabaseService",
|
||||
"close_registered_services",
|
||||
"get_or_init_redis_client",
|
||||
"initialize_registered_services",
|
||||
"redis_service",
|
||||
"register_service",
|
||||
"register_service_instance",
|
||||
"resolve_registered_services",
|
||||
"supabase_service",
|
||||
]
|
||||
|
||||
@@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider):
|
||||
return self._require_client()
|
||||
|
||||
|
||||
async def get_or_init_redis_client() -> redis.Redis:
|
||||
if not redis_service.is_initialized:
|
||||
initialized = await redis_service.initialize()
|
||||
if not initialized:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
return redis_service.get_client()
|
||||
|
||||
|
||||
redis_service: RedisService = register_service_instance("redis", RedisService())
|
||||
|
||||
__all__ = ["RedisService", "redis_service"]
|
||||
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
|
||||
|
||||
@@ -61,6 +61,12 @@ class ServiceRegistry:
|
||||
@classmethod
|
||||
def create_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
return cls.get_service(service_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_service(
|
||||
cls, service_name: str, **kwargs: Any
|
||||
) -> Optional[BaseServiceProvider]:
|
||||
factory = cls.get_service_factory(service_name)
|
||||
if not factory:
|
||||
@@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider)
|
||||
def register_service_instance(service_name: str, service: TService) -> TService:
|
||||
ServiceRegistry.register(service_name, lambda: service)
|
||||
return service
|
||||
|
||||
|
||||
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
|
||||
services: list[BaseServiceProvider] = []
|
||||
for service_name in service_names:
|
||||
service = ServiceRegistry.get_service(service_name)
|
||||
if service is None:
|
||||
raise RuntimeError(f"Service is not registered: {service_name}")
|
||||
services.append(service)
|
||||
return services
|
||||
|
||||
|
||||
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
all_closed = True
|
||||
for service in reversed(services):
|
||||
try:
|
||||
closed = await service.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Failed to close service",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
all_closed = False
|
||||
continue
|
||||
if not closed:
|
||||
lifecycle_logger.warning(
|
||||
"Service close returned false",
|
||||
service=service.service_name,
|
||||
)
|
||||
all_closed = False
|
||||
return all_closed
|
||||
|
||||
|
||||
async def initialize_registered_services(
|
||||
service_names: list[str],
|
||||
) -> tuple[bool, list[BaseServiceProvider]]:
|
||||
lifecycle_logger = get_logger("services.base.lifecycle")
|
||||
initialized_services: list[BaseServiceProvider] = []
|
||||
try:
|
||||
services = resolve_registered_services(service_names)
|
||||
except RuntimeError as exc:
|
||||
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
|
||||
return False, []
|
||||
|
||||
for service in services:
|
||||
try:
|
||||
initialized = await service.initialize()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
lifecycle_logger.warning(
|
||||
"Service initialization raised exception",
|
||||
service=service.service_name,
|
||||
error=str(exc),
|
||||
)
|
||||
initialized = False
|
||||
|
||||
if not initialized:
|
||||
lifecycle_logger.error(
|
||||
"Service initialization failed, rolling back",
|
||||
service=service.service_name,
|
||||
)
|
||||
await close_registered_services(initialized_services)
|
||||
return False, []
|
||||
|
||||
initialized_services.append(service)
|
||||
|
||||
return True, initialized_services
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from supabase import create_client
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
|
||||
from .service_interface import BaseServiceProvider, register_service_instance
|
||||
|
||||
|
||||
class SupabaseService(BaseServiceProvider):
|
||||
def __init__(self, settings: SupabaseSettings | None = None) -> None:
|
||||
super().__init__("supabase")
|
||||
self._settings = settings or config.supabase
|
||||
self._client: Any = None
|
||||
self._admin_client: Any = None
|
||||
|
||||
async def initialize(self, **_: Any) -> bool:
|
||||
try:
|
||||
self._client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.anon_key,
|
||||
)
|
||||
self._admin_client = create_client(
|
||||
self._settings.url,
|
||||
self._settings.service_role_key,
|
||||
)
|
||||
self._set_initialized(True)
|
||||
self.logger.info("Supabase service initialized")
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase service initialization failed", error=str(exc))
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
return False
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._client = None
|
||||
self._admin_client = None
|
||||
self._set_initialized(False)
|
||||
self.logger.info("Supabase service closed")
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
client = self._client
|
||||
admin_client = self._admin_client
|
||||
if client is None or admin_client is None:
|
||||
return {"status": "unhealthy", "details": {"error": "not initialized"}}
|
||||
try:
|
||||
await asyncio.to_thread(client.auth.get_session)
|
||||
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"anon_client": "ready",
|
||||
"admin_client": "ready",
|
||||
},
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.warning("Supabase health check failed", error=str(exc))
|
||||
return {"status": "unhealthy", "details": {"error": str(exc)}}
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self._require_client()
|
||||
|
||||
def get_admin_client(self) -> Any:
|
||||
return self._require_admin_client()
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
client = self._client
|
||||
if client is None:
|
||||
raise RuntimeError("Supabase client is not initialized")
|
||||
return client
|
||||
|
||||
def _require_admin_client(self) -> Any:
|
||||
admin_client = self._admin_client
|
||||
if admin_client is None:
|
||||
raise RuntimeError("Supabase admin client is not initialized")
|
||||
return admin_client
|
||||
|
||||
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
)
|
||||
|
||||
__all__ = ["SupabaseService", "supabase_service"]
|
||||
@@ -1,57 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.queue.tasks import run_command_task
|
||||
from core.agent.infrastructure.queue.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from services.base.redis import redis_service
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
DEDUP_WAIT_RETRIES = 20
|
||||
DEDUP_WAIT_SECONDS = 0.05
|
||||
DEDUP_LOCK_SECONDS = 300
|
||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||
|
||||
class CeleryQueueClient:
|
||||
|
||||
class TaskiqQueueClient:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_service.get_client()
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def _get_redis(self) -> Redis:
|
||||
if self._redis is None:
|
||||
self._redis = await get_or_init_redis_client()
|
||||
return self._redis
|
||||
|
||||
@staticmethod
|
||||
def _select_queue_task(command: dict[str, object]) -> Any:
|
||||
queue = str(command.get("queue", "default")).strip().lower()
|
||||
if queue == "critical":
|
||||
return run_command_task_critical
|
||||
if queue == "bulk":
|
||||
return run_command_task_bulk
|
||||
return run_command_task
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
redis_client = await self._get_redis()
|
||||
redis_key = None
|
||||
if dedup_key:
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
|
||||
locked = await redis_client.set(
|
||||
redis_key,
|
||||
DEDUP_INFLIGHT_MARKER,
|
||||
nx=True,
|
||||
ex=DEDUP_LOCK_SECONDS,
|
||||
)
|
||||
if not locked:
|
||||
existing = await self._redis.get(redis_key)
|
||||
if existing and existing != "__inflight__":
|
||||
return existing
|
||||
for _ in range(DEDUP_WAIT_RETRIES):
|
||||
existing = await redis_client.get(redis_key)
|
||||
if existing and existing != DEDUP_INFLIGHT_MARKER:
|
||||
return existing
|
||||
await asyncio.sleep(DEDUP_WAIT_SECONDS)
|
||||
raise RuntimeError("duplicate request is still in progress")
|
||||
|
||||
payload = dict(command)
|
||||
if dedup_key:
|
||||
payload["dedup_key"] = dedup_key
|
||||
delay = getattr(run_command_task, "delay")
|
||||
result = delay(payload)
|
||||
task_id = str(result.id)
|
||||
if redis_key is not None:
|
||||
await self._redis.set(redis_key, task_id, ex=300)
|
||||
return task_id
|
||||
queue_task = self._select_queue_task(payload)
|
||||
try:
|
||||
result = await queue_task.kiq(payload)
|
||||
task_id = str(result.task_id)
|
||||
if redis_key is not None:
|
||||
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
|
||||
return task_id
|
||||
except Exception:
|
||||
if redis_key is not None:
|
||||
await redis_client.delete(redis_key)
|
||||
raise
|
||||
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
client = redis_service.get_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
self._store: RedisStreamEventStore | None = None
|
||||
|
||||
async def _get_store(self) -> RedisStreamEventStore:
|
||||
if self._store is None:
|
||||
client = await get_or_init_redis_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
return self._store
|
||||
|
||||
async def read(
|
||||
self,
|
||||
@@ -59,7 +100,8 @@ class RedisEventStream:
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = await self._store.read_events(
|
||||
store = await self._get_store()
|
||||
rows = await store.read_events(
|
||||
session_id=UUID(session_id),
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
@@ -69,6 +111,6 @@ class RedisEventStream:
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
return AgentService(
|
||||
repository=AgentRepository(session),
|
||||
queue=CeleryQueueClient(),
|
||||
queue=TaskiqQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
)
|
||||
|
||||
@@ -5,10 +5,10 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from supabase import AuthError, create_client
|
||||
from supabase import AuthError
|
||||
|
||||
from core.config.settings import SupabaseSettings, config
|
||||
from core.logging import get_logger
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.auth.schemas import (
|
||||
AuthUser,
|
||||
PasswordResetConfirmRequest,
|
||||
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
|
||||
|
||||
|
||||
class SupabaseAuthGateway(AuthServiceGateway):
|
||||
_client: Any
|
||||
_admin_client: Any
|
||||
def _get_client(self) -> Any:
|
||||
return supabase_service.get_client()
|
||||
|
||||
def __init__(self) -> None:
|
||||
settings: SupabaseSettings = config.supabase
|
||||
self._client = create_client(settings.url, settings.anon_key)
|
||||
self._admin_client = create_client(settings.url, settings.service_role_key)
|
||||
def _get_admin_client(self) -> Any:
|
||||
return supabase_service.get_admin_client()
|
||||
|
||||
async def create_verification(
|
||||
self, request: VerificationCreateRequest
|
||||
) -> VerificationCreateResponse:
|
||||
client = self._get_client()
|
||||
metadata: dict[str, Any] = {"username": request.username}
|
||||
if request.invite_code:
|
||||
metadata["invite_code"] = request.invite_code
|
||||
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
if request.redirect_to:
|
||||
payload["options"] = {"email_redirect_to": request.redirect_to}
|
||||
try:
|
||||
sign_up = cast(Any, self._client.auth.sign_up)
|
||||
sign_up = cast(Any, client.auth.sign_up)
|
||||
await asyncio.to_thread(sign_up, payload)
|
||||
return VerificationCreateResponse(email=request.email)
|
||||
except AuthError as exc:
|
||||
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def verify_verification(
|
||||
self, request: VerificationVerifyRequest
|
||||
) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {
|
||||
"type": "signup",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, payload)
|
||||
return _map_auth_response(response, "Invalid verification code")
|
||||
except AuthError as exc:
|
||||
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
) from exc
|
||||
|
||||
async def resend_verification(self, request: VerificationResendRequest) -> None:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"type": "signup", "email": request.email}
|
||||
try:
|
||||
resend = cast(Any, self._client.auth.resend)
|
||||
resend = cast(Any, client.auth.resend)
|
||||
await asyncio.to_thread(resend, payload)
|
||||
except AuthError as exc:
|
||||
logger.warning("Signup resend failed", error_type=type(exc).__name__)
|
||||
|
||||
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
payload: dict[str, Any] = {"email": request.email, "password": request.password}
|
||||
try:
|
||||
sign_in = cast(Any, self._client.auth.sign_in_with_password)
|
||||
sign_in = cast(Any, client.auth.sign_in_with_password)
|
||||
response = await asyncio.to_thread(sign_in, payload)
|
||||
return _map_auth_response(response, "Invalid credentials")
|
||||
except AuthError as exc:
|
||||
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
|
||||
|
||||
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
client.auth.refresh_session,
|
||||
request.refresh_token,
|
||||
)
|
||||
return _map_auth_response(response, "Invalid refresh token")
|
||||
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def delete_session(self, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="Missing refresh token")
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self._client.auth.refresh_session,
|
||||
client.auth.refresh_session,
|
||||
refresh_token,
|
||||
)
|
||||
session = getattr(response, "session", None)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
await asyncio.to_thread(
|
||||
self._client.auth.set_session,
|
||||
client.auth.set_session,
|
||||
str(session.access_token),
|
||||
str(session.refresh_token),
|
||||
)
|
||||
await asyncio.to_thread(self._client.auth.sign_out)
|
||||
await asyncio.to_thread(client.auth.sign_out)
|
||||
except AuthError as exc:
|
||||
logger.warning("Logout failed", error_type=type(exc).__name__)
|
||||
raise HTTPException(
|
||||
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
) from exc
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
|
||||
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
|
||||
admin_client = self._get_admin_client()
|
||||
users = await asyncio.to_thread(_list_auth_users, admin_client)
|
||||
normalized_email = email.lower()
|
||||
user = next(
|
||||
(
|
||||
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
)
|
||||
|
||||
async def request_password_reset(self, request: PasswordResetRequest) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
reset_email = cast(Any, self._client.auth.reset_password_email)
|
||||
reset_email = cast(Any, client.auth.reset_password_email)
|
||||
email = _coerce_reset_email(request.email)
|
||||
if request.redirect_to:
|
||||
options: dict[str, str] = {"redirect_to": request.redirect_to}
|
||||
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
async def confirm_password_reset(
|
||||
self, request: PasswordResetConfirmRequest
|
||||
) -> None:
|
||||
client = self._get_client()
|
||||
admin_client = self._get_admin_client()
|
||||
verify_payload: dict[str, Any] = {
|
||||
"type": "recovery",
|
||||
"email": request.email,
|
||||
"token": request.token,
|
||||
}
|
||||
try:
|
||||
verify_otp = cast(Any, self._client.auth.verify_otp)
|
||||
verify_otp = cast(Any, client.auth.verify_otp)
|
||||
response = await asyncio.to_thread(verify_otp, verify_payload)
|
||||
session = getattr(response, "session", None)
|
||||
user = getattr(response, "user", None)
|
||||
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
|
||||
status_code=401, detail="Invalid or expired verification code"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self._admin_client.auth.admin.update_user_by_id,
|
||||
admin_client.auth.admin.update_user_by_id,
|
||||
user_id,
|
||||
{"password": request.new_password},
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
@@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
)
|
||||
|
||||
assert events == ["RUN_STARTED", "RUN_ERROR"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_rejects_invalid_command() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
|
||||
with pytest.raises(ValueError, match="tool_call_id is required"):
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_redis_publisher_init_fail_raises_runtime_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from core.agent.infrastructure.queue import tasks
|
||||
|
||||
async def _fake_get_client() -> object:
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
|
||||
monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
|
||||
await _build_redis_publisher()
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_taskiq_uses_redis_url_by_default() -> None:
|
||||
settings = Settings()
|
||||
|
||||
assert settings.taskiq_broker_url.startswith("redis://")
|
||||
assert settings.taskiq_result_backend_url.startswith("redis://")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
|
||||
|
||||
|
||||
def test_taskiq_broker_is_configured() -> None:
|
||||
assert broker is not None
|
||||
assert default_broker is broker
|
||||
assert critical_broker is not None
|
||||
assert bulk_broker is not None
|
||||
|
||||
|
||||
def test_taskiq_app_configures_logging_on_import(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sys.modules.pop("core.taskiq.app", None)
|
||||
sys.modules.pop("core.taskiq", None)
|
||||
|
||||
called = {"count": 0, "args": None}
|
||||
|
||||
def _fake_configure_logging(*args: object, **__: object) -> None:
|
||||
called["count"] += 1
|
||||
called["args"] = args
|
||||
|
||||
monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging)
|
||||
|
||||
importlib.import_module("core.taskiq.app")
|
||||
|
||||
from core.config.settings import config
|
||||
|
||||
assert called["count"] == 1
|
||||
assert called["args"] == (config,)
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[4]
|
||||
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
|
||||
|
||||
|
||||
def test_worker_commands_use_taskiq() -> None:
|
||||
content = APP_SCRIPT.read_text(encoding="utf-8")
|
||||
removed_runner = "uv run c" "elery"
|
||||
|
||||
assert "uv run taskiq worker" in content
|
||||
assert "core.taskiq.app:critical_broker" in content
|
||||
assert "core.taskiq.app:default_broker" in content
|
||||
assert "core.taskiq.app:bulk_broker" in content
|
||||
assert 'pgrep -f "taskiq.*worker"' in content
|
||||
assert 'pkill -f "taskiq.*worker"' in content
|
||||
assert removed_runner not in content
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
from core.config.settings import RedisSettings
|
||||
from services.base.redis import RedisService
|
||||
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
@@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None:
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_init_redis_client_initializes_when_needed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_client = _FakeRedisClient()
|
||||
|
||||
async def _fake_initialize() -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
|
||||
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
|
||||
monkeypatch.setattr(redis_service, "get_client", lambda: fake_client)
|
||||
|
||||
client = await get_or_init_redis_client()
|
||||
|
||||
assert client is fake_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_init_redis_client_raises_when_init_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_initialize() -> bool:
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
|
||||
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
|
||||
await get_or_init_redis_client()
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.base.service_interface import (
|
||||
BaseServiceProvider,
|
||||
ServiceRegistry,
|
||||
close_registered_services,
|
||||
initialize_registered_services,
|
||||
register_service,
|
||||
register_service_instance,
|
||||
)
|
||||
@@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None:
|
||||
assert created.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_and_get_service() -> None:
|
||||
@register_service("dummy-service-get")
|
||||
class _RegisteredService(_DummyService):
|
||||
pass
|
||||
|
||||
resolved = ServiceRegistry.get_service("dummy-service-get")
|
||||
|
||||
assert resolved is not None
|
||||
assert resolved.get_service_info()["name"] == "dummy"
|
||||
|
||||
|
||||
def test_register_service_instance_returns_same_instance() -> None:
|
||||
instance = _DummyService("singleton")
|
||||
|
||||
@@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None:
|
||||
|
||||
def test_create_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.create_service("missing-service") is None
|
||||
|
||||
|
||||
def test_get_service_returns_none_for_missing() -> None:
|
||||
assert ServiceRegistry.get_service("missing-service") is None
|
||||
|
||||
|
||||
class _LifecycleService(BaseServiceProvider):
|
||||
def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None:
|
||||
super().__init__(name)
|
||||
self._recorder = recorder
|
||||
self._fail_on_init = fail_on_init
|
||||
|
||||
async def initialize(self, **_: object) -> bool:
|
||||
self._recorder.append(f"init:{self.service_name}")
|
||||
if self._fail_on_init:
|
||||
return False
|
||||
self._set_initialized(True)
|
||||
return True
|
||||
|
||||
async def close(self) -> bool:
|
||||
self._recorder.append(f"close:{self.service_name}")
|
||||
self._set_initialized(False)
|
||||
return True
|
||||
|
||||
async def health_check(self) -> dict[str, object]:
|
||||
return {"status": "healthy", "details": {}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_registered_services_success() -> None:
|
||||
recorder: list[str] = []
|
||||
first = register_service_instance(
|
||||
"lifecycle-success-first", _LifecycleService("first", recorder)
|
||||
)
|
||||
second = register_service_instance(
|
||||
"lifecycle-success-second", _LifecycleService("second", recorder)
|
||||
)
|
||||
|
||||
initialized, services = await initialize_registered_services(
|
||||
["lifecycle-success-first", "lifecycle-success-second"]
|
||||
)
|
||||
|
||||
assert initialized is True
|
||||
assert services == [first, second]
|
||||
assert recorder == ["init:first", "init:second"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_registered_services_failure_rolls_back() -> None:
|
||||
recorder: list[str] = []
|
||||
register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder))
|
||||
register_service_instance(
|
||||
"lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True)
|
||||
)
|
||||
|
||||
initialized, services = await initialize_registered_services(
|
||||
["lifecycle-fail-first", "lifecycle-fail-second"]
|
||||
)
|
||||
|
||||
assert initialized is False
|
||||
assert services == []
|
||||
assert recorder == ["init:first", "init:second", "close:first"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_registered_services_closes_in_reverse_order() -> None:
|
||||
recorder: list[str] = []
|
||||
first = _LifecycleService("first", recorder)
|
||||
second = _LifecycleService("second", recorder)
|
||||
|
||||
closed = await close_registered_services([first, second])
|
||||
|
||||
assert closed is True
|
||||
assert recorder == ["close:second", "close:first"]
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import SupabaseSettings
|
||||
from services.base.supabase import SupabaseService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
anon_client = MagicMock()
|
||||
admin_client = MagicMock()
|
||||
|
||||
create_calls: list[tuple[str, str]] = []
|
||||
|
||||
def _fake_create_client(url: str, key: str) -> object:
|
||||
create_calls.append((url, key))
|
||||
return anon_client if len(create_calls) == 1 else admin_client
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is True
|
||||
assert service.is_initialized is True
|
||||
assert service.get_client() is anon_client
|
||||
assert service.get_admin_client() is admin_client
|
||||
assert len(create_calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
result = await service.initialize()
|
||||
|
||||
assert result is False
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
assert await service.close() is True
|
||||
assert service.is_initialized is False
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_uninitialized() -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "unhealthy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
anon_client = MagicMock()
|
||||
anon_client.auth.get_session = MagicMock(return_value=None)
|
||||
|
||||
admin_list_users = MagicMock(return_value=SimpleNamespace(users=[]))
|
||||
admin_client = MagicMock()
|
||||
admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users)
|
||||
|
||||
create_sequence = [anon_client, admin_client]
|
||||
|
||||
def _fake_create_client(_: str, __: str) -> object:
|
||||
return create_sequence.pop(0)
|
||||
|
||||
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
|
||||
|
||||
assert await service.initialize() is True
|
||||
|
||||
health = await service.health_check()
|
||||
|
||||
assert health["status"] == "healthy"
|
||||
admin_list_users.assert_called_once_with(page=1, per_page=1)
|
||||
|
||||
|
||||
def test_get_client_raises_before_init() -> None:
|
||||
service = SupabaseService(settings=SupabaseSettings())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_client()
|
||||
with pytest.raises(RuntimeError):
|
||||
service.get_admin_client()
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import app as app_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
initialized_services = [object(), object()]
|
||||
calls: dict[str, object] = {}
|
||||
|
||||
async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]:
|
||||
calls["init_names"] = service_names
|
||||
return True, initialized_services
|
||||
|
||||
async def _fake_close(services: list[object]) -> bool:
|
||||
calls["close_services"] = services
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
|
||||
monkeypatch.setattr(app_module, "close_registered_services", _fake_close)
|
||||
|
||||
context = app_module.lifespan(app_module.app)
|
||||
await context.__aenter__()
|
||||
await context.__aexit__(None, None, None)
|
||||
|
||||
assert calls["init_names"] == ["redis", "supabase"]
|
||||
assert calls["close_services"] == initialized_services
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_raises_when_initialization_failed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]:
|
||||
return False, []
|
||||
|
||||
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
|
||||
|
||||
context = app_module.lifespan(app_module.app)
|
||||
with pytest.raises(RuntimeError, match="Service initialization failed"):
|
||||
await context.__aenter__()
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.logging import celery as celery_logging
|
||||
from core.logging.context import clear_context, get_context
|
||||
|
||||
|
||||
class DummyTask:
|
||||
name: str = "tasks.sample"
|
||||
|
||||
|
||||
def test_celery_prerun_binds_task_context() -> None:
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
|
||||
context = get_context()
|
||||
|
||||
assert context["task_id"] == "task-123"
|
||||
assert context["task_name"] == "tasks.sample"
|
||||
|
||||
clear_context()
|
||||
|
||||
|
||||
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
|
||||
called = {"value": False}
|
||||
|
||||
def fake_configure_logging(settings: object | None = None) -> None:
|
||||
called["value"] = True
|
||||
|
||||
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
|
||||
handlers = celery_logging.build_celery_signal_handlers()
|
||||
|
||||
handlers.on_setup_logging()
|
||||
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_configure_celery_app_disables_hijack() -> None:
|
||||
app = Celery("test")
|
||||
|
||||
celery_logging.configure_celery_app(app)
|
||||
|
||||
assert app.conf.worker_hijack_root_logger is False
|
||||
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.dependencies import TaskiqQueueClient
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, str] = {}
|
||||
self.delete_calls: list[str] = []
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
nx: bool = False,
|
||||
ex: int | None = None,
|
||||
) -> bool:
|
||||
del ex
|
||||
if nx and key in self.store:
|
||||
return False
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self.store.get(key)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
self.delete_calls.append(key)
|
||||
existed = 1 if key in self.store else 0
|
||||
self.store.pop(key, None)
|
||||
return existed
|
||||
|
||||
|
||||
class _FakeAsyncResult:
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self.task_id = task_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
assert payload["command"] == "run"
|
||||
return _FakeAsyncResult("task-123")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "task-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_resume_dedup_returns_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
resolved_client = {"value": False}
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
return _FakeAsyncResult("new-task-id")
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
resolved_client["value"] = True
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert resolved_client["value"] is True
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
|
||||
attempts = {"count": 0}
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_get(key: str) -> str | None:
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] > 1:
|
||||
fake_redis.store[key] = "existing-task-id"
|
||||
return fake_redis.store.get(key)
|
||||
|
||||
async def _fake_sleep(_: float) -> None:
|
||||
return None
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise AssertionError("should not enqueue when dedup task id appears")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(fake_redis, "get", _fake_get)
|
||||
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert task_id == "existing-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
dedup_key = "resume:session-1:call-1"
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
|
||||
del payload
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
with pytest.raises(RuntimeError, match="enqueue failed"):
|
||||
await client.enqueue(
|
||||
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
assert redis_key in fake_redis.delete_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_critical_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("critical-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "critical"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "critical-task-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_uses_bulk_queue_when_requested(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from v1.agent import dependencies as deps
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_or_init_client() -> _FakeRedis:
|
||||
return fake_redis
|
||||
|
||||
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
raise AssertionError("default queue should not be selected")
|
||||
|
||||
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
|
||||
return _FakeAsyncResult("bulk-task-id")
|
||||
|
||||
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
|
||||
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
|
||||
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
|
||||
|
||||
client = TaskiqQueueClient()
|
||||
task_id = await client.enqueue(
|
||||
command={"command": "run", "queue": "bulk"},
|
||||
dedup_key=None,
|
||||
)
|
||||
|
||||
assert task_id == "bulk-task-id"
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
|
||||
|
||||
class TestSupabaseAuthGateway:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> SupabaseAuthGateway:
|
||||
with patch("v1.auth.gateway.create_client") as mock_create:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
mock_create.side_effect = [mock_client, mock_admin_client]
|
||||
return SupabaseAuthGateway()
|
||||
def gateway(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
|
||||
mock_client = MagicMock()
|
||||
mock_admin_client = MagicMock()
|
||||
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
|
||||
monkeypatch.setattr(
|
||||
"v1.auth.gateway.supabase_service.get_admin_client",
|
||||
lambda: mock_admin_client,
|
||||
)
|
||||
return SupabaseAuthGateway(), mock_client, mock_admin_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_calls_email_with_string(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_with_redirect(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(
|
||||
email="test@example.com",
|
||||
redirect_to="http://localhost:3000/reset-password",
|
||||
)
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with(
|
||||
"test@example.com",
|
||||
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_swallows_auth_error(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
from supabase import AuthError
|
||||
|
||||
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest(email="test@example.com")
|
||||
|
||||
result = await gateway.request_password_reset(request)
|
||||
result = await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_extracts_email_from_mapping(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
mock_reset_email = MagicMock()
|
||||
gateway._client.auth.reset_password_email = mock_reset_email
|
||||
mock_client.auth.reset_password_email = mock_reset_email
|
||||
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"email": "test@example.com"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
mock_reset_email.assert_called_once_with("test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_password_reset_rejects_invalid_email_shape(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, _, _ = gateway
|
||||
request = PasswordResetRequest.model_construct(
|
||||
email={"unexpected": "value"},
|
||||
redirect_to=None,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.request_password_reset(request)
|
||||
await sut.request_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Invalid email"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_updates_password_by_user_id(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, mock_admin_client = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id="user-1"),
|
||||
)
|
||||
mock_verify_otp = MagicMock(return_value=verify_response)
|
||||
gateway._client.auth.verify_otp = mock_verify_otp
|
||||
mock_client.auth.verify_otp = mock_verify_otp
|
||||
|
||||
mock_update_user_by_id = MagicMock()
|
||||
gateway._admin_client.auth.admin = SimpleNamespace(
|
||||
mock_admin_client.auth.admin = SimpleNamespace(
|
||||
update_user_by_id=mock_update_user_by_id
|
||||
)
|
||||
|
||||
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
|
||||
new_password="newpassword123",
|
||||
)
|
||||
|
||||
await gateway.confirm_password_reset(request)
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
mock_verify_otp.assert_called_once_with(
|
||||
{
|
||||
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_password_reset_raises_when_user_id_missing(
|
||||
self, gateway: SupabaseAuthGateway
|
||||
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
|
||||
) -> None:
|
||||
sut, mock_client, _ = gateway
|
||||
verify_response = SimpleNamespace(
|
||||
session=SimpleNamespace(access_token="access"),
|
||||
user=SimpleNamespace(id=""),
|
||||
)
|
||||
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
|
||||
|
||||
request = PasswordResetConfirmRequest(
|
||||
email="test@example.com",
|
||||
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gateway.confirm_password_reset(request)
|
||||
await sut.confirm_password_reset(request)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid or expired verification code"
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
# Supabase 统一服务生命周期设计(优化版)
|
||||
|
||||
**Date:** 2026-03-06
|
||||
**Status:** Draft
|
||||
|
||||
---
|
||||
|
||||
## 0. Intake Contract
|
||||
|
||||
- Objective: 将 Supabase 客户端纳入统一服务生命周期管理,避免每次请求重复创建客户端。
|
||||
- Deliverable: 新增 `SupabaseService`,并基于 `service_interface.py` 的 `ServiceRegistry` 提供统一初始化/关闭路径,完成 auth 侧迁移。
|
||||
- Constraints:
|
||||
- 保持现有 `core.config.settings` 的配置读取行为不变。
|
||||
- 不引入 `os.environ` 直接读取。
|
||||
- 不改变现有 API 语义。
|
||||
- Verification target:
|
||||
- 通过单元测试证明 Supabase 服务初始化、关闭、健康检查行为。
|
||||
- 通过应用启动测试证明统一初始化流程可用。
|
||||
- 通过 auth 相关测试证明迁移后业务行为一致。
|
||||
|
||||
---
|
||||
|
||||
## 1. 复杂度与风险分级
|
||||
|
||||
- Complexity: `S2`
|
||||
- 原因:涉及多文件改造(`services/base`、`app.py`、`v1/auth`、测试)。
|
||||
- Risk Tier: `L1`
|
||||
- 原因:涉及应用启动链路和认证网关依赖,但不改变对外接口契约。
|
||||
|
||||
L1 Gate 要求:执行 `refactor-cleaner` 审视冗余与结构风险(`code-reviewer` 可选)。
|
||||
|
||||
---
|
||||
|
||||
## 2. 现状与问题
|
||||
|
||||
### 2.1 当前现状
|
||||
|
||||
- `SupabaseAuthGateway` 在 `__init__` 内直接 `create_client(...)`,每次实例化都会创建 anon/admin 客户端。
|
||||
- `get_auth_service()` 当前每次请求都会 new `SupabaseAuthGateway()`,导致客户端重复构造。
|
||||
- `ServiceRegistry` 已存在,但目前主要用于注册,应用启动仍是手写逐个初始化。
|
||||
|
||||
### 2.2 核心问题
|
||||
|
||||
1. 生命周期不统一:Supabase 没有接入应用启动/关闭的统一管理。
|
||||
2. 初始化代码重复趋势:服务增多后,`app.py` 的 lifespan 会继续膨胀。
|
||||
3. 网关构造时机风险:若在应用未初始化阶段取客户端,可能抛运行时异常。
|
||||
|
||||
---
|
||||
|
||||
## 3. 优化设计(推荐方案)
|
||||
|
||||
### 3.1 方案摘要
|
||||
|
||||
在 `service_interface.py` 基础上新增统一生命周期函数,按服务名列表批量初始化/关闭;`app.py` 仅声明服务顺序,减少样板代码。Supabase 使用 `config.supabase` 作为默认配置来源,保持 settings 行为一致。
|
||||
|
||||
### 3.2 目标文件结构
|
||||
|
||||
```text
|
||||
backend/src/services/base/
|
||||
├── __init__.py
|
||||
├── service_interface.py # 扩展:统一生命周期函数
|
||||
├── redis.py
|
||||
└── supabase.py # 新增
|
||||
```
|
||||
|
||||
### 3.3 service_interface 统一初始化能力(新增)
|
||||
|
||||
在 `service_interface.py` 新增以下函数(建议命名):
|
||||
|
||||
- `resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]`
|
||||
- `initialize_registered_services(service_names: list[str]) -> tuple[bool, list[BaseServiceProvider]]`
|
||||
- `close_registered_services(services: list[BaseServiceProvider]) -> bool`
|
||||
|
||||
约束与行为:
|
||||
|
||||
1. 初始化按 `service_names` 顺序执行。
|
||||
2. 任一服务初始化失败时:
|
||||
- 返回 `False`。
|
||||
- 对已成功初始化的服务按逆序执行关闭回滚。
|
||||
3. 关闭按逆序执行,最大化依赖安全性。
|
||||
4. 日志必须包含失败服务名和错误摘要。
|
||||
|
||||
这样 `app.py` 只需声明:
|
||||
|
||||
```python
|
||||
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
|
||||
```
|
||||
|
||||
并调用统一函数,减少重复初始化样板。
|
||||
|
||||
### 3.4 SupabaseService 设计
|
||||
|
||||
`supabase.py` 关键点:
|
||||
|
||||
- 继承 `BaseServiceProvider`。
|
||||
- 构造函数签名:
|
||||
- `def __init__(self, settings: SupabaseSettings | None = None) -> None`
|
||||
- 默认 `settings or config.supabase`,确保与当前配置源一致。
|
||||
- `initialize()`:创建 anon/admin 两个 client,失败返回 `False`。
|
||||
- `close()`:
|
||||
- 清空 `_client`、`_admin_client`。
|
||||
- `self._set_initialized(False)`。
|
||||
- `health_check()`:
|
||||
- 必须进行至少一个轻量真实请求验证,不仅检查本地对象存在。
|
||||
- 返回结构与 `RedisService.health_check()`风格一致(`status + details`)。
|
||||
|
||||
注册方式:
|
||||
|
||||
```python
|
||||
supabase_service: SupabaseService = register_service_instance(
|
||||
"supabase", SupabaseService()
|
||||
)
|
||||
```
|
||||
|
||||
### 3.5 app.py 改造
|
||||
|
||||
当前手写 `redis_service.initialize()` 改为调用统一初始化函数。
|
||||
|
||||
目标行为:
|
||||
|
||||
1. 启动阶段:
|
||||
- 调用 `initialize_registered_services(["redis", "supabase"])`。
|
||||
- 失败则 `raise RuntimeError("Service initialization failed")`。
|
||||
2. 关闭阶段:
|
||||
- 调用 `close_registered_services(initialized_services)`。
|
||||
|
||||
### 3.6 AuthGateway 迁移策略(避免构造时机问题)
|
||||
|
||||
不建议在 `SupabaseAuthGateway.__init__` 里立即绑定 client;改为按需获取:
|
||||
|
||||
- 保留网关对象轻量化。
|
||||
- 在每个业务方法内部通过 `supabase_service.get_client()` / `get_admin_client()` 取实例。
|
||||
|
||||
优点:
|
||||
|
||||
1. 避免模块导入或依赖构建阶段误触未初始化 client。
|
||||
2. 对 `users/dependencies.py` 中全局缓存 gateway 的场景更安全。
|
||||
3. 不改变业务层接口。
|
||||
|
||||
---
|
||||
|
||||
## 4. 配置与兼容性保证
|
||||
|
||||
### 4.1 settings/config 行为不变
|
||||
|
||||
迁移后依然通过 `core.config.settings.config.supabase` 读取:
|
||||
|
||||
- `url`
|
||||
- `anon_key`
|
||||
- `service_role_key`
|
||||
- `jwt_secret`(JWT 校验现有逻辑继续使用)
|
||||
|
||||
### 4.2 环境变量兼容
|
||||
|
||||
由于 `Settings` + `env_nested_delimiter` 机制不变,现有环境变量命名与 `.env` 内容无需修改。
|
||||
|
||||
### 4.3 对现有代码影响
|
||||
|
||||
- API 层 schema/路由不变。
|
||||
- 认证行为不变。
|
||||
- 仅优化客户端生命周期与启动流程。
|
||||
|
||||
---
|
||||
|
||||
## 5. 实施计划(可执行)
|
||||
|
||||
### Task 1: 扩展统一生命周期接口
|
||||
|
||||
**Files**
|
||||
- Modify: `backend/src/services/base/service_interface.py`
|
||||
- Test: `backend/tests/unit/services/base/test_service_interface.py`(新增)
|
||||
|
||||
**Steps**
|
||||
1. 写失败测试:初始化顺序、失败回滚、关闭逆序。
|
||||
2. 实现生命周期函数。
|
||||
3. 跑单测确认通过。
|
||||
|
||||
### Task 2: 新增 SupabaseService
|
||||
|
||||
**Files**
|
||||
- Create: `backend/src/services/base/supabase.py`
|
||||
- Modify: `backend/src/services/base/__init__.py`
|
||||
- Test: `backend/tests/unit/services/base/test_supabase.py`
|
||||
|
||||
**Steps**
|
||||
1. 写失败测试(init success/fail、close、health_check)。
|
||||
2. 实现 `SupabaseService` 与实例注册。
|
||||
3. 跑单测。
|
||||
|
||||
### Task 3: 接入 app lifespan 统一初始化
|
||||
|
||||
**Files**
|
||||
- Modify: `backend/src/app.py`
|
||||
- Test: `backend/tests/integration/test_app_lifespan.py`(新增或扩展)
|
||||
|
||||
**Steps**
|
||||
1. 写失败测试(supabase init fail 时应用启动失败)。
|
||||
2. 替换手写初始化为统一函数。
|
||||
3. 跑集成测试。
|
||||
|
||||
### Task 4: 迁移 AuthGateway 获取 client 方式
|
||||
|
||||
**Files**
|
||||
- Modify: `backend/src/v1/auth/gateway.py`
|
||||
- Optional Modify: `backend/src/v1/auth/dependencies.py`
|
||||
- Optional Modify: `backend/src/v1/users/dependencies.py`
|
||||
- Test: `backend/tests/unit/v1/auth/test_gateway.py`(扩展)
|
||||
|
||||
**Steps**
|
||||
1. 写失败测试(未初始化时错误、初始化后正常调用)。
|
||||
2. 改为方法内按需取 client。
|
||||
3. 跑 auth 相关单测。
|
||||
|
||||
### Task 5: 全量验证与门禁
|
||||
|
||||
**Commands**
|
||||
- `uv run ruff check backend/src backend/tests`
|
||||
- `uv run basedpyright`
|
||||
- `uv run pytest backend/tests/unit/services/base -q`
|
||||
- `uv run pytest backend/tests/unit/v1/auth -q`
|
||||
- `uv run pytest backend/tests/integration -q`
|
||||
|
||||
输出要求:记录每条命令 pass/fail 与关键摘要。
|
||||
|
||||
---
|
||||
|
||||
## 6. 验收标准(更新)
|
||||
|
||||
- [ ] `SupabaseService` 继承 `BaseServiceProvider` 并注册到 `ServiceRegistry`
|
||||
- [ ] `service_interface.py` 提供统一初始化/关闭函数
|
||||
- [ ] `app.py` 通过统一函数初始化 `redis + supabase`
|
||||
- [ ] Supabase 配置读取仍仅来自 `core.config.settings.config`
|
||||
- [ ] `auth/gateway.py` 不再在 `__init__` 新建客户端
|
||||
- [ ] 初始化失败具备回滚关闭逻辑
|
||||
- [ ] 单元/集成测试覆盖核心迁移路径并通过
|
||||
|
||||
---
|
||||
|
||||
## 7. 风险与缓解
|
||||
|
||||
| 风险 | 级别 | 缓解 |
|
||||
|---|---|---|
|
||||
| 统一初始化函数引入顺序错误 | 中 | 显式 `SERVICE_STARTUP_ORDER` + 顺序测试 |
|
||||
| Supabase 健康检查误报 | 中 | 使用真实轻量请求,不只做对象检查 |
|
||||
| gateway 与生命周期耦合导致运行时错误 | 中 | 改为方法内按需取 client,并覆盖未初始化测试 |
|
||||
| 迁移影响现有 auth 行为 | 中 | 保持 service 接口不变,补充回归测试 |
|
||||
|
||||
---
|
||||
|
||||
## 8. 完成定义(Completion Contract)
|
||||
|
||||
1. Complexity: `S2`
|
||||
2. Risk Tier: `L1`
|
||||
3. Gates:
|
||||
- 必需:`refactor-cleaner`
|
||||
- 可选:`code-reviewer`(建议在合并前执行)
|
||||
4. Verification evidence:
|
||||
- 提供 lint/typecheck/unit/integration 命令结果
|
||||
5. Remaining risks/follow-ups:
|
||||
- 若后续新增第三方服务,沿用 `ServiceRegistry + 统一生命周期函数` 接入,不再在 `app.py` 手写初始化。
|
||||
@@ -0,0 +1,359 @@
|
||||
# Celery To Taskiq One-Shot Migration Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 在当前早期项目中一次性移除 Celery,并以 Taskiq 替换异步任务基础设施,保持 agent runtime 行为不变。
|
||||
|
||||
**Architecture:** 复用现有 `AgentService -> QueueClientLike` 抽象,仅替换基础设施层实现(任务声明、入队调用、worker 启动、配置与依赖)。保持 Redis 作为 broker/result 存储与事件流通道,避免改动业务服务层语义。
|
||||
|
||||
**Tech Stack:** FastAPI, Taskiq, taskiq-redis, Redis, pytest, uv
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 依赖与配置切换(先 RED 后 GREEN)
|
||||
|
||||
**Files:**
|
||||
- Modify: `pyproject.toml`
|
||||
- Modify: `backend/src/core/config/settings.py`
|
||||
- Test: `backend/tests/unit/core/config/test_taskiq_settings.py` (new)
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_taskiq_uses_redis_url_by_default() -> None:
|
||||
settings = Settings()
|
||||
assert settings.taskiq_broker_url.startswith("redis://")
|
||||
|
||||
|
||||
def test_taskiq_queue_default_value() -> None:
|
||||
settings = Settings()
|
||||
assert settings.taskiq.default_queue == "default"
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/config/test_taskiq_settings.py -v`
|
||||
Expected: FAIL(`taskiq_broker_url` / `taskiq` 字段不存在)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
class TaskiqSettings(BaseModel):
|
||||
broker_url: str | None = None
|
||||
result_backend_url: str | None = None
|
||||
default_queue: str = "default"
|
||||
|
||||
class Settings(BaseSettings):
|
||||
taskiq: TaskiqSettings = TaskiqSettings()
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def taskiq_broker_url(self) -> str:
|
||||
return self.taskiq.broker_url or self.redis.url
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def taskiq_result_backend_url(self) -> str:
|
||||
return self.taskiq.result_backend_url or self.redis.url
|
||||
```
|
||||
|
||||
`pyproject.toml` 同步变更:
|
||||
- 删除 `celery>=...`
|
||||
- 增加 `taskiq>=...`
|
||||
- 增加 `taskiq-redis>=...`
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/config/test_taskiq_settings.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add pyproject.toml backend/src/core/config/settings.py backend/tests/unit/core/config/test_taskiq_settings.py
|
||||
git commit -m "refactor(queue): replace celery config with taskiq settings"
|
||||
```
|
||||
|
||||
### Task 2: 新建 Taskiq broker 与 worker 启动入口
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/src/core/taskiq/app.py`
|
||||
- Create: `backend/tests/unit/core/taskiq/test_app.py`
|
||||
- Delete: `backend/src/core/celery/app.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
from core.taskiq.app import broker
|
||||
|
||||
|
||||
def test_taskiq_broker_is_configured() -> None:
|
||||
assert broker is not None
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/taskiq/test_app.py -v`
|
||||
Expected: FAIL(模块不存在)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
from core.config.settings import config
|
||||
|
||||
broker = ListQueueBroker(url=config.taskiq_broker_url).with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
|
||||
)
|
||||
```
|
||||
|
||||
说明:若当前 `taskiq-redis` 版本 API 名称有差异,以该版本官方 API 为准做等价实现。
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/taskiq/test_app.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/taskiq/app.py backend/tests/unit/core/taskiq/test_app.py backend/src/core/celery/app.py
|
||||
git commit -m "feat(queue): add taskiq broker app and remove celery app"
|
||||
```
|
||||
|
||||
### Task 3: 迁移任务定义(Celery task -> Taskiq task)
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agent/infrastructure/queue/tasks.py`
|
||||
- Test: `backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py` (new)
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
|
||||
|
||||
async def test_run_agent_task_invalid_command_raises() -> None:
|
||||
try:
|
||||
await run_agent_task({"command": "unknown", "session_id": "00000000-0000-0000-0000-000000000001"})
|
||||
raise AssertionError("expected ValueError")
|
||||
except ValueError as exc:
|
||||
assert "invalid command type" in str(exc)
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py -v`
|
||||
Expected: FAIL(测试文件不存在或导入失败)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
from core.taskiq.app import broker
|
||||
|
||||
@broker.task(task_name="tasks.agent.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
```
|
||||
|
||||
并移除:
|
||||
- `from core.celery.app import celery_app`
|
||||
- `@celery_app.task(...)`
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/core/agent/infrastructure/queue/tasks.py backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py
|
||||
git commit -m "refactor(agent): migrate run command task to taskiq"
|
||||
```
|
||||
|
||||
### Task 4: 迁移 API 入队客户端(.delay -> .kiq)
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/dependencies.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_dependencies_queue.py` (new)
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
class _FakeTask:
|
||||
async def kiq(self, payload: dict[str, object]):
|
||||
class _Result:
|
||||
task_id = "task-123"
|
||||
return _Result()
|
||||
|
||||
|
||||
async def test_enqueue_returns_task_id(monkeypatch):
|
||||
from v1.agent.dependencies import CeleryQueueClient
|
||||
client = CeleryQueueClient() # 迁移后应重命名为 TaskiqQueueClient
|
||||
monkeypatch.setattr("v1.agent.dependencies.run_command_task", _FakeTask())
|
||||
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
|
||||
assert task_id == "task-123"
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_dependencies_queue.py -v`
|
||||
Expected: FAIL(类型/方法不匹配)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
class TaskiqQueueClient:
|
||||
async def enqueue(self, *, command: dict[str, object], dedup_key: str | None) -> str:
|
||||
payload = dict(command)
|
||||
if dedup_key:
|
||||
payload["dedup_key"] = dedup_key
|
||||
|
||||
result = await run_command_task.kiq(payload)
|
||||
task_id = str(result.task_id)
|
||||
return task_id
|
||||
```
|
||||
|
||||
并替换 DI:
|
||||
|
||||
```python
|
||||
queue=TaskiqQueueClient()
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_dependencies_queue.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/dependencies.py backend/tests/unit/v1/agent/test_dependencies_queue.py
|
||||
git commit -m "refactor(api): switch agent enqueue client from celery to taskiq"
|
||||
```
|
||||
|
||||
### Task 5: 运维脚本与日志测试清理(一次性删除 Celery)
|
||||
|
||||
**Files:**
|
||||
- Modify: `infra/scripts/app.sh`
|
||||
- Delete: `backend/tests/unit/test_celery_logging.py`
|
||||
- Modify/Create: `backend/tests/unit/core/logging/test_taskiq_logging.py` (if taskiq logging hook implemented)
|
||||
- Modify: `backend/src/core/logging/__init__.py`(移除 celery logging export)
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
def test_worker_command_uses_taskiq() -> None:
|
||||
content = Path("infra/scripts/app.sh").read_text()
|
||||
assert "uv run taskiq worker" in content
|
||||
assert "uv run celery" not in content
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/logging/test_taskiq_logging.py -v`
|
||||
Expected: FAIL(脚本仍含 celery)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
`infra/scripts/app.sh` worker 命令替换为 Taskiq worker,例如:
|
||||
|
||||
```bash
|
||||
uv run taskiq worker core.taskiq.app:broker core.agent.infrastructure.queue.tasks
|
||||
```
|
||||
|
||||
删除所有 celery 进程清理匹配:
|
||||
|
||||
```bash
|
||||
pgrep -f "taskiq.*worker"
|
||||
pkill -f "taskiq.*worker"
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/logging/test_taskiq_logging.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add infra/scripts/app.sh backend/src/core/logging/__init__.py backend/tests/unit/core/logging/test_taskiq_logging.py backend/tests/unit/test_celery_logging.py
|
||||
git commit -m "chore(infra): replace celery worker scripts and remove celery-specific tests"
|
||||
```
|
||||
|
||||
### Task 6: 全量引用清理与回归验证
|
||||
|
||||
**Files:**
|
||||
- Modify: `docs/runtime/runtime-runbook.md`
|
||||
- Modify: 其他引用 Celery 的运行文档(按 `rg` 结果逐个更新)
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
# 用命令断言替代代码测试
|
||||
# rg -n "celery" backend/src infra/scripts docs/runtime pyproject.toml
|
||||
```
|
||||
|
||||
**Step 2: Run check to verify it fails**
|
||||
|
||||
Run: `rg -n "celery" backend/src infra/scripts docs/runtime pyproject.toml`
|
||||
Expected: 仍有旧引用
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
- 删除/替换剩余 Celery 代码、文档、配置。
|
||||
- 保留历史变更记录中的 Celery 字样(如 bugs 归档)可接受,但运行路径必须为 0 引用。
|
||||
|
||||
**Step 4: Run verification suite**
|
||||
|
||||
Run:
|
||||
- `uv run pytest backend/tests/unit -q`
|
||||
- `uv run pytest backend/tests/integration -q`
|
||||
- `uv run pytest backend/tests/e2e -q`(如环境不满足,记录原因)
|
||||
- `uv run ruff check backend/src backend/tests`
|
||||
- `uv run basedpyright`
|
||||
- `rg -n "celery" backend/src infra/scripts pyproject.toml`
|
||||
|
||||
Expected:
|
||||
- 测试与静态检查通过
|
||||
- 运行路径无 Celery 引用
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add docs/runtime/runtime-runbook.md pyproject.toml backend/src infra/scripts backend/tests
|
||||
git commit -m "refactor(queue): complete one-shot migration from celery to taskiq"
|
||||
```
|
||||
|
||||
### Task 7: L1 Review Gates 与交付确认
|
||||
|
||||
**Files:**
|
||||
- No code changes required by default
|
||||
|
||||
**Step 1: Run required L1 gate (`refactor-cleaner`)**
|
||||
|
||||
Run: 使用 `refactor-cleaner` 审查迁移后冗余代码、死引用、命名一致性。
|
||||
Expected: 无阻断问题。
|
||||
|
||||
**Step 2: Optional `code-reviewer` (recommended for infra switch)**
|
||||
|
||||
Run: 使用 `code-reviewer` 聚焦任务丢失、重复消费、幂等锁逻辑。
|
||||
Expected: 无 CRITICAL/HIGH 问题。
|
||||
|
||||
**Step 3: Final evidence report**
|
||||
|
||||
输出内容必须包含:
|
||||
- 执行命令列表
|
||||
- 每条命令 PASS/FAIL
|
||||
- 若有无法执行项(如 e2e 环境),给出原因与人工验证步骤
|
||||
|
||||
**Step 4: Commit review notes (optional)**
|
||||
|
||||
```bash
|
||||
git add docs/plans/2026-03-06-taskiq-migration.md
|
||||
git commit -m "docs(plan): taskiq one-shot migration execution checklist"
|
||||
```
|
||||
@@ -1,144 +0,0 @@
|
||||
# Agent LLM Config Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 将 `system_agents.config` 中的 `temperature` / `max_tokens` 以受约束方式加载到运行时,并在调用 LiteLLM 时按需透传。
|
||||
|
||||
**Architecture:** 在应用层 `RunService` 读取模型选择时同步读取并校验 `SystemAgents.config`;将校验后的 `SystemAgentLLMConfig` 传入 `CrewAIRuntime`;由 runtime 将配置转交给 LiteLLM client,client 仅在值非 `None` 时向 `completion()` 传参,避免不必要的 provider 兼容风险。
|
||||
|
||||
**Tech Stack:** FastAPI, SQLAlchemy (async), Pydantic v2, LiteLLM, pytest
|
||||
|
||||
---
|
||||
|
||||
## 背景与修正点
|
||||
|
||||
- 当前真实调用链为:`RunService._load_agent_model_selection()` -> `create_runtime()` -> `CrewAIRuntime.execute()` -> `run_completion()`,并非 `load_stage_models()`。
|
||||
- `SystemAgentLLMConfig` 已存在:`backend/src/core/agent/domain/system_agent_config.py`。
|
||||
- `system_agents.config` 目前在初始化 YAML 侧有约束,但运行时 DB 读取仍需二次校验,防止脏数据绕过。
|
||||
|
||||
## 规则约束
|
||||
|
||||
- 严格 TDD:先写失败测试,再做实现。
|
||||
- Python 命令统一使用 `uv run ...`。
|
||||
- 仅做增量改动,不回滚或覆盖与本任务无关的已有变更。
|
||||
|
||||
## 字段映射与透传策略
|
||||
|
||||
| 配置字段 | LiteLLM 参数 | 规则 |
|
||||
|---|---|---|
|
||||
| `temperature` | `temperature` | `None` 不透传;非空直接透传 |
|
||||
| `max_tokens` | `max_tokens` | `None` 不透传;非空直接透传 |
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 应用层加载并校验 Agent LLM Config
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agent/application/run_service.py`
|
||||
- Test: `backend/tests/unit/core/agent/test_run_resume_service.py`
|
||||
|
||||
**Step 1: 写失败测试(RED)**
|
||||
|
||||
新增单测覆盖以下行为:
|
||||
1. `_load_agent_model_selection()` 返回三元组:`(model_code, provider_name, llm_config)`。
|
||||
2. 当 DB `config` 为 `{}` 时,`llm_config.temperature/max_tokens` 为 `None`。
|
||||
3. 当 DB `config` 含非法值(如 `temperature=3`)时抛 `ValueError`。
|
||||
|
||||
**Step 2: 运行测试确认失败**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
|
||||
Expected: 新增断言失败(返回值结构/异常行为不匹配)。
|
||||
|
||||
**Step 3: 最小实现(GREEN)**
|
||||
|
||||
在 `run_service.py`:
|
||||
1. 查询 `SystemAgents.config`。
|
||||
2. 用 `SystemAgentLLMConfig.model_validate(config or {})` 校验。
|
||||
3. 将 `_load_agent_model_selection()` 改为返回三元组。
|
||||
4. 在 `run()` 中把 `llm_config` 传递到 `create_runtime(...)`。
|
||||
|
||||
**Step 4: 运行测试确认通过**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
|
||||
Expected: PASS。
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Runtime 与 LiteLLM Client 支持可选参数透传
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/agent/infrastructure/crewai/factory.py`
|
||||
- Modify: `backend/src/core/agent/infrastructure/crewai/runtime.py`
|
||||
- Modify: `backend/src/core/agent/infrastructure/litellm/client.py`
|
||||
- Test: `backend/tests/unit/core/agent/test_crewai_runtime.py`
|
||||
|
||||
**Step 1: 写失败测试(RED)**
|
||||
|
||||
在 `test_crewai_runtime.py` 增加用例:
|
||||
1. 传入 `temperature/max_tokens` 时,`run_completion` 收到对应参数。
|
||||
2. 参数为 `None` 时,不应被透传到 LiteLLM。
|
||||
|
||||
必要时新增 `backend/tests/unit/core/agent/test_litellm_client.py`,单测 `run_completion` 的 kwargs 组装逻辑。
|
||||
|
||||
**Step 2: 运行测试确认失败**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/test_crewai_runtime.py -q`
|
||||
Expected: 新增断言失败(参数未透传或未过滤 `None`)。
|
||||
|
||||
**Step 3: 最小实现(GREEN)**
|
||||
|
||||
1. `create_runtime()` 增加 `llm_config` 参数并传给 `CrewAIRuntime`。
|
||||
2. `CrewAIRuntime` 保存 `llm_config`,执行时调用:
|
||||
- `run_completion(..., temperature=llm_config.temperature, max_tokens=llm_config.max_tokens)`
|
||||
3. `run_completion()` 改为支持可选 `temperature/max_tokens`,内部仅在非 `None` 时加入 kwargs 再调用 `completion()`。
|
||||
|
||||
**Step 4: 运行测试确认通过**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/test_crewai_runtime.py -q`
|
||||
Expected: PASS。
|
||||
|
||||
---
|
||||
|
||||
### Task 3: 初始化数据补齐与回归验证
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/core/config/static/database/system_agents.yaml`
|
||||
- Modify: `backend/src/core/config/initial/init_data.py`(如需补充类型兜底)
|
||||
- Test: `backend/tests/unit/core/agent/test_run_resume_service.py`
|
||||
|
||||
**Step 1: 写失败测试(RED)**
|
||||
|
||||
补充断言:YAML 读取后 `config` 可为空或包含 `max_tokens: null`,初始化逻辑不会报错,且生成结构符合 `SystemAgentLLMConfig`。
|
||||
|
||||
**Step 2: 运行测试确认失败**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
|
||||
Expected: 新增断言失败。
|
||||
|
||||
**Step 3: 最小实现(GREEN)**
|
||||
|
||||
1. 在 `system_agents.yaml` 为各 agent 配置显式补充 `max_tokens: null`。
|
||||
2. `init_data.py` 保持 `config: SystemAgentLLMConfig | None = None`,写库时统一序列化为 dict。
|
||||
|
||||
**Step 4: 运行测试确认通过**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
|
||||
Expected: PASS。
|
||||
|
||||
---
|
||||
|
||||
## 最终验证
|
||||
|
||||
1. `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py backend/tests/unit/core/agent/test_crewai_runtime.py -q`
|
||||
2. `uv run pytest backend/tests/integration/core/agent/test_queue_run_resume.py -q`
|
||||
3. `uv run ruff check backend/src backend/tests`
|
||||
4. `uv run basedpyright`
|
||||
|
||||
预期:全部通过;若集成测试依赖本地 DB 状态导致跳过/失败,需记录原因并给出手工验证步骤。
|
||||
|
||||
## 完成标准
|
||||
|
||||
- `RunService` 从 DB 读取并校验 `config`。
|
||||
- runtime 到 LiteLLM 链路支持 `temperature/max_tokens` 可选透传。
|
||||
- `None` 不透传。
|
||||
- 单测与相关集成测试通过,并给出命令级证据。
|
||||
@@ -69,7 +69,7 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T db \
|
||||
### 启动应用进程
|
||||
|
||||
```bash
|
||||
bash infra/scripts/app-up.sh
|
||||
bash infra/scripts/app.sh start
|
||||
```
|
||||
|
||||
该脚本会在 tmux `social-dev` 会话中拉起:
|
||||
@@ -172,6 +172,7 @@ curl -sS "${WEB_BASE_URL}/api/v1/profile/me" \
|
||||
- 症状:队列堆积,任务长时间 pending。
|
||||
- 定位:检查 `worker-*` tmux 窗口和对应日志文件。
|
||||
- 修复:重启 tmux 会话,确认并发配置与队列名(critical/default/bulk)。
|
||||
- 说明:Taskiq 路径当前仅消费 `SOCIAL_WORKER__GROUPS__*__CONCURRENCY`,旧 Celery 参数(prefetch/time_limit 等)已废弃。
|
||||
|
||||
### 2.1) Agent Runtime run/resume 事件不闭环
|
||||
|
||||
@@ -179,7 +180,7 @@ curl -sS "${WEB_BASE_URL}/api/v1/profile/me" \
|
||||
- 定位步骤:
|
||||
|
||||
```bash
|
||||
# 1) 检查 celery worker 是否消费 agent 任务
|
||||
# 1) 检查 taskiq worker 是否消费 agent 任务
|
||||
grep -E "tasks\.agent\.run_command|RUN_STARTED|RUN_FINISHED|RUN_ERROR" logs/worker-default.log
|
||||
|
||||
# 2) 检查 API SSE 事件读取(带 Last-Event-ID)
|
||||
@@ -192,7 +193,7 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T redis
|
||||
```
|
||||
|
||||
- 修复建议:
|
||||
- 若 worker 无消费:重启 `worker-default` 窗口并确认 `core.agent.infrastructure.queue.tasks` 已被 Celery include。
|
||||
- 若 worker 无消费:重启 `worker-default` 窗口并确认 `core.agent.infrastructure.queue.tasks` 已被 Taskiq worker 加载。
|
||||
- 若 worker 有事件但 API 无输出:排查 Redis stream 前缀配置与 session_id 是否一致。
|
||||
- 若出现 `RUN_ERROR`:按 error_id 回查后端日志,不在 API/SSE 中暴露敏感上下文。
|
||||
|
||||
@@ -270,4 +271,4 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml up -d --force-
|
||||
| 2026-02-28 | 邀请码功能:新增 invite_codes 表、profiles.referred_by,注册时可选填邀请码并记录邀请关系 |
|
||||
| 2026-03-02 | 文档整理:修正 auth 端点名称(/verifications)、补充 profile 路由文档、修复 L2/L3 验证命令 |
|
||||
| 2026-03-02 | 修正 bootstrap 命令:init-job 需要使用 `uv run python -m core.runtime.cli bootstrap` |
|
||||
| 2026-03-05 | 新增 Agent Runtime run/resume/events 运维排障流程(Celery + Redis + Last-Event-ID) |
|
||||
| 2026-03-05 | 新增 Agent Runtime run/resume/events 运维排障流程(Taskiq + Redis + Last-Event-ID) |
|
||||
|
||||
@@ -56,9 +56,9 @@ ${SOCIAL_WEB__GUNICORN__WORKER_CLASS:-uvicorn.workers.UvicornWorker} --timeout \
|
||||
${SOCIAL_WEB__GUNICORN__TIMEOUT:-60} \
|
||||
--log-level ${SOCIAL_RUNTIME__LOG_LEVEL:-info}"
|
||||
|
||||
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run celery -A core.celery.app worker --loglevel=info --queues=critical --concurrency=${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
|
||||
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run celery -A core.celery.app worker --loglevel=info --queues=default --concurrency=${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
|
||||
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run celery -A core.celery.app worker --loglevel=info --queues=bulk --concurrency=${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
|
||||
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
|
||||
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
|
||||
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
|
||||
|
||||
tmux new-session -d -s "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""
|
||||
tmux new-window -t "$SESSION_NAME" -n worker-critical "bash -lc \"$WORKER_CRITICAL_CMD; echo '[worker-critical] exited'; exec bash\""
|
||||
@@ -92,9 +92,9 @@ stop() {
|
||||
echo "Killing orphaned gunicorn processes..."
|
||||
pkill -f "gunicorn.*app:app"
|
||||
fi
|
||||
if pgrep -f "celery.*worker" > /dev/null 2>&1; then
|
||||
echo "Killing orphaned celery processes..."
|
||||
pkill -f "celery.*worker"
|
||||
if pgrep -f "taskiq.*worker" > /dev/null 2>&1; then
|
||||
echo "Killing orphaned taskiq processes..."
|
||||
pkill -f "taskiq.*worker"
|
||||
fi
|
||||
|
||||
echo "Session stopped and cleaned up."
|
||||
|
||||
+3
-1
@@ -8,7 +8,6 @@ dependencies = [
|
||||
"alembic>=1.18.3",
|
||||
"asyncpg>=0.31.0",
|
||||
"basedpyright>=1.37.2",
|
||||
"celery>=5.6.2",
|
||||
"crewai>=1.6.1",
|
||||
"crewai-tools>=1.6.1",
|
||||
"email-validator>=2.3.0",
|
||||
@@ -22,8 +21,11 @@ dependencies = [
|
||||
"redis>=7.1.0",
|
||||
"sqlalchemy[asyncio]>=2.0.46",
|
||||
"structlog>=24.4.0",
|
||||
"taskiq>=0.11.0",
|
||||
"taskiq-redis>=1.0.0",
|
||||
"supabase>=2.27.2",
|
||||
"uvicorn[standard]>=0.40.0",
|
||||
"gunicorn>=25.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
Reference in New Issue
Block a user