diff --git a/.env.example b/.env.example index 88d3456..86fe96c 100644 --- a/.env.example +++ b/.env.example @@ -35,17 +35,16 @@ SOCIAL_REDIS__DB=0 # default: 常规异步任务 # bulk: 批处理/重计算/可延迟任务 SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY=2 -SOCIAL_WORKER__GROUPS__CRITICAL__PREFETCH_MULTIPLIER=1 -SOCIAL_WORKER__GROUPS__CRITICAL__TIME_LIMIT=300 SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY=2 -SOCIAL_WORKER__GROUPS__DEFAULT__PREFETCH_MULTIPLIER=4 -SOCIAL_WORKER__GROUPS__DEFAULT__TIME_LIMIT=600 SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY=1 -SOCIAL_WORKER__GROUPS__BULK__PREFETCH_MULTIPLIER=1 -SOCIAL_WORKER__GROUPS__BULK__TIME_LIMIT=3600 -SOCIAL_WORKER__GROUPS__BULK__MAX_TASKS_PER_CHILD=100 + +############ +# Taskiq(可选,默认回落到 Redis URL) +############ +# SOCIAL_TASKIQ__BROKER_URL=redis://:password@localhost:6379/0 +# SOCIAL_TASKIQ__RESULT_BACKEND_URL=redis://:password@localhost:6379/0 ############ # Supabase(本地 Docker 与阿里云自托管保持同一变量) diff --git a/backend/src/app.py b/backend/src/app.py index ea5a66e..898db02 100644 --- a/backend/src/app.py +++ b/backend/src/app.py @@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from core.config.settings import config from core.http.response import build_problem_details from core.logging import configure_logging, get_logger, log_service_banner -from services.base.redis import redis_service +from services.base import close_registered_services, initialize_registered_services from v1.router import router as mobile_router @@ -29,22 +29,23 @@ log_service_banner( ) logger = get_logger("api.app") +SERVICE_STARTUP_ORDER = ["redis", "supabase"] @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - initialized = await redis_service.initialize() + initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER) if not initialized: - logger.error("Redis service failed to initialize, aborting startup") - raise RuntimeError("Redis service initialization failed") - logger.info( - "Redis service initialized", - host=config.redis.host, - db=config.redis.db, - ) - yield - await redis_service.close() - logger.info("Redis service closed") + logger.error("Service initialization failed, aborting startup") + raise RuntimeError("Service initialization failed") + logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER) + try: + yield + finally: + closed = await close_registered_services(services) + if not closed: + logger.warning("Failed to close all base services") + logger.info("Base services closed", services=SERVICE_STARTUP_ORDER) app = FastAPI(lifespan=lifespan) diff --git a/backend/src/core/agent/infrastructure/queue/tasks.py b/backend/src/core/agent/infrastructure/queue/tasks.py index 962d8d8..e38e3dd 100644 --- a/backend/src/core/agent/infrastructure/queue/tasks.py +++ b/backend/src/core/agent/infrastructure/queue/tasks.py @@ -1,15 +1,15 @@ from __future__ import annotations -from typing import Any, Protocol, cast +from typing import Any, Protocol from uuid import UUID from core.agent.application.resume_service import ResumeService from core.agent.application.run_service import RunService from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore -from core.celery.app import celery_app from core.config.settings import config from core.logging import get_logger -from services.base.redis import redis_service +from core.taskiq.app import bulk_broker, critical_broker, default_broker +from services.base.redis import get_or_init_redis_client logger = get_logger("core.agent.infrastructure.queue.tasks") @@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol): async def _build_redis_publisher() -> PublishEvent: - settings = cast(Any, config) - client = redis_service.get_client() + client = await get_or_init_redis_client() event_store = RedisStreamEventStore( client=client, - stream_prefix=settings.agent_runtime.redis_stream_prefix, - read_count=settings.agent_runtime.redis_stream_read_count, - block_ms=settings.agent_runtime.redis_stream_block_ms, + stream_prefix=config.agent_runtime.redis_stream_prefix, + read_count=config.agent_runtime.redis_stream_read_count, + block_ms=config.agent_runtime.redis_stream_block_ms, ) async def _publish(event_type: str, payload: dict[str, object]) -> None: @@ -70,22 +69,27 @@ async def run_agent_task( raise ValueError("session_id is required") UUID(session_id) + tool_call_id = "" + user_input = "" + if command_type == "resume": + tool_call_id = str(command.get("tool_call_id", "")) + if not tool_call_id: + raise ValueError("tool_call_id is required") + else: + user_input = str(command.get("user_input", "")) + if not user_input: + raise ValueError("user_input is required") + start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED" await publisher(start_event, {"session_id": session_id}) try: if command_type == "resume": - tool_call_id = str(command.get("tool_call_id", "")) - if not tool_call_id: - raise ValueError("tool_call_id is required") result = await service_resume.resume( session_id=session_id, tool_call_id=tool_call_id, ) else: - user_input = str(command.get("user_input", "")) - if not user_input: - raise ValueError("user_input is required") result = await service_run.run( session_id=session_id, user_input=user_input, @@ -125,6 +129,16 @@ async def run_agent_task( raise -@celery_app.task(name="tasks.agent.run_command") +@default_broker.task(task_name="tasks.agent.run_command") async def run_command_task(command: dict[str, Any]) -> dict[str, object]: return await run_agent_task(command) + + +@critical_broker.task(task_name="tasks.agent.run_command.critical") +async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]: + return await run_agent_task(command) + + +@bulk_broker.task(task_name="tasks.agent.run_command.bulk") +async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]: + return await run_agent_task(command) diff --git a/backend/src/core/celery/app.py b/backend/src/core/celery/app.py deleted file mode 100644 index b2674cc..0000000 --- a/backend/src/core/celery/app.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import annotations - -from celery import Celery -from celery import signals as celery_signals -from kombu import Queue - -from core.config.settings import config -from core.logging import get_logger -from core.logging.celery import configure_celery_app -from services.base.redis import redis_service - -logger = get_logger("core.celery") - - -def _init_redis_on_worker_startup(**_: object) -> None: - import asyncio - - logger.info("Initializing Redis service for Celery worker") - try: - result = asyncio.run(redis_service.initialize()) - if result: - logger.info("Redis service initialized for Celery worker") - else: - logger.warning("Redis service initialization returned False") - except Exception as exc: # noqa: BLE001 - logger.error("Failed to initialize Redis for Celery worker", error=str(exc)) - - -def create_celery_app() -> Celery: - """Create and configure the Celery application.""" - celery_settings = config.celery - - app = Celery( - "social_app", - broker=config.celery_broker_url, - backend=config.celery_result_backend, - include=["core.agent.infrastructure.queue.tasks"], - ) - - app.conf.update( - task_serializer=celery_settings.task_serializer, - result_serializer=celery_settings.result_serializer, - accept_content=celery_settings.accept_content, - timezone=celery_settings.timezone, - enable_utc=celery_settings.enable_utc, - task_track_started=celery_settings.task_track_started, - task_time_limit=celery_settings.task_time_limit, - task_soft_time_limit=celery_settings.task_soft_time_limit, - task_default_retry_delay=celery_settings.task_default_retry_delay, - task_default_queue="default", - task_create_missing_queues=False, - task_queues=( - Queue("default"), - Queue("critical"), - Queue("bulk"), - ), - task_routes={ - "tasks.critical.*": {"queue": "critical"}, - "tasks.bulk.*": {"queue": "bulk"}, - }, - task_acks_late=True, - task_reject_on_worker_lost=True, - worker_prefetch_multiplier=1, - ) - - configure_celery_app(app, settings=config) - - return app - - -celery_app = create_celery_app() - -celery_signals.worker_process_init.connect(_init_redis_on_worker_startup) diff --git a/backend/src/core/config/settings.py b/backend/src/core/config/settings.py index a04025f..1138513 100644 --- a/backend/src/core/config/settings.py +++ b/backend/src/core/config/settings.py @@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel): return self -class CelerySettings(BaseModel): +class TaskiqSettings(BaseModel): broker_url: str | None = None - result_backend: str | None = None - task_serializer: str = "json" - result_serializer: str = "json" - accept_content: list[str] = Field(default_factory=lambda: ["json"]) - timezone: str = "UTC" - enable_utc: bool = True - task_track_started: bool = True - task_time_limit: int = 300 - task_soft_time_limit: int = 240 - task_default_retry_delay: int = 30 - task_max_retries: int = 3 + result_backend_url: str | None = None class CorsSettings(BaseModel): @@ -189,7 +179,7 @@ class Settings(BaseSettings): storage: StorageSettings = StorageSettings() llm: LlmSettings = LlmSettings() agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings() - celery: CelerySettings = CelerySettings() + taskiq: TaskiqSettings = TaskiqSettings() database: DatabaseSettings = DatabaseSettings() @computed_field @@ -199,13 +189,13 @@ class Settings(BaseSettings): @computed_field @property - def celery_broker_url(self) -> str: - return self.celery.broker_url or self.redis.url + def taskiq_broker_url(self) -> str: + return self.taskiq.broker_url or self.redis.url @computed_field @property - def celery_result_backend(self) -> str: - return self.celery.result_backend or self.redis.url + def taskiq_result_backend_url(self) -> str: + return self.taskiq.result_backend_url or self.redis.url model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( env_file=_resolve_env_file(), diff --git a/backend/src/core/logging/__init__.py b/backend/src/core/logging/__init__.py index 5de63fd..5546157 100644 --- a/backend/src/core/logging/__init__.py +++ b/backend/src/core/logging/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -from core.logging import celery from core.logging.banner import log_service_banner from core.logging.config import configure_logging from core.logging.context import bind_context, clear_context, get_context @@ -8,7 +7,6 @@ from core.logging.logger import get_logger __all__ = [ "bind_context", - "celery", "clear_context", "configure_logging", "get_context", diff --git a/backend/src/core/logging/celery.py b/backend/src/core/logging/celery.py deleted file mode 100644 index 4ceba4a..0000000 --- a/backend/src/core/logging/celery.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass -from typing import cast - -from celery import Celery, signals - -from core.config.settings import Settings -from core.logging.banner import log_service_banner -from core.logging.config import configure_logging -from core.logging.context import bind_context, clear_context - - -@dataclass(frozen=True) -class CelerySignalHandlers: - on_setup_logging: Callable[..., None] - on_after_setup_task_logger: Callable[..., None] - on_task_prerun: Callable[..., None] - on_task_postrun: Callable[..., None] - - -def build_celery_signal_handlers( - settings: Settings | None = None, -) -> CelerySignalHandlers: - active_settings = settings or Settings() - - def on_setup_logging(*_args: object, **_kwargs: object) -> None: - configure_logging(settings) - log_service_banner( - service_name=active_settings.runtime.service_name, - environment=active_settings.runtime.environment, - ) - - def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None: - configure_logging(settings) - - def on_task_prerun(*_args: object, **kwargs: object) -> None: - task_id = cast(str | None, kwargs.get("task_id")) - task = kwargs.get("task") - task_name = getattr(task, "name", None) - bind_context(task_id=task_id, task_name=task_name) - - def on_task_postrun(*_args: object, **_kwargs: object) -> None: - clear_context() - - return CelerySignalHandlers( - on_setup_logging=on_setup_logging, - on_after_setup_task_logger=on_after_setup_task_logger, - on_task_prerun=on_task_prerun, - on_task_postrun=on_task_postrun, - ) - - -def configure_celery_app(app: Celery, settings: Settings | None = None) -> None: - app.conf.worker_hijack_root_logger = False - - handlers = build_celery_signal_handlers(settings) - signals.setup_logging.connect(handlers.on_setup_logging, weak=False) - signals.after_setup_task_logger.connect( - handlers.on_after_setup_task_logger, weak=False - ) - signals.task_prerun.connect(handlers.on_task_prerun, weak=False) - signals.task_postrun.connect(handlers.on_task_postrun, weak=False) diff --git a/backend/src/core/taskiq/__init__.py b/backend/src/core/taskiq/__init__.py new file mode 100644 index 0000000..7a93800 --- /dev/null +++ b/backend/src/core/taskiq/__init__.py @@ -0,0 +1,3 @@ +from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker + +__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"] diff --git a/backend/src/core/taskiq/app.py b/backend/src/core/taskiq/app.py new file mode 100644 index 0000000..0570f34 --- /dev/null +++ b/backend/src/core/taskiq/app.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend + +from core.config.settings import config +from core.logging import configure_logging + + +configure_logging(config) + +def _build_broker(queue_name: str) -> ListQueueBroker: + return ListQueueBroker( + url=config.taskiq_broker_url, + queue_name=queue_name, + ).with_result_backend( + RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url) + ) + + +default_broker = _build_broker("default") +critical_broker = _build_broker("critical") +bulk_broker = _build_broker("bulk") + +# Backward-compatible export name for existing imports/tests. +broker = default_broker + +__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"] diff --git a/backend/src/services/base/__init__.py b/backend/src/services/base/__init__.py index 7369fd5..d115d21 100644 --- a/backend/src/services/base/__init__.py +++ b/backend/src/services/base/__init__.py @@ -1,18 +1,28 @@ from __future__ import annotations -from services.base.redis import RedisService, redis_service +from services.base.redis import RedisService, get_or_init_redis_client, redis_service from services.base.service_interface import ( BaseServiceProvider, ServiceRegistry, + close_registered_services, + initialize_registered_services, register_service, register_service_instance, + resolve_registered_services, ) +from services.base.supabase import SupabaseService, supabase_service __all__ = [ "BaseServiceProvider", "RedisService", "ServiceRegistry", + "SupabaseService", + "close_registered_services", + "get_or_init_redis_client", + "initialize_registered_services", "redis_service", "register_service", "register_service_instance", + "resolve_registered_services", + "supabase_service", ] diff --git a/backend/src/services/base/redis.py b/backend/src/services/base/redis.py index 8d0cf79..d9b32fa 100644 --- a/backend/src/services/base/redis.py +++ b/backend/src/services/base/redis.py @@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider): return self._require_client() +async def get_or_init_redis_client() -> redis.Redis: + if not redis_service.is_initialized: + initialized = await redis_service.initialize() + if not initialized: + raise RuntimeError("Redis service initialization failed") + return redis_service.get_client() + + redis_service: RedisService = register_service_instance("redis", RedisService()) -__all__ = ["RedisService", "redis_service"] +__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"] diff --git a/backend/src/services/base/service_interface.py b/backend/src/services/base/service_interface.py index 0e96aa5..b516e8e 100644 --- a/backend/src/services/base/service_interface.py +++ b/backend/src/services/base/service_interface.py @@ -61,6 +61,12 @@ class ServiceRegistry: @classmethod def create_service( cls, service_name: str, **kwargs: Any + ) -> Optional[BaseServiceProvider]: + return cls.get_service(service_name, **kwargs) + + @classmethod + def get_service( + cls, service_name: str, **kwargs: Any ) -> Optional[BaseServiceProvider]: factory = cls.get_service_factory(service_name) if not factory: @@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider) def register_service_instance(service_name: str, service: TService) -> TService: ServiceRegistry.register(service_name, lambda: service) return service + + +def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]: + services: list[BaseServiceProvider] = [] + for service_name in service_names: + service = ServiceRegistry.get_service(service_name) + if service is None: + raise RuntimeError(f"Service is not registered: {service_name}") + services.append(service) + return services + + +async def close_registered_services(services: list[BaseServiceProvider]) -> bool: + lifecycle_logger = get_logger("services.base.lifecycle") + all_closed = True + for service in reversed(services): + try: + closed = await service.close() + except Exception as exc: # noqa: BLE001 + lifecycle_logger.warning( + "Failed to close service", + service=service.service_name, + error=str(exc), + ) + all_closed = False + continue + if not closed: + lifecycle_logger.warning( + "Service close returned false", + service=service.service_name, + ) + all_closed = False + return all_closed + + +async def initialize_registered_services( + service_names: list[str], +) -> tuple[bool, list[BaseServiceProvider]]: + lifecycle_logger = get_logger("services.base.lifecycle") + initialized_services: list[BaseServiceProvider] = [] + try: + services = resolve_registered_services(service_names) + except RuntimeError as exc: + lifecycle_logger.error("Failed to resolve registered services", error=str(exc)) + return False, [] + + for service in services: + try: + initialized = await service.initialize() + except Exception as exc: # noqa: BLE001 + lifecycle_logger.warning( + "Service initialization raised exception", + service=service.service_name, + error=str(exc), + ) + initialized = False + + if not initialized: + lifecycle_logger.error( + "Service initialization failed, rolling back", + service=service.service_name, + ) + await close_registered_services(initialized_services) + return False, [] + + initialized_services.append(service) + + return True, initialized_services diff --git a/backend/src/services/base/supabase.py b/backend/src/services/base/supabase.py new file mode 100644 index 0000000..640e8c7 --- /dev/null +++ b/backend/src/services/base/supabase.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +from supabase import create_client + +from core.config.settings import SupabaseSettings, config + +from .service_interface import BaseServiceProvider, register_service_instance + + +class SupabaseService(BaseServiceProvider): + def __init__(self, settings: SupabaseSettings | None = None) -> None: + super().__init__("supabase") + self._settings = settings or config.supabase + self._client: Any = None + self._admin_client: Any = None + + async def initialize(self, **_: Any) -> bool: + try: + self._client = create_client( + self._settings.url, + self._settings.anon_key, + ) + self._admin_client = create_client( + self._settings.url, + self._settings.service_role_key, + ) + self._set_initialized(True) + self.logger.info("Supabase service initialized") + return True + except Exception as exc: # noqa: BLE001 + self.logger.warning("Supabase service initialization failed", error=str(exc)) + self._client = None + self._admin_client = None + self._set_initialized(False) + return False + + async def close(self) -> bool: + self._client = None + self._admin_client = None + self._set_initialized(False) + self.logger.info("Supabase service closed") + return True + + async def health_check(self) -> dict[str, Any]: + client = self._client + admin_client = self._admin_client + if client is None or admin_client is None: + return {"status": "unhealthy", "details": {"error": "not initialized"}} + try: + await asyncio.to_thread(client.auth.get_session) + await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1) + return { + "status": "healthy", + "details": { + "anon_client": "ready", + "admin_client": "ready", + }, + } + except Exception as exc: # noqa: BLE001 + self.logger.warning("Supabase health check failed", error=str(exc)) + return {"status": "unhealthy", "details": {"error": str(exc)}} + + def get_client(self) -> Any: + return self._require_client() + + def get_admin_client(self) -> Any: + return self._require_admin_client() + + def _require_client(self) -> Any: + client = self._client + if client is None: + raise RuntimeError("Supabase client is not initialized") + return client + + def _require_admin_client(self) -> Any: + admin_client = self._admin_client + if admin_client is None: + raise RuntimeError("Supabase admin client is not initialized") + return admin_client + + +supabase_service: SupabaseService = register_service_instance( + "supabase", SupabaseService() +) + +__all__ = ["SupabaseService", "supabase_service"] diff --git a/backend/src/v1/agent/dependencies.py b/backend/src/v1/agent/dependencies.py index eb45b8a..acdd732 100644 --- a/backend/src/v1/agent/dependencies.py +++ b/backend/src/v1/agent/dependencies.py @@ -1,57 +1,98 @@ from __future__ import annotations -from typing import Any, cast +import asyncio +from typing import Any from uuid import UUID from fastapi import Depends +from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore -from core.agent.infrastructure.queue.tasks import run_command_task +from core.agent.infrastructure.queue.tasks import ( + run_command_task, + run_command_task_bulk, + run_command_task_critical, +) from core.config.settings import config from core.db import get_db -from services.base.redis import redis_service +from services.base.redis import get_or_init_redis_client from v1.agent.repository import AgentRepository from v1.agent.service import AgentService +DEDUP_WAIT_RETRIES = 20 +DEDUP_WAIT_SECONDS = 0.05 +DEDUP_LOCK_SECONDS = 300 +DEDUP_INFLIGHT_MARKER = "__inflight__" -class CeleryQueueClient: + +class TaskiqQueueClient: def __init__(self) -> None: - self._redis = redis_service.get_client() + self._redis: Redis | None = None + + async def _get_redis(self) -> Redis: + if self._redis is None: + self._redis = await get_or_init_redis_client() + return self._redis + + @staticmethod + def _select_queue_task(command: dict[str, object]) -> Any: + queue = str(command.get("queue", "default")).strip().lower() + if queue == "critical": + return run_command_task_critical + if queue == "bulk": + return run_command_task_bulk + return run_command_task async def enqueue( self, *, command: dict[str, object], dedup_key: str | None ) -> str: + redis_client = await self._get_redis() redis_key = None if dedup_key: redis_key = f"agent:dedup:{dedup_key}" - locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300) + locked = await redis_client.set( + redis_key, + DEDUP_INFLIGHT_MARKER, + nx=True, + ex=DEDUP_LOCK_SECONDS, + ) if not locked: - existing = await self._redis.get(redis_key) - if existing and existing != "__inflight__": - return existing + for _ in range(DEDUP_WAIT_RETRIES): + existing = await redis_client.get(redis_key) + if existing and existing != DEDUP_INFLIGHT_MARKER: + return existing + await asyncio.sleep(DEDUP_WAIT_SECONDS) + raise RuntimeError("duplicate request is still in progress") payload = dict(command) - if dedup_key: - payload["dedup_key"] = dedup_key - delay = getattr(run_command_task, "delay") - result = delay(payload) - task_id = str(result.id) - if redis_key is not None: - await self._redis.set(redis_key, task_id, ex=300) - return task_id + queue_task = self._select_queue_task(payload) + try: + result = await queue_task.kiq(payload) + task_id = str(result.task_id) + if redis_key is not None: + await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS) + return task_id + except Exception: + if redis_key is not None: + await redis_client.delete(redis_key) + raise class RedisEventStream: def __init__(self) -> None: - settings = cast(Any, config) - client = redis_service.get_client() - self._store = RedisStreamEventStore( - client=client, - stream_prefix=settings.agent_runtime.redis_stream_prefix, - read_count=settings.agent_runtime.redis_stream_read_count, - block_ms=settings.agent_runtime.redis_stream_block_ms, - ) + self._store: RedisStreamEventStore | None = None + + async def _get_store(self) -> RedisStreamEventStore: + if self._store is None: + client = await get_or_init_redis_client() + self._store = RedisStreamEventStore( + client=client, + stream_prefix=config.agent_runtime.redis_stream_prefix, + read_count=config.agent_runtime.redis_stream_read_count, + block_ms=config.agent_runtime.redis_stream_block_ms, + ) + return self._store async def read( self, @@ -59,7 +100,8 @@ class RedisEventStream: session_id: str, last_event_id: str | None, ) -> list[dict[str, Any]]: - rows = await self._store.read_events( + store = await self._get_store() + rows = await store.read_events( session_id=UUID(session_id), last_event_id=last_event_id, ) @@ -69,6 +111,6 @@ class RedisEventStream: def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService: return AgentService( repository=AgentRepository(session), - queue=CeleryQueueClient(), + queue=TaskiqQueueClient(), stream=RedisEventStream(), ) diff --git a/backend/src/v1/auth/gateway.py b/backend/src/v1/auth/gateway.py index efac7ba..521bc7f 100644 --- a/backend/src/v1/auth/gateway.py +++ b/backend/src/v1/auth/gateway.py @@ -5,10 +5,10 @@ from collections.abc import Mapping from typing import Any, cast from fastapi import HTTPException -from supabase import AuthError, create_client +from supabase import AuthError -from core.config.settings import SupabaseSettings, config from core.logging import get_logger +from services.base.supabase import supabase_service from v1.auth.schemas import ( AuthUser, PasswordResetConfirmRequest, @@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway") class SupabaseAuthGateway(AuthServiceGateway): - _client: Any - _admin_client: Any + def _get_client(self) -> Any: + return supabase_service.get_client() - def __init__(self) -> None: - settings: SupabaseSettings = config.supabase - self._client = create_client(settings.url, settings.anon_key) - self._admin_client = create_client(settings.url, settings.service_role_key) + def _get_admin_client(self) -> Any: + return supabase_service.get_admin_client() async def create_verification( self, request: VerificationCreateRequest ) -> VerificationCreateResponse: + client = self._get_client() metadata: dict[str, Any] = {"username": request.username} if request.invite_code: metadata["invite_code"] = request.invite_code @@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway): if request.redirect_to: payload["options"] = {"email_redirect_to": request.redirect_to} try: - sign_up = cast(Any, self._client.auth.sign_up) + sign_up = cast(Any, client.auth.sign_up) await asyncio.to_thread(sign_up, payload) return VerificationCreateResponse(email=request.email) except AuthError as exc: @@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway): async def verify_verification( self, request: VerificationVerifyRequest ) -> SessionResponse: + client = self._get_client() payload: dict[str, Any] = { "type": "signup", "email": request.email, "token": request.token, } try: - verify_otp = cast(Any, self._client.auth.verify_otp) + verify_otp = cast(Any, client.auth.verify_otp) response = await asyncio.to_thread(verify_otp, payload) return _map_auth_response(response, "Invalid verification code") except AuthError as exc: @@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway): ) from exc async def resend_verification(self, request: VerificationResendRequest) -> None: + client = self._get_client() payload: dict[str, Any] = {"type": "signup", "email": request.email} try: - resend = cast(Any, self._client.auth.resend) + resend = cast(Any, client.auth.resend) await asyncio.to_thread(resend, payload) except AuthError as exc: logger.warning("Signup resend failed", error_type=type(exc).__name__) async def create_session(self, request: SessionCreateRequest) -> SessionResponse: + client = self._get_client() payload: dict[str, Any] = {"email": request.email, "password": request.password} try: - sign_in = cast(Any, self._client.auth.sign_in_with_password) + sign_in = cast(Any, client.auth.sign_in_with_password) response = await asyncio.to_thread(sign_in, payload) return _map_auth_response(response, "Invalid credentials") except AuthError as exc: @@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway): raise HTTPException(status_code=401, detail="Invalid credentials") from exc async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse: + client = self._get_client() try: response = await asyncio.to_thread( - self._client.auth.refresh_session, + client.auth.refresh_session, request.refresh_token, ) return _map_auth_response(response, "Invalid refresh token") @@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway): async def delete_session(self, refresh_token: str | None) -> None: if not refresh_token: raise HTTPException(status_code=401, detail="Missing refresh token") + client = self._get_client() try: response = await asyncio.to_thread( - self._client.auth.refresh_session, + client.auth.refresh_session, refresh_token, ) session = getattr(response, "session", None) if session is None: raise HTTPException(status_code=401, detail="Invalid refresh token") await asyncio.to_thread( - self._client.auth.set_session, + client.auth.set_session, str(session.access_token), str(session.refresh_token), ) - await asyncio.to_thread(self._client.auth.sign_out) + await asyncio.to_thread(client.auth.sign_out) except AuthError as exc: logger.warning("Logout failed", error_type=type(exc).__name__) raise HTTPException( @@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway): ) from exc async def get_user_by_email(self, email: str) -> UserByEmailResponse: - users = await asyncio.to_thread(_list_auth_users, self._admin_client) + admin_client = self._get_admin_client() + users = await asyncio.to_thread(_list_auth_users, admin_client) normalized_email = email.lower() user = next( ( @@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway): ) async def request_password_reset(self, request: PasswordResetRequest) -> None: + client = self._get_client() try: - reset_email = cast(Any, self._client.auth.reset_password_email) + reset_email = cast(Any, client.auth.reset_password_email) email = _coerce_reset_email(request.email) if request.redirect_to: options: dict[str, str] = {"redirect_to": request.redirect_to} @@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway): async def confirm_password_reset( self, request: PasswordResetConfirmRequest ) -> None: + client = self._get_client() + admin_client = self._get_admin_client() verify_payload: dict[str, Any] = { "type": "recovery", "email": request.email, "token": request.token, } try: - verify_otp = cast(Any, self._client.auth.verify_otp) + verify_otp = cast(Any, client.auth.verify_otp) response = await asyncio.to_thread(verify_otp, verify_payload) session = getattr(response, "session", None) user = getattr(response, "user", None) @@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway): status_code=401, detail="Invalid or expired verification code" ) await asyncio.to_thread( - self._admin_client.auth.admin.update_user_by_id, + admin_client.auth.admin.update_user_by_id, user_id, {"password": request.new_password}, ) diff --git a/backend/tests/unit/core/agent/test_queue_tasks.py b/backend/tests/unit/core/agent/test_queue_tasks.py index c55813d..9c89f37 100644 --- a/backend/tests/unit/core/agent/test_queue_tasks.py +++ b/backend/tests/unit/core/agent/test_queue_tasks.py @@ -2,7 +2,7 @@ from __future__ import annotations import pytest -from core.agent.infrastructure.queue.tasks import run_agent_task +from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task class _FakeRunService: @@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None: ) assert events == ["RUN_STARTED", "RUN_ERROR"] + + +@pytest.mark.asyncio +async def test_run_agent_task_rejects_invalid_command() -> None: + with pytest.raises(ValueError, match="invalid command type"): + await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"}) + + +@pytest.mark.asyncio +async def test_run_agent_task_resume_requires_tool_call_id() -> None: + with pytest.raises(ValueError, match="tool_call_id is required"): + await run_agent_task( + { + "command": "resume", + "session_id": "00000000-0000-0000-0000-000000000001", + } + ) + + +@pytest.mark.asyncio +async def test_build_redis_publisher_init_fail_raises_runtime_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from core.agent.infrastructure.queue import tasks + + async def _fake_get_client() -> object: + raise RuntimeError("Redis service initialization failed") + + monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client) + + with pytest.raises(RuntimeError, match="Redis service initialization failed"): + await _build_redis_publisher() diff --git a/backend/tests/unit/core/config/test_taskiq_settings.py b/backend/tests/unit/core/config/test_taskiq_settings.py new file mode 100644 index 0000000..2ae1155 --- /dev/null +++ b/backend/tests/unit/core/config/test_taskiq_settings.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from core.config.settings import Settings + + +def test_taskiq_uses_redis_url_by_default() -> None: + settings = Settings() + + assert settings.taskiq_broker_url.startswith("redis://") + assert settings.taskiq_result_backend_url.startswith("redis://") diff --git a/backend/tests/unit/core/taskiq/test_app.py b/backend/tests/unit/core/taskiq/test_app.py new file mode 100644 index 0000000..ffe32d8 --- /dev/null +++ b/backend/tests/unit/core/taskiq/test_app.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import importlib +import sys + +import pytest + +from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker + + +def test_taskiq_broker_is_configured() -> None: + assert broker is not None + assert default_broker is broker + assert critical_broker is not None + assert bulk_broker is not None + + +def test_taskiq_app_configures_logging_on_import( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sys.modules.pop("core.taskiq.app", None) + sys.modules.pop("core.taskiq", None) + + called = {"count": 0, "args": None} + + def _fake_configure_logging(*args: object, **__: object) -> None: + called["count"] += 1 + called["args"] = args + + monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging) + + importlib.import_module("core.taskiq.app") + + from core.config.settings import config + + assert called["count"] == 1 + assert called["args"] == (config,) diff --git a/backend/tests/unit/infra/test_worker_runtime_script.py b/backend/tests/unit/infra/test_worker_runtime_script.py new file mode 100644 index 0000000..4ca858d --- /dev/null +++ b/backend/tests/unit/infra/test_worker_runtime_script.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pathlib import Path + + +ROOT_DIR = Path(__file__).resolve().parents[4] +APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh" + + +def test_worker_commands_use_taskiq() -> None: + content = APP_SCRIPT.read_text(encoding="utf-8") + removed_runner = "uv run c" "elery" + + assert "uv run taskiq worker" in content + assert "core.taskiq.app:critical_broker" in content + assert "core.taskiq.app:default_broker" in content + assert "core.taskiq.app:bulk_broker" in content + assert 'pgrep -f "taskiq.*worker"' in content + assert 'pkill -f "taskiq.*worker"' in content + assert removed_runner not in content diff --git a/backend/tests/unit/services/base/test_redis_service.py b/backend/tests/unit/services/base/test_redis_service.py index 706d0fd..d6725f1 100644 --- a/backend/tests/unit/services/base/test_redis_service.py +++ b/backend/tests/unit/services/base/test_redis_service.py @@ -3,7 +3,7 @@ from __future__ import annotations import pytest from core.config.settings import RedisSettings -from services.base.redis import RedisService +from services.base.redis import RedisService, get_or_init_redis_client, redis_service class _FakeRedisClient: @@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None: with pytest.raises(RuntimeError): service.get_client() + + +@pytest.mark.asyncio +async def test_get_or_init_redis_client_initializes_when_needed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_client = _FakeRedisClient() + + async def _fake_initialize() -> bool: + return True + + monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False)) + monkeypatch.setattr(redis_service, "initialize", _fake_initialize) + monkeypatch.setattr(redis_service, "get_client", lambda: fake_client) + + client = await get_or_init_redis_client() + + assert client is fake_client + + +@pytest.mark.asyncio +async def test_get_or_init_redis_client_raises_when_init_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def _fake_initialize() -> bool: + return False + + monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False)) + monkeypatch.setattr(redis_service, "initialize", _fake_initialize) + + with pytest.raises(RuntimeError, match="Redis service initialization failed"): + await get_or_init_redis_client() diff --git a/backend/tests/unit/services/base/test_service_registry.py b/backend/tests/unit/services/base/test_service_registry.py index f95250d..373d644 100644 --- a/backend/tests/unit/services/base/test_service_registry.py +++ b/backend/tests/unit/services/base/test_service_registry.py @@ -1,8 +1,12 @@ from __future__ import annotations +import pytest + from services.base.service_interface import ( BaseServiceProvider, ServiceRegistry, + close_registered_services, + initialize_registered_services, register_service, register_service_instance, ) @@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None: assert created.get_service_info()["name"] == "dummy" +def test_register_service_and_get_service() -> None: + @register_service("dummy-service-get") + class _RegisteredService(_DummyService): + pass + + resolved = ServiceRegistry.get_service("dummy-service-get") + + assert resolved is not None + assert resolved.get_service_info()["name"] == "dummy" + + def test_register_service_instance_returns_same_instance() -> None: instance = _DummyService("singleton") @@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None: def test_create_service_returns_none_for_missing() -> None: assert ServiceRegistry.create_service("missing-service") is None + + +def test_get_service_returns_none_for_missing() -> None: + assert ServiceRegistry.get_service("missing-service") is None + + +class _LifecycleService(BaseServiceProvider): + def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None: + super().__init__(name) + self._recorder = recorder + self._fail_on_init = fail_on_init + + async def initialize(self, **_: object) -> bool: + self._recorder.append(f"init:{self.service_name}") + if self._fail_on_init: + return False + self._set_initialized(True) + return True + + async def close(self) -> bool: + self._recorder.append(f"close:{self.service_name}") + self._set_initialized(False) + return True + + async def health_check(self) -> dict[str, object]: + return {"status": "healthy", "details": {}} + + +@pytest.mark.asyncio +async def test_initialize_registered_services_success() -> None: + recorder: list[str] = [] + first = register_service_instance( + "lifecycle-success-first", _LifecycleService("first", recorder) + ) + second = register_service_instance( + "lifecycle-success-second", _LifecycleService("second", recorder) + ) + + initialized, services = await initialize_registered_services( + ["lifecycle-success-first", "lifecycle-success-second"] + ) + + assert initialized is True + assert services == [first, second] + assert recorder == ["init:first", "init:second"] + + +@pytest.mark.asyncio +async def test_initialize_registered_services_failure_rolls_back() -> None: + recorder: list[str] = [] + register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder)) + register_service_instance( + "lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True) + ) + + initialized, services = await initialize_registered_services( + ["lifecycle-fail-first", "lifecycle-fail-second"] + ) + + assert initialized is False + assert services == [] + assert recorder == ["init:first", "init:second", "close:first"] + + +@pytest.mark.asyncio +async def test_close_registered_services_closes_in_reverse_order() -> None: + recorder: list[str] = [] + first = _LifecycleService("first", recorder) + second = _LifecycleService("second", recorder) + + closed = await close_registered_services([first, second]) + + assert closed is True + assert recorder == ["close:second", "close:first"] diff --git a/backend/tests/unit/services/base/test_supabase_service.py b/backend/tests/unit/services/base/test_supabase_service.py new file mode 100644 index 0000000..d0db5d2 --- /dev/null +++ b/backend/tests/unit/services/base/test_supabase_service.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.config.settings import SupabaseSettings +from services.base.supabase import SupabaseService + + +@pytest.mark.asyncio +async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None: + service = SupabaseService(settings=SupabaseSettings()) + anon_client = MagicMock() + admin_client = MagicMock() + + create_calls: list[tuple[str, str]] = [] + + def _fake_create_client(url: str, key: str) -> object: + create_calls.append((url, key)) + return anon_client if len(create_calls) == 1 else admin_client + + monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client) + + result = await service.initialize() + + assert result is True + assert service.is_initialized is True + assert service.get_client() is anon_client + assert service.get_admin_client() is admin_client + assert len(create_calls) == 2 + + +@pytest.mark.asyncio +async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None: + service = SupabaseService(settings=SupabaseSettings()) + + def _fake_create_client(_: str, __: str) -> object: + raise RuntimeError("boom") + + monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client) + + result = await service.initialize() + + assert result is False + assert service.is_initialized is False + with pytest.raises(RuntimeError): + service.get_client() + + +@pytest.mark.asyncio +async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None: + service = SupabaseService(settings=SupabaseSettings()) + + def _fake_create_client(_: str, __: str) -> object: + return MagicMock() + + monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client) + + assert await service.initialize() is True + assert await service.close() is True + assert service.is_initialized is False + with pytest.raises(RuntimeError): + service.get_client() + with pytest.raises(RuntimeError): + service.get_admin_client() + + +@pytest.mark.asyncio +async def test_health_check_uninitialized() -> None: + service = SupabaseService(settings=SupabaseSettings()) + + health = await service.health_check() + + assert health["status"] == "unhealthy" + + +@pytest.mark.asyncio +async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None: + service = SupabaseService(settings=SupabaseSettings()) + + anon_client = MagicMock() + anon_client.auth.get_session = MagicMock(return_value=None) + + admin_list_users = MagicMock(return_value=SimpleNamespace(users=[])) + admin_client = MagicMock() + admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users) + + create_sequence = [anon_client, admin_client] + + def _fake_create_client(_: str, __: str) -> object: + return create_sequence.pop(0) + + monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client) + + assert await service.initialize() is True + + health = await service.health_check() + + assert health["status"] == "healthy" + admin_list_users.assert_called_once_with(page=1, per_page=1) + + +def test_get_client_raises_before_init() -> None: + service = SupabaseService(settings=SupabaseSettings()) + + with pytest.raises(RuntimeError): + service.get_client() + with pytest.raises(RuntimeError): + service.get_admin_client() diff --git a/backend/tests/unit/test_app_lifespan.py b/backend/tests/unit/test_app_lifespan.py new file mode 100644 index 0000000..34ff4f8 --- /dev/null +++ b/backend/tests/unit/test_app_lifespan.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import pytest + +import app as app_module + + +@pytest.mark.asyncio +async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None: + initialized_services = [object(), object()] + calls: dict[str, object] = {} + + async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]: + calls["init_names"] = service_names + return True, initialized_services + + async def _fake_close(services: list[object]) -> bool: + calls["close_services"] = services + return True + + monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize) + monkeypatch.setattr(app_module, "close_registered_services", _fake_close) + + context = app_module.lifespan(app_module.app) + await context.__aenter__() + await context.__aexit__(None, None, None) + + assert calls["init_names"] == ["redis", "supabase"] + assert calls["close_services"] == initialized_services + + +@pytest.mark.asyncio +async def test_lifespan_raises_when_initialization_failed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]: + return False, [] + + monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize) + + context = app_module.lifespan(app_module.app) + with pytest.raises(RuntimeError, match="Service initialization failed"): + await context.__aenter__() diff --git a/backend/tests/unit/test_celery_logging.py b/backend/tests/unit/test_celery_logging.py deleted file mode 100644 index d58158a..0000000 --- a/backend/tests/unit/test_celery_logging.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from celery import Celery -from pytest import MonkeyPatch - -from core.logging import celery as celery_logging -from core.logging.context import clear_context, get_context - - -class DummyTask: - name: str = "tasks.sample" - - -def test_celery_prerun_binds_task_context() -> None: - handlers = celery_logging.build_celery_signal_handlers() - - handlers.on_task_prerun(task_id="task-123", task=DummyTask()) - context = get_context() - - assert context["task_id"] == "task-123" - assert context["task_name"] == "tasks.sample" - - clear_context() - - -def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None: - called = {"value": False} - - def fake_configure_logging(settings: object | None = None) -> None: - called["value"] = True - - monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging) - handlers = celery_logging.build_celery_signal_handlers() - - handlers.on_setup_logging() - - assert called["value"] is True - - -def test_configure_celery_app_disables_hijack() -> None: - app = Celery("test") - - celery_logging.configure_celery_app(app) - - assert app.conf.worker_hijack_root_logger is False diff --git a/backend/tests/unit/v1/agent/test_dependencies_queue.py b/backend/tests/unit/v1/agent/test_dependencies_queue.py new file mode 100644 index 0000000..7d6a342 --- /dev/null +++ b/backend/tests/unit/v1/agent/test_dependencies_queue.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import pytest + +from v1.agent.dependencies import TaskiqQueueClient + + +class _FakeRedis: + def __init__(self) -> None: + self.store: dict[str, str] = {} + self.delete_calls: list[str] = [] + + async def set( + self, + key: str, + value: str, + *, + nx: bool = False, + ex: int | None = None, + ) -> bool: + del ex + if nx and key in self.store: + return False + self.store[key] = value + return True + + async def get(self, key: str) -> str | None: + return self.store.get(key) + + async def delete(self, key: str) -> int: + self.delete_calls.append(key) + existed = 1 if key in self.store else 0 + self.store.pop(key, None) + return existed + + +class _FakeAsyncResult: + def __init__(self, task_id: str) -> None: + self.task_id = task_id + + +@pytest.mark.asyncio +async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None: + from v1.agent import dependencies as deps + + fake_redis = _FakeRedis() + resolved_client = {"value": False} + + async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult: + assert payload["command"] == "run" + return _FakeAsyncResult("task-123") + + async def _fake_get_or_init_client() -> _FakeRedis: + resolved_client["value"] = True + return fake_redis + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq) + + client = TaskiqQueueClient() + task_id = await client.enqueue(command={"command": "run"}, dedup_key=None) + + assert resolved_client["value"] is True + assert task_id == "task-123" + + +@pytest.mark.asyncio +async def test_enqueue_resume_dedup_returns_existing_task_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from v1.agent import dependencies as deps + + fake_redis = _FakeRedis() + resolved_client = {"value": False} + + async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult: + del payload + return _FakeAsyncResult("new-task-id") + + async def _fake_get_or_init_client() -> _FakeRedis: + resolved_client["value"] = True + return fake_redis + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq) + + dedup_key = "resume:session-1:call-1" + fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id" + + client = TaskiqQueueClient() + task_id = await client.enqueue( + command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"}, + dedup_key=dedup_key, + ) + + assert resolved_client["value"] is True + assert task_id == "existing-task-id" + + +@pytest.mark.asyncio +async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from v1.agent import dependencies as deps + + fake_redis = _FakeRedis() + dedup_key = "resume:session-1:call-1" + redis_key = f"agent:dedup:{dedup_key}" + fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER + attempts = {"count": 0} + + async def _fake_get_or_init_client() -> _FakeRedis: + return fake_redis + + async def _fake_get(key: str) -> str | None: + attempts["count"] += 1 + if attempts["count"] > 1: + fake_redis.store[key] = "existing-task-id" + return fake_redis.store.get(key) + + async def _fake_sleep(_: float) -> None: + return None + + async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult: + del payload + raise AssertionError("should not enqueue when dedup task id appears") + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(fake_redis, "get", _fake_get) + monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep) + monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq) + + client = TaskiqQueueClient() + task_id = await client.enqueue( + command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"}, + dedup_key=dedup_key, + ) + + assert task_id == "existing-task-id" + + +@pytest.mark.asyncio +async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None: + from v1.agent import dependencies as deps + + fake_redis = _FakeRedis() + dedup_key = "resume:session-1:call-1" + redis_key = f"agent:dedup:{dedup_key}" + + async def _fake_get_or_init_client() -> _FakeRedis: + return fake_redis + + async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult: + del payload + raise RuntimeError("enqueue failed") + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq) + + client = TaskiqQueueClient() + with pytest.raises(RuntimeError, match="enqueue failed"): + await client.enqueue( + command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"}, + dedup_key=dedup_key, + ) + + assert redis_key in fake_redis.delete_calls + + +@pytest.mark.asyncio +async def test_enqueue_uses_critical_queue_when_requested( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from v1.agent import dependencies as deps + + fake_redis = _FakeRedis() + + async def _fake_get_or_init_client() -> _FakeRedis: + return fake_redis + + async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult: + raise AssertionError("default queue should not be selected") + + async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult: + return _FakeAsyncResult("critical-task-id") + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq) + monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq) + + client = TaskiqQueueClient() + task_id = await client.enqueue( + command={"command": "run", "queue": "critical"}, + dedup_key=None, + ) + + assert task_id == "critical-task-id" + + +@pytest.mark.asyncio +async def test_enqueue_uses_bulk_queue_when_requested( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from v1.agent import dependencies as deps + + fake_redis = _FakeRedis() + + async def _fake_get_or_init_client() -> _FakeRedis: + return fake_redis + + async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult: + raise AssertionError("default queue should not be selected") + + async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult: + return _FakeAsyncResult("bulk-task-id") + + monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client) + monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq) + monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq) + + client = TaskiqQueueClient() + task_id = await client.enqueue( + command={"command": "run", "queue": "bulk"}, + dedup_key=None, + ) + + assert task_id == "bulk-task-id" diff --git a/backend/tests/unit/v1/auth/test_auth_gateway.py b/backend/tests/unit/v1/auth/test_auth_gateway.py index ec2b64c..e53be50 100644 --- a/backend/tests/unit/v1/auth/test_auth_gateway.py +++ b/backend/tests/unit/v1/auth/test_auth_gateway.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from fastapi import HTTPException @@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest class TestSupabaseAuthGateway: @pytest.fixture - def gateway(self) -> SupabaseAuthGateway: - with patch("v1.auth.gateway.create_client") as mock_create: - mock_client = MagicMock() - mock_admin_client = MagicMock() - mock_create.side_effect = [mock_client, mock_admin_client] - return SupabaseAuthGateway() + def gateway( + self, monkeypatch: pytest.MonkeyPatch + ) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]: + mock_client = MagicMock() + mock_admin_client = MagicMock() + monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client) + monkeypatch.setattr( + "v1.auth.gateway.supabase_service.get_admin_client", + lambda: mock_admin_client, + ) + return SupabaseAuthGateway(), mock_client, mock_admin_client @pytest.mark.asyncio async def test_request_password_reset_calls_email_with_string( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, mock_client, _ = gateway mock_reset_email = MagicMock() - gateway._client.auth.reset_password_email = mock_reset_email + mock_client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest(email="test@example.com") - await gateway.request_password_reset(request) + await sut.request_password_reset(request) mock_reset_email.assert_called_once_with("test@example.com") @pytest.mark.asyncio async def test_request_password_reset_with_redirect( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, mock_client, _ = gateway mock_reset_email = MagicMock() - gateway._client.auth.reset_password_email = mock_reset_email + mock_client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest( email="test@example.com", redirect_to="http://localhost:3000/reset-password", ) - await gateway.request_password_reset(request) + await sut.request_password_reset(request) mock_reset_email.assert_called_once_with( "test@example.com", @@ -51,64 +58,68 @@ class TestSupabaseAuthGateway: @pytest.mark.asyncio async def test_request_password_reset_swallows_auth_error( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, mock_client, _ = gateway from supabase import AuthError mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None)) - gateway._client.auth.reset_password_email = mock_reset_email + mock_client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest(email="test@example.com") - result = await gateway.request_password_reset(request) + result = await sut.request_password_reset(request) mock_reset_email.assert_called_once() assert result is None @pytest.mark.asyncio async def test_request_password_reset_extracts_email_from_mapping( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, mock_client, _ = gateway mock_reset_email = MagicMock() - gateway._client.auth.reset_password_email = mock_reset_email + mock_client.auth.reset_password_email = mock_reset_email request = PasswordResetRequest.model_construct( email={"email": "test@example.com"}, redirect_to=None, ) - await gateway.request_password_reset(request) + await sut.request_password_reset(request) mock_reset_email.assert_called_once_with("test@example.com") @pytest.mark.asyncio async def test_request_password_reset_rejects_invalid_email_shape( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, _, _ = gateway request = PasswordResetRequest.model_construct( email={"unexpected": "value"}, redirect_to=None, ) with pytest.raises(HTTPException) as exc_info: - await gateway.request_password_reset(request) + await sut.request_password_reset(request) assert exc_info.value.status_code == 422 assert exc_info.value.detail == "Invalid email" @pytest.mark.asyncio async def test_confirm_password_reset_updates_password_by_user_id( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, mock_client, mock_admin_client = gateway verify_response = SimpleNamespace( session=SimpleNamespace(access_token="access"), user=SimpleNamespace(id="user-1"), ) mock_verify_otp = MagicMock(return_value=verify_response) - gateway._client.auth.verify_otp = mock_verify_otp + mock_client.auth.verify_otp = mock_verify_otp mock_update_user_by_id = MagicMock() - gateway._admin_client.auth.admin = SimpleNamespace( + mock_admin_client.auth.admin = SimpleNamespace( update_user_by_id=mock_update_user_by_id ) @@ -118,7 +129,7 @@ class TestSupabaseAuthGateway: new_password="newpassword123", ) - await gateway.confirm_password_reset(request) + await sut.confirm_password_reset(request) mock_verify_otp.assert_called_once_with( { @@ -134,13 +145,14 @@ class TestSupabaseAuthGateway: @pytest.mark.asyncio async def test_confirm_password_reset_raises_when_user_id_missing( - self, gateway: SupabaseAuthGateway + self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock] ) -> None: + sut, mock_client, _ = gateway verify_response = SimpleNamespace( session=SimpleNamespace(access_token="access"), user=SimpleNamespace(id=""), ) - gateway._client.auth.verify_otp = MagicMock(return_value=verify_response) + mock_client.auth.verify_otp = MagicMock(return_value=verify_response) request = PasswordResetConfirmRequest( email="test@example.com", @@ -149,7 +161,7 @@ class TestSupabaseAuthGateway: ) with pytest.raises(HTTPException) as exc_info: - await gateway.confirm_password_reset(request) + await sut.confirm_password_reset(request) assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid or expired verification code" diff --git a/docs/plans/2026-03-06-supabase-service-design.md b/docs/plans/2026-03-06-supabase-service-design.md new file mode 100644 index 0000000..d46a2aa --- /dev/null +++ b/docs/plans/2026-03-06-supabase-service-design.md @@ -0,0 +1,260 @@ +# Supabase 统一服务生命周期设计(优化版) + +**Date:** 2026-03-06 +**Status:** Draft + +--- + +## 0. Intake Contract + +- Objective: 将 Supabase 客户端纳入统一服务生命周期管理,避免每次请求重复创建客户端。 +- Deliverable: 新增 `SupabaseService`,并基于 `service_interface.py` 的 `ServiceRegistry` 提供统一初始化/关闭路径,完成 auth 侧迁移。 +- Constraints: + - 保持现有 `core.config.settings` 的配置读取行为不变。 + - 不引入 `os.environ` 直接读取。 + - 不改变现有 API 语义。 +- Verification target: + - 通过单元测试证明 Supabase 服务初始化、关闭、健康检查行为。 + - 通过应用启动测试证明统一初始化流程可用。 + - 通过 auth 相关测试证明迁移后业务行为一致。 + +--- + +## 1. 复杂度与风险分级 + +- Complexity: `S2` + - 原因:涉及多文件改造(`services/base`、`app.py`、`v1/auth`、测试)。 +- Risk Tier: `L1` + - 原因:涉及应用启动链路和认证网关依赖,但不改变对外接口契约。 + +L1 Gate 要求:执行 `refactor-cleaner` 审视冗余与结构风险(`code-reviewer` 可选)。 + +--- + +## 2. 现状与问题 + +### 2.1 当前现状 + +- `SupabaseAuthGateway` 在 `__init__` 内直接 `create_client(...)`,每次实例化都会创建 anon/admin 客户端。 +- `get_auth_service()` 当前每次请求都会 new `SupabaseAuthGateway()`,导致客户端重复构造。 +- `ServiceRegistry` 已存在,但目前主要用于注册,应用启动仍是手写逐个初始化。 + +### 2.2 核心问题 + +1. 生命周期不统一:Supabase 没有接入应用启动/关闭的统一管理。 +2. 初始化代码重复趋势:服务增多后,`app.py` 的 lifespan 会继续膨胀。 +3. 网关构造时机风险:若在应用未初始化阶段取客户端,可能抛运行时异常。 + +--- + +## 3. 优化设计(推荐方案) + +### 3.1 方案摘要 + +在 `service_interface.py` 基础上新增统一生命周期函数,按服务名列表批量初始化/关闭;`app.py` 仅声明服务顺序,减少样板代码。Supabase 使用 `config.supabase` 作为默认配置来源,保持 settings 行为一致。 + +### 3.2 目标文件结构 + +```text +backend/src/services/base/ +├── __init__.py +├── service_interface.py # 扩展:统一生命周期函数 +├── redis.py +└── supabase.py # 新增 +``` + +### 3.3 service_interface 统一初始化能力(新增) + +在 `service_interface.py` 新增以下函数(建议命名): + +- `resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]` +- `initialize_registered_services(service_names: list[str]) -> tuple[bool, list[BaseServiceProvider]]` +- `close_registered_services(services: list[BaseServiceProvider]) -> bool` + +约束与行为: + +1. 初始化按 `service_names` 顺序执行。 +2. 任一服务初始化失败时: + - 返回 `False`。 + - 对已成功初始化的服务按逆序执行关闭回滚。 +3. 关闭按逆序执行,最大化依赖安全性。 +4. 日志必须包含失败服务名和错误摘要。 + +这样 `app.py` 只需声明: + +```python +SERVICE_STARTUP_ORDER = ["redis", "supabase"] +``` + +并调用统一函数,减少重复初始化样板。 + +### 3.4 SupabaseService 设计 + +`supabase.py` 关键点: + +- 继承 `BaseServiceProvider`。 +- 构造函数签名: + - `def __init__(self, settings: SupabaseSettings | None = None) -> None` + - 默认 `settings or config.supabase`,确保与当前配置源一致。 +- `initialize()`:创建 anon/admin 两个 client,失败返回 `False`。 +- `close()`: + - 清空 `_client`、`_admin_client`。 + - `self._set_initialized(False)`。 +- `health_check()`: + - 必须进行至少一个轻量真实请求验证,不仅检查本地对象存在。 + - 返回结构与 `RedisService.health_check()`风格一致(`status + details`)。 + +注册方式: + +```python +supabase_service: SupabaseService = register_service_instance( + "supabase", SupabaseService() +) +``` + +### 3.5 app.py 改造 + +当前手写 `redis_service.initialize()` 改为调用统一初始化函数。 + +目标行为: + +1. 启动阶段: + - 调用 `initialize_registered_services(["redis", "supabase"])`。 + - 失败则 `raise RuntimeError("Service initialization failed")`。 +2. 关闭阶段: + - 调用 `close_registered_services(initialized_services)`。 + +### 3.6 AuthGateway 迁移策略(避免构造时机问题) + +不建议在 `SupabaseAuthGateway.__init__` 里立即绑定 client;改为按需获取: + +- 保留网关对象轻量化。 +- 在每个业务方法内部通过 `supabase_service.get_client()` / `get_admin_client()` 取实例。 + +优点: + +1. 避免模块导入或依赖构建阶段误触未初始化 client。 +2. 对 `users/dependencies.py` 中全局缓存 gateway 的场景更安全。 +3. 不改变业务层接口。 + +--- + +## 4. 配置与兼容性保证 + +### 4.1 settings/config 行为不变 + +迁移后依然通过 `core.config.settings.config.supabase` 读取: + +- `url` +- `anon_key` +- `service_role_key` +- `jwt_secret`(JWT 校验现有逻辑继续使用) + +### 4.2 环境变量兼容 + +由于 `Settings` + `env_nested_delimiter` 机制不变,现有环境变量命名与 `.env` 内容无需修改。 + +### 4.3 对现有代码影响 + +- API 层 schema/路由不变。 +- 认证行为不变。 +- 仅优化客户端生命周期与启动流程。 + +--- + +## 5. 实施计划(可执行) + +### Task 1: 扩展统一生命周期接口 + +**Files** +- Modify: `backend/src/services/base/service_interface.py` +- Test: `backend/tests/unit/services/base/test_service_interface.py`(新增) + +**Steps** +1. 写失败测试:初始化顺序、失败回滚、关闭逆序。 +2. 实现生命周期函数。 +3. 跑单测确认通过。 + +### Task 2: 新增 SupabaseService + +**Files** +- Create: `backend/src/services/base/supabase.py` +- Modify: `backend/src/services/base/__init__.py` +- Test: `backend/tests/unit/services/base/test_supabase.py` + +**Steps** +1. 写失败测试(init success/fail、close、health_check)。 +2. 实现 `SupabaseService` 与实例注册。 +3. 跑单测。 + +### Task 3: 接入 app lifespan 统一初始化 + +**Files** +- Modify: `backend/src/app.py` +- Test: `backend/tests/integration/test_app_lifespan.py`(新增或扩展) + +**Steps** +1. 写失败测试(supabase init fail 时应用启动失败)。 +2. 替换手写初始化为统一函数。 +3. 跑集成测试。 + +### Task 4: 迁移 AuthGateway 获取 client 方式 + +**Files** +- Modify: `backend/src/v1/auth/gateway.py` +- Optional Modify: `backend/src/v1/auth/dependencies.py` +- Optional Modify: `backend/src/v1/users/dependencies.py` +- Test: `backend/tests/unit/v1/auth/test_gateway.py`(扩展) + +**Steps** +1. 写失败测试(未初始化时错误、初始化后正常调用)。 +2. 改为方法内按需取 client。 +3. 跑 auth 相关单测。 + +### Task 5: 全量验证与门禁 + +**Commands** +- `uv run ruff check backend/src backend/tests` +- `uv run basedpyright` +- `uv run pytest backend/tests/unit/services/base -q` +- `uv run pytest backend/tests/unit/v1/auth -q` +- `uv run pytest backend/tests/integration -q` + +输出要求:记录每条命令 pass/fail 与关键摘要。 + +--- + +## 6. 验收标准(更新) + +- [ ] `SupabaseService` 继承 `BaseServiceProvider` 并注册到 `ServiceRegistry` +- [ ] `service_interface.py` 提供统一初始化/关闭函数 +- [ ] `app.py` 通过统一函数初始化 `redis + supabase` +- [ ] Supabase 配置读取仍仅来自 `core.config.settings.config` +- [ ] `auth/gateway.py` 不再在 `__init__` 新建客户端 +- [ ] 初始化失败具备回滚关闭逻辑 +- [ ] 单元/集成测试覆盖核心迁移路径并通过 + +--- + +## 7. 风险与缓解 + +| 风险 | 级别 | 缓解 | +|---|---|---| +| 统一初始化函数引入顺序错误 | 中 | 显式 `SERVICE_STARTUP_ORDER` + 顺序测试 | +| Supabase 健康检查误报 | 中 | 使用真实轻量请求,不只做对象检查 | +| gateway 与生命周期耦合导致运行时错误 | 中 | 改为方法内按需取 client,并覆盖未初始化测试 | +| 迁移影响现有 auth 行为 | 中 | 保持 service 接口不变,补充回归测试 | + +--- + +## 8. 完成定义(Completion Contract) + +1. Complexity: `S2` +2. Risk Tier: `L1` +3. Gates: + - 必需:`refactor-cleaner` + - 可选:`code-reviewer`(建议在合并前执行) +4. Verification evidence: + - 提供 lint/typecheck/unit/integration 命令结果 +5. Remaining risks/follow-ups: + - 若后续新增第三方服务,沿用 `ServiceRegistry + 统一生命周期函数` 接入,不再在 `app.py` 手写初始化。 diff --git a/docs/plans/2026-03-06-taskiq-migration.md b/docs/plans/2026-03-06-taskiq-migration.md new file mode 100644 index 0000000..0c8e1a6 --- /dev/null +++ b/docs/plans/2026-03-06-taskiq-migration.md @@ -0,0 +1,359 @@ +# Celery To Taskiq One-Shot Migration Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 在当前早期项目中一次性移除 Celery,并以 Taskiq 替换异步任务基础设施,保持 agent runtime 行为不变。 + +**Architecture:** 复用现有 `AgentService -> QueueClientLike` 抽象,仅替换基础设施层实现(任务声明、入队调用、worker 启动、配置与依赖)。保持 Redis 作为 broker/result 存储与事件流通道,避免改动业务服务层语义。 + +**Tech Stack:** FastAPI, Taskiq, taskiq-redis, Redis, pytest, uv + +--- + +### Task 1: 依赖与配置切换(先 RED 后 GREEN) + +**Files:** +- Modify: `pyproject.toml` +- Modify: `backend/src/core/config/settings.py` +- Test: `backend/tests/unit/core/config/test_taskiq_settings.py` (new) + +**Step 1: Write the failing test** + +```python +from core.config.settings import Settings + + +def test_taskiq_uses_redis_url_by_default() -> None: + settings = Settings() + assert settings.taskiq_broker_url.startswith("redis://") + + +def test_taskiq_queue_default_value() -> None: + settings = Settings() + assert settings.taskiq.default_queue == "default" +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/core/config/test_taskiq_settings.py -v` +Expected: FAIL(`taskiq_broker_url` / `taskiq` 字段不存在) + +**Step 3: Write minimal implementation** + +```python +class TaskiqSettings(BaseModel): + broker_url: str | None = None + result_backend_url: str | None = None + default_queue: str = "default" + +class Settings(BaseSettings): + taskiq: TaskiqSettings = TaskiqSettings() + + @computed_field + @property + def taskiq_broker_url(self) -> str: + return self.taskiq.broker_url or self.redis.url + + @computed_field + @property + def taskiq_result_backend_url(self) -> str: + return self.taskiq.result_backend_url or self.redis.url +``` + +`pyproject.toml` 同步变更: +- 删除 `celery>=...` +- 增加 `taskiq>=...` +- 增加 `taskiq-redis>=...` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/core/config/test_taskiq_settings.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add pyproject.toml backend/src/core/config/settings.py backend/tests/unit/core/config/test_taskiq_settings.py +git commit -m "refactor(queue): replace celery config with taskiq settings" +``` + +### Task 2: 新建 Taskiq broker 与 worker 启动入口 + +**Files:** +- Create: `backend/src/core/taskiq/app.py` +- Create: `backend/tests/unit/core/taskiq/test_app.py` +- Delete: `backend/src/core/celery/app.py` + +**Step 1: Write the failing test** + +```python +from core.taskiq.app import broker + + +def test_taskiq_broker_is_configured() -> None: + assert broker is not None +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/core/taskiq/test_app.py -v` +Expected: FAIL(模块不存在) + +**Step 3: Write minimal implementation** + +```python +from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend +from core.config.settings import config + +broker = ListQueueBroker(url=config.taskiq_broker_url).with_result_backend( + RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url) +) +``` + +说明:若当前 `taskiq-redis` 版本 API 名称有差异,以该版本官方 API 为准做等价实现。 + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/core/taskiq/test_app.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/core/taskiq/app.py backend/tests/unit/core/taskiq/test_app.py backend/src/core/celery/app.py +git commit -m "feat(queue): add taskiq broker app and remove celery app" +``` + +### Task 3: 迁移任务定义(Celery task -> Taskiq task) + +**Files:** +- Modify: `backend/src/core/agent/infrastructure/queue/tasks.py` +- Test: `backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py` (new) + +**Step 1: Write the failing test** + +```python +from core.agent.infrastructure.queue.tasks import run_agent_task + + +async def test_run_agent_task_invalid_command_raises() -> None: + try: + await run_agent_task({"command": "unknown", "session_id": "00000000-0000-0000-0000-000000000001"}) + raise AssertionError("expected ValueError") + except ValueError as exc: + assert "invalid command type" in str(exc) +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py -v` +Expected: FAIL(测试文件不存在或导入失败) + +**Step 3: Write minimal implementation** + +```python +from core.taskiq.app import broker + +@broker.task(task_name="tasks.agent.run_command") +async def run_command_task(command: dict[str, Any]) -> dict[str, object]: + return await run_agent_task(command) +``` + +并移除: +- `from core.celery.app import celery_app` +- `@celery_app.task(...)` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/core/agent/infrastructure/queue/tasks.py backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py +git commit -m "refactor(agent): migrate run command task to taskiq" +``` + +### Task 4: 迁移 API 入队客户端(.delay -> .kiq) + +**Files:** +- Modify: `backend/src/v1/agent/dependencies.py` +- Test: `backend/tests/unit/v1/agent/test_dependencies_queue.py` (new) + +**Step 1: Write the failing test** + +```python +class _FakeTask: + async def kiq(self, payload: dict[str, object]): + class _Result: + task_id = "task-123" + return _Result() + + +async def test_enqueue_returns_task_id(monkeypatch): + from v1.agent.dependencies import CeleryQueueClient + client = CeleryQueueClient() # 迁移后应重命名为 TaskiqQueueClient + monkeypatch.setattr("v1.agent.dependencies.run_command_task", _FakeTask()) + task_id = await client.enqueue(command={"command": "run"}, dedup_key=None) + assert task_id == "task-123" +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_dependencies_queue.py -v` +Expected: FAIL(类型/方法不匹配) + +**Step 3: Write minimal implementation** + +```python +class TaskiqQueueClient: + async def enqueue(self, *, command: dict[str, object], dedup_key: str | None) -> str: + payload = dict(command) + if dedup_key: + payload["dedup_key"] = dedup_key + + result = await run_command_task.kiq(payload) + task_id = str(result.task_id) + return task_id +``` + +并替换 DI: + +```python +queue=TaskiqQueueClient() +``` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_dependencies_queue.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/v1/agent/dependencies.py backend/tests/unit/v1/agent/test_dependencies_queue.py +git commit -m "refactor(api): switch agent enqueue client from celery to taskiq" +``` + +### Task 5: 运维脚本与日志测试清理(一次性删除 Celery) + +**Files:** +- Modify: `infra/scripts/app.sh` +- Delete: `backend/tests/unit/test_celery_logging.py` +- Modify/Create: `backend/tests/unit/core/logging/test_taskiq_logging.py` (if taskiq logging hook implemented) +- Modify: `backend/src/core/logging/__init__.py`(移除 celery logging export) + +**Step 1: Write the failing test** + +```python +def test_worker_command_uses_taskiq() -> None: + content = Path("infra/scripts/app.sh").read_text() + assert "uv run taskiq worker" in content + assert "uv run celery" not in content +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/core/logging/test_taskiq_logging.py -v` +Expected: FAIL(脚本仍含 celery) + +**Step 3: Write minimal implementation** + +`infra/scripts/app.sh` worker 命令替换为 Taskiq worker,例如: + +```bash +uv run taskiq worker core.taskiq.app:broker core.agent.infrastructure.queue.tasks +``` + +删除所有 celery 进程清理匹配: + +```bash +pgrep -f "taskiq.*worker" +pkill -f "taskiq.*worker" +``` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/core/logging/test_taskiq_logging.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add infra/scripts/app.sh backend/src/core/logging/__init__.py backend/tests/unit/core/logging/test_taskiq_logging.py backend/tests/unit/test_celery_logging.py +git commit -m "chore(infra): replace celery worker scripts and remove celery-specific tests" +``` + +### Task 6: 全量引用清理与回归验证 + +**Files:** +- Modify: `docs/runtime/runtime-runbook.md` +- Modify: 其他引用 Celery 的运行文档(按 `rg` 结果逐个更新) + +**Step 1: Write the failing test** + +```python +# 用命令断言替代代码测试 +# rg -n "celery" backend/src infra/scripts docs/runtime pyproject.toml +``` + +**Step 2: Run check to verify it fails** + +Run: `rg -n "celery" backend/src infra/scripts docs/runtime pyproject.toml` +Expected: 仍有旧引用 + +**Step 3: Write minimal implementation** + +- 删除/替换剩余 Celery 代码、文档、配置。 +- 保留历史变更记录中的 Celery 字样(如 bugs 归档)可接受,但运行路径必须为 0 引用。 + +**Step 4: Run verification suite** + +Run: +- `uv run pytest backend/tests/unit -q` +- `uv run pytest backend/tests/integration -q` +- `uv run pytest backend/tests/e2e -q`(如环境不满足,记录原因) +- `uv run ruff check backend/src backend/tests` +- `uv run basedpyright` +- `rg -n "celery" backend/src infra/scripts pyproject.toml` + +Expected: +- 测试与静态检查通过 +- 运行路径无 Celery 引用 + +**Step 5: Commit** + +```bash +git add docs/runtime/runtime-runbook.md pyproject.toml backend/src infra/scripts backend/tests +git commit -m "refactor(queue): complete one-shot migration from celery to taskiq" +``` + +### Task 7: L1 Review Gates 与交付确认 + +**Files:** +- No code changes required by default + +**Step 1: Run required L1 gate (`refactor-cleaner`)** + +Run: 使用 `refactor-cleaner` 审查迁移后冗余代码、死引用、命名一致性。 +Expected: 无阻断问题。 + +**Step 2: Optional `code-reviewer` (recommended for infra switch)** + +Run: 使用 `code-reviewer` 聚焦任务丢失、重复消费、幂等锁逻辑。 +Expected: 无 CRITICAL/HIGH 问题。 + +**Step 3: Final evidence report** + +输出内容必须包含: +- 执行命令列表 +- 每条命令 PASS/FAIL +- 若有无法执行项(如 e2e 环境),给出原因与人工验证步骤 + +**Step 4: Commit review notes (optional)** + +```bash +git add docs/plans/2026-03-06-taskiq-migration.md +git commit -m "docs(plan): taskiq one-shot migration execution checklist" +``` diff --git a/docs/plans/agent-llm-config.md b/docs/plans/agent-llm-config.md deleted file mode 100644 index b961b2d..0000000 --- a/docs/plans/agent-llm-config.md +++ /dev/null @@ -1,144 +0,0 @@ -# Agent LLM Config Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** 将 `system_agents.config` 中的 `temperature` / `max_tokens` 以受约束方式加载到运行时,并在调用 LiteLLM 时按需透传。 - -**Architecture:** 在应用层 `RunService` 读取模型选择时同步读取并校验 `SystemAgents.config`;将校验后的 `SystemAgentLLMConfig` 传入 `CrewAIRuntime`;由 runtime 将配置转交给 LiteLLM client,client 仅在值非 `None` 时向 `completion()` 传参,避免不必要的 provider 兼容风险。 - -**Tech Stack:** FastAPI, SQLAlchemy (async), Pydantic v2, LiteLLM, pytest - ---- - -## 背景与修正点 - -- 当前真实调用链为:`RunService._load_agent_model_selection()` -> `create_runtime()` -> `CrewAIRuntime.execute()` -> `run_completion()`,并非 `load_stage_models()`。 -- `SystemAgentLLMConfig` 已存在:`backend/src/core/agent/domain/system_agent_config.py`。 -- `system_agents.config` 目前在初始化 YAML 侧有约束,但运行时 DB 读取仍需二次校验,防止脏数据绕过。 - -## 规则约束 - -- 严格 TDD:先写失败测试,再做实现。 -- Python 命令统一使用 `uv run ...`。 -- 仅做增量改动,不回滚或覆盖与本任务无关的已有变更。 - -## 字段映射与透传策略 - -| 配置字段 | LiteLLM 参数 | 规则 | -|---|---|---| -| `temperature` | `temperature` | `None` 不透传;非空直接透传 | -| `max_tokens` | `max_tokens` | `None` 不透传;非空直接透传 | - ---- - -### Task 1: 应用层加载并校验 Agent LLM Config - -**Files:** -- Modify: `backend/src/core/agent/application/run_service.py` -- Test: `backend/tests/unit/core/agent/test_run_resume_service.py` - -**Step 1: 写失败测试(RED)** - -新增单测覆盖以下行为: -1. `_load_agent_model_selection()` 返回三元组:`(model_code, provider_name, llm_config)`。 -2. 当 DB `config` 为 `{}` 时,`llm_config.temperature/max_tokens` 为 `None`。 -3. 当 DB `config` 含非法值(如 `temperature=3`)时抛 `ValueError`。 - -**Step 2: 运行测试确认失败** - -Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q` -Expected: 新增断言失败(返回值结构/异常行为不匹配)。 - -**Step 3: 最小实现(GREEN)** - -在 `run_service.py`: -1. 查询 `SystemAgents.config`。 -2. 用 `SystemAgentLLMConfig.model_validate(config or {})` 校验。 -3. 将 `_load_agent_model_selection()` 改为返回三元组。 -4. 在 `run()` 中把 `llm_config` 传递到 `create_runtime(...)`。 - -**Step 4: 运行测试确认通过** - -Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q` -Expected: PASS。 - ---- - -### Task 2: Runtime 与 LiteLLM Client 支持可选参数透传 - -**Files:** -- Modify: `backend/src/core/agent/infrastructure/crewai/factory.py` -- Modify: `backend/src/core/agent/infrastructure/crewai/runtime.py` -- Modify: `backend/src/core/agent/infrastructure/litellm/client.py` -- Test: `backend/tests/unit/core/agent/test_crewai_runtime.py` - -**Step 1: 写失败测试(RED)** - -在 `test_crewai_runtime.py` 增加用例: -1. 传入 `temperature/max_tokens` 时,`run_completion` 收到对应参数。 -2. 参数为 `None` 时,不应被透传到 LiteLLM。 - -必要时新增 `backend/tests/unit/core/agent/test_litellm_client.py`,单测 `run_completion` 的 kwargs 组装逻辑。 - -**Step 2: 运行测试确认失败** - -Run: `uv run pytest backend/tests/unit/core/agent/test_crewai_runtime.py -q` -Expected: 新增断言失败(参数未透传或未过滤 `None`)。 - -**Step 3: 最小实现(GREEN)** - -1. `create_runtime()` 增加 `llm_config` 参数并传给 `CrewAIRuntime`。 -2. `CrewAIRuntime` 保存 `llm_config`,执行时调用: - - `run_completion(..., temperature=llm_config.temperature, max_tokens=llm_config.max_tokens)` -3. `run_completion()` 改为支持可选 `temperature/max_tokens`,内部仅在非 `None` 时加入 kwargs 再调用 `completion()`。 - -**Step 4: 运行测试确认通过** - -Run: `uv run pytest backend/tests/unit/core/agent/test_crewai_runtime.py -q` -Expected: PASS。 - ---- - -### Task 3: 初始化数据补齐与回归验证 - -**Files:** -- Modify: `backend/src/core/config/static/database/system_agents.yaml` -- Modify: `backend/src/core/config/initial/init_data.py`(如需补充类型兜底) -- Test: `backend/tests/unit/core/agent/test_run_resume_service.py` - -**Step 1: 写失败测试(RED)** - -补充断言:YAML 读取后 `config` 可为空或包含 `max_tokens: null`,初始化逻辑不会报错,且生成结构符合 `SystemAgentLLMConfig`。 - -**Step 2: 运行测试确认失败** - -Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q` -Expected: 新增断言失败。 - -**Step 3: 最小实现(GREEN)** - -1. 在 `system_agents.yaml` 为各 agent 配置显式补充 `max_tokens: null`。 -2. `init_data.py` 保持 `config: SystemAgentLLMConfig | None = None`,写库时统一序列化为 dict。 - -**Step 4: 运行测试确认通过** - -Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q` -Expected: PASS。 - ---- - -## 最终验证 - -1. `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py backend/tests/unit/core/agent/test_crewai_runtime.py -q` -2. `uv run pytest backend/tests/integration/core/agent/test_queue_run_resume.py -q` -3. `uv run ruff check backend/src backend/tests` -4. `uv run basedpyright` - -预期:全部通过;若集成测试依赖本地 DB 状态导致跳过/失败,需记录原因并给出手工验证步骤。 - -## 完成标准 - -- `RunService` 从 DB 读取并校验 `config`。 -- runtime 到 LiteLLM 链路支持 `temperature/max_tokens` 可选透传。 -- `None` 不透传。 -- 单测与相关集成测试通过,并给出命令级证据。 diff --git a/docs/runtime/runtime-runbook.md b/docs/runtime/runtime-runbook.md index 1faf124..7e0f38b 100644 --- a/docs/runtime/runtime-runbook.md +++ b/docs/runtime/runtime-runbook.md @@ -69,7 +69,7 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T db \ ### 启动应用进程 ```bash -bash infra/scripts/app-up.sh +bash infra/scripts/app.sh start ``` 该脚本会在 tmux `social-dev` 会话中拉起: @@ -172,6 +172,7 @@ curl -sS "${WEB_BASE_URL}/api/v1/profile/me" \ - 症状:队列堆积,任务长时间 pending。 - 定位:检查 `worker-*` tmux 窗口和对应日志文件。 - 修复:重启 tmux 会话,确认并发配置与队列名(critical/default/bulk)。 + - 说明:Taskiq 路径当前仅消费 `SOCIAL_WORKER__GROUPS__*__CONCURRENCY`,旧 Celery 参数(prefetch/time_limit 等)已废弃。 ### 2.1) Agent Runtime run/resume 事件不闭环 @@ -179,7 +180,7 @@ curl -sS "${WEB_BASE_URL}/api/v1/profile/me" \ - 定位步骤: ```bash -# 1) 检查 celery worker 是否消费 agent 任务 +# 1) 检查 taskiq worker 是否消费 agent 任务 grep -E "tasks\.agent\.run_command|RUN_STARTED|RUN_FINISHED|RUN_ERROR" logs/worker-default.log # 2) 检查 API SSE 事件读取(带 Last-Event-ID) @@ -192,7 +193,7 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T redis ``` - 修复建议: - - 若 worker 无消费:重启 `worker-default` 窗口并确认 `core.agent.infrastructure.queue.tasks` 已被 Celery include。 + - 若 worker 无消费:重启 `worker-default` 窗口并确认 `core.agent.infrastructure.queue.tasks` 已被 Taskiq worker 加载。 - 若 worker 有事件但 API 无输出:排查 Redis stream 前缀配置与 session_id 是否一致。 - 若出现 `RUN_ERROR`:按 error_id 回查后端日志,不在 API/SSE 中暴露敏感上下文。 @@ -270,4 +271,4 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml up -d --force- | 2026-02-28 | 邀请码功能:新增 invite_codes 表、profiles.referred_by,注册时可选填邀请码并记录邀请关系 | | 2026-03-02 | 文档整理:修正 auth 端点名称(/verifications)、补充 profile 路由文档、修复 L2/L3 验证命令 | | 2026-03-02 | 修正 bootstrap 命令:init-job 需要使用 `uv run python -m core.runtime.cli bootstrap` | -| 2026-03-05 | 新增 Agent Runtime run/resume/events 运维排障流程(Celery + Redis + Last-Event-ID) | +| 2026-03-05 | 新增 Agent Runtime run/resume/events 运维排障流程(Taskiq + Redis + Last-Event-ID) | diff --git a/infra/scripts/app.sh b/infra/scripts/app.sh index 61b668f..20f7524 100755 --- a/infra/scripts/app.sh +++ b/infra/scripts/app.sh @@ -56,9 +56,9 @@ ${SOCIAL_WEB__GUNICORN__WORKER_CLASS:-uvicorn.workers.UvicornWorker} --timeout \ ${SOCIAL_WEB__GUNICORN__TIMEOUT:-60} \ --log-level ${SOCIAL_RUNTIME__LOG_LEVEL:-info}" - WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run celery -A core.celery.app worker --loglevel=info --queues=critical --concurrency=${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}" - WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run celery -A core.celery.app worker --loglevel=info --queues=default --concurrency=${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}" - WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run celery -A core.celery.app worker --loglevel=info --queues=bulk --concurrency=${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}" + WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}" + WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}" + WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}" tmux new-session -d -s "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\"" tmux new-window -t "$SESSION_NAME" -n worker-critical "bash -lc \"$WORKER_CRITICAL_CMD; echo '[worker-critical] exited'; exec bash\"" @@ -92,9 +92,9 @@ stop() { echo "Killing orphaned gunicorn processes..." pkill -f "gunicorn.*app:app" fi - if pgrep -f "celery.*worker" > /dev/null 2>&1; then - echo "Killing orphaned celery processes..." - pkill -f "celery.*worker" + if pgrep -f "taskiq.*worker" > /dev/null 2>&1; then + echo "Killing orphaned taskiq processes..." + pkill -f "taskiq.*worker" fi echo "Session stopped and cleaned up." diff --git a/pyproject.toml b/pyproject.toml index adf7173..b2c54f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ dependencies = [ "alembic>=1.18.3", "asyncpg>=0.31.0", "basedpyright>=1.37.2", - "celery>=5.6.2", "crewai>=1.6.1", "crewai-tools>=1.6.1", "email-validator>=2.3.0", @@ -22,8 +21,11 @@ dependencies = [ "redis>=7.1.0", "sqlalchemy[asyncio]>=2.0.46", "structlog>=24.4.0", + "taskiq>=0.11.0", + "taskiq-redis>=1.0.0", "supabase>=2.27.2", "uvicorn[standard]>=0.40.0", + "gunicorn>=25.1.0", ] [project.optional-dependencies]