chore: checkpoint current backend/runtime changes

This commit is contained in:
qzl
2026-03-06 17:28:17 +08:00
parent 2c59fe5ee2
commit b6087fd195
32 changed files with 1641 additions and 469 deletions
+6 -7
View File
@@ -35,17 +35,16 @@ SOCIAL_REDIS__DB=0
# default: 常规异步任务
# bulk: 批处理/重计算/可延迟任务
SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY=2
SOCIAL_WORKER__GROUPS__CRITICAL__PREFETCH_MULTIPLIER=1
SOCIAL_WORKER__GROUPS__CRITICAL__TIME_LIMIT=300
SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY=2
SOCIAL_WORKER__GROUPS__DEFAULT__PREFETCH_MULTIPLIER=4
SOCIAL_WORKER__GROUPS__DEFAULT__TIME_LIMIT=600
SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY=1
SOCIAL_WORKER__GROUPS__BULK__PREFETCH_MULTIPLIER=1
SOCIAL_WORKER__GROUPS__BULK__TIME_LIMIT=3600
SOCIAL_WORKER__GROUPS__BULK__MAX_TASKS_PER_CHILD=100
############
# Taskiq(可选,默认回落到 Redis URL)
############
# SOCIAL_TASKIQ__BROKER_URL=redis://:password@localhost:6379/0
# SOCIAL_TASKIQ__RESULT_BACKEND_URL=redis://:password@localhost:6379/0
############
# Supabase(本地 Docker 与阿里云自托管保持同一变量)
+13 -12
View File
@@ -13,7 +13,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from core.config.settings import config
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger, log_service_banner
from services.base.redis import redis_service
from services.base import close_registered_services, initialize_registered_services
from v1.router import router as mobile_router
@@ -29,22 +29,23 @@ log_service_banner(
)
logger = get_logger("api.app")
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
initialized = await redis_service.initialize()
initialized, services = await initialize_registered_services(SERVICE_STARTUP_ORDER)
if not initialized:
logger.error("Redis service failed to initialize, aborting startup")
raise RuntimeError("Redis service initialization failed")
logger.info(
"Redis service initialized",
host=config.redis.host,
db=config.redis.db,
)
yield
await redis_service.close()
logger.info("Redis service closed")
logger.error("Service initialization failed, aborting startup")
raise RuntimeError("Service initialization failed")
logger.info("Base services initialized", services=SERVICE_STARTUP_ORDER)
try:
yield
finally:
closed = await close_registered_services(services)
if not closed:
logger.warning("Failed to close all base services")
logger.info("Base services closed", services=SERVICE_STARTUP_ORDER)
app = FastAPI(lifespan=lifespan)
@@ -1,15 +1,15 @@
from __future__ import annotations
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import UUID
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.celery.app import celery_app
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import redis_service
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agent.infrastructure.queue.tasks")
@@ -29,13 +29,12 @@ class ResumeServiceLike(Protocol):
async def _build_redis_publisher() -> PublishEvent:
settings = cast(Any, config)
client = redis_service.get_client()
client = await get_or_init_redis_client()
event_store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
async def _publish(event_type: str, payload: dict[str, object]) -> None:
@@ -70,22 +69,27 @@ async def run_agent_task(
raise ValueError("session_id is required")
UUID(session_id)
tool_call_id = ""
user_input = ""
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
await publisher(start_event, {"session_id": session_id})
try:
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
result = await service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
result = await service_run.run(
session_id=session_id,
user_input=user_input,
@@ -125,6 +129,16 @@ async def run_agent_task(
raise
@celery_app.task(name="tasks.agent.run_command")
@default_broker.task(task_name="tasks.agent.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
@critical_broker.task(task_name="tasks.agent.run_command.critical")
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
@bulk_broker.task(task_name="tasks.agent.run_command.bulk")
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
-73
View File
@@ -1,73 +0,0 @@
from __future__ import annotations
from celery import Celery
from celery import signals as celery_signals
from kombu import Queue
from core.config.settings import config
from core.logging import get_logger
from core.logging.celery import configure_celery_app
from services.base.redis import redis_service
logger = get_logger("core.celery")
def _init_redis_on_worker_startup(**_: object) -> None:
import asyncio
logger.info("Initializing Redis service for Celery worker")
try:
result = asyncio.run(redis_service.initialize())
if result:
logger.info("Redis service initialized for Celery worker")
else:
logger.warning("Redis service initialization returned False")
except Exception as exc: # noqa: BLE001
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
def create_celery_app() -> Celery:
"""Create and configure the Celery application."""
celery_settings = config.celery
app = Celery(
"social_app",
broker=config.celery_broker_url,
backend=config.celery_result_backend,
include=["core.agent.infrastructure.queue.tasks"],
)
app.conf.update(
task_serializer=celery_settings.task_serializer,
result_serializer=celery_settings.result_serializer,
accept_content=celery_settings.accept_content,
timezone=celery_settings.timezone,
enable_utc=celery_settings.enable_utc,
task_track_started=celery_settings.task_track_started,
task_time_limit=celery_settings.task_time_limit,
task_soft_time_limit=celery_settings.task_soft_time_limit,
task_default_retry_delay=celery_settings.task_default_retry_delay,
task_default_queue="default",
task_create_missing_queues=False,
task_queues=(
Queue("default"),
Queue("critical"),
Queue("bulk"),
),
task_routes={
"tasks.critical.*": {"queue": "critical"},
"tasks.bulk.*": {"queue": "bulk"},
},
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1,
)
configure_celery_app(app, settings=config)
return app
celery_app = create_celery_app()
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
+7 -17
View File
@@ -63,19 +63,9 @@ class RuntimeSettings(BaseModel):
return self
class CelerySettings(BaseModel):
class TaskiqSettings(BaseModel):
broker_url: str | None = None
result_backend: str | None = None
task_serializer: str = "json"
result_serializer: str = "json"
accept_content: list[str] = Field(default_factory=lambda: ["json"])
timezone: str = "UTC"
enable_utc: bool = True
task_track_started: bool = True
task_time_limit: int = 300
task_soft_time_limit: int = 240
task_default_retry_delay: int = 30
task_max_retries: int = 3
result_backend_url: str | None = None
class CorsSettings(BaseModel):
@@ -189,7 +179,7 @@ class Settings(BaseSettings):
storage: StorageSettings = StorageSettings()
llm: LlmSettings = LlmSettings()
agent_runtime: AgentRuntimeSettings = AgentRuntimeSettings()
celery: CelerySettings = CelerySettings()
taskiq: TaskiqSettings = TaskiqSettings()
database: DatabaseSettings = DatabaseSettings()
@computed_field
@@ -199,13 +189,13 @@ class Settings(BaseSettings):
@computed_field
@property
def celery_broker_url(self) -> str:
return self.celery.broker_url or self.redis.url
def taskiq_broker_url(self) -> str:
return self.taskiq.broker_url or self.redis.url
@computed_field
@property
def celery_result_backend(self) -> str:
return self.celery.result_backend or self.redis.url
def taskiq_result_backend_url(self) -> str:
return self.taskiq.result_backend_url or self.redis.url
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
env_file=_resolve_env_file(),
-2
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
from core.logging import celery
from core.logging.banner import log_service_banner
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context, get_context
@@ -8,7 +7,6 @@ from core.logging.logger import get_logger
__all__ = [
"bind_context",
"celery",
"clear_context",
"configure_logging",
"get_context",
-64
View File
@@ -1,64 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import cast
from celery import Celery, signals
from core.config.settings import Settings
from core.logging.banner import log_service_banner
from core.logging.config import configure_logging
from core.logging.context import bind_context, clear_context
@dataclass(frozen=True)
class CelerySignalHandlers:
on_setup_logging: Callable[..., None]
on_after_setup_task_logger: Callable[..., None]
on_task_prerun: Callable[..., None]
on_task_postrun: Callable[..., None]
def build_celery_signal_handlers(
settings: Settings | None = None,
) -> CelerySignalHandlers:
active_settings = settings or Settings()
def on_setup_logging(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
log_service_banner(
service_name=active_settings.runtime.service_name,
environment=active_settings.runtime.environment,
)
def on_after_setup_task_logger(*_args: object, **_kwargs: object) -> None:
configure_logging(settings)
def on_task_prerun(*_args: object, **kwargs: object) -> None:
task_id = cast(str | None, kwargs.get("task_id"))
task = kwargs.get("task")
task_name = getattr(task, "name", None)
bind_context(task_id=task_id, task_name=task_name)
def on_task_postrun(*_args: object, **_kwargs: object) -> None:
clear_context()
return CelerySignalHandlers(
on_setup_logging=on_setup_logging,
on_after_setup_task_logger=on_after_setup_task_logger,
on_task_prerun=on_task_prerun,
on_task_postrun=on_task_postrun,
)
def configure_celery_app(app: Celery, settings: Settings | None = None) -> None:
app.conf.worker_hijack_root_logger = False
handlers = build_celery_signal_handlers(settings)
signals.setup_logging.connect(handlers.on_setup_logging, weak=False)
signals.after_setup_task_logger.connect(
handlers.on_after_setup_task_logger, weak=False
)
signals.task_prerun.connect(handlers.on_task_prerun, weak=False)
signals.task_postrun.connect(handlers.on_task_postrun, weak=False)
+3
View File
@@ -0,0 +1,3 @@
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config.settings import config
from core.logging import configure_logging
configure_logging(config)
def _build_broker(queue_name: str) -> ListQueueBroker:
return ListQueueBroker(
url=config.taskiq_broker_url,
queue_name=queue_name,
).with_result_backend(
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
)
default_broker = _build_broker("default")
critical_broker = _build_broker("critical")
bulk_broker = _build_broker("bulk")
# Backward-compatible export name for existing imports/tests.
broker = default_broker
__all__ = ["broker", "default_broker", "critical_broker", "bulk_broker"]
+11 -1
View File
@@ -1,18 +1,28 @@
from __future__ import annotations
from services.base.redis import RedisService, redis_service
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
resolve_registered_services,
)
from services.base.supabase import SupabaseService, supabase_service
__all__ = [
"BaseServiceProvider",
"RedisService",
"ServiceRegistry",
"SupabaseService",
"close_registered_services",
"get_or_init_redis_client",
"initialize_registered_services",
"redis_service",
"register_service",
"register_service_instance",
"resolve_registered_services",
"supabase_service",
]
+9 -1
View File
@@ -92,6 +92,14 @@ class RedisService(BaseServiceProvider):
return self._require_client()
async def get_or_init_redis_client() -> redis.Redis:
if not redis_service.is_initialized:
initialized = await redis_service.initialize()
if not initialized:
raise RuntimeError("Redis service initialization failed")
return redis_service.get_client()
redis_service: RedisService = register_service_instance("redis", RedisService())
__all__ = ["RedisService", "redis_service"]
__all__ = ["RedisService", "get_or_init_redis_client", "redis_service"]
@@ -61,6 +61,12 @@ class ServiceRegistry:
@classmethod
def create_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
return cls.get_service(service_name, **kwargs)
@classmethod
def get_service(
cls, service_name: str, **kwargs: Any
) -> Optional[BaseServiceProvider]:
factory = cls.get_service_factory(service_name)
if not factory:
@@ -82,3 +88,71 @@ TService = TypeVar("TService", bound=BaseServiceProvider)
def register_service_instance(service_name: str, service: TService) -> TService:
ServiceRegistry.register(service_name, lambda: service)
return service
def resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]:
services: list[BaseServiceProvider] = []
for service_name in service_names:
service = ServiceRegistry.get_service(service_name)
if service is None:
raise RuntimeError(f"Service is not registered: {service_name}")
services.append(service)
return services
async def close_registered_services(services: list[BaseServiceProvider]) -> bool:
lifecycle_logger = get_logger("services.base.lifecycle")
all_closed = True
for service in reversed(services):
try:
closed = await service.close()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Failed to close service",
service=service.service_name,
error=str(exc),
)
all_closed = False
continue
if not closed:
lifecycle_logger.warning(
"Service close returned false",
service=service.service_name,
)
all_closed = False
return all_closed
async def initialize_registered_services(
service_names: list[str],
) -> tuple[bool, list[BaseServiceProvider]]:
lifecycle_logger = get_logger("services.base.lifecycle")
initialized_services: list[BaseServiceProvider] = []
try:
services = resolve_registered_services(service_names)
except RuntimeError as exc:
lifecycle_logger.error("Failed to resolve registered services", error=str(exc))
return False, []
for service in services:
try:
initialized = await service.initialize()
except Exception as exc: # noqa: BLE001
lifecycle_logger.warning(
"Service initialization raised exception",
service=service.service_name,
error=str(exc),
)
initialized = False
if not initialized:
lifecycle_logger.error(
"Service initialization failed, rolling back",
service=service.service_name,
)
await close_registered_services(initialized_services)
return False, []
initialized_services.append(service)
return True, initialized_services
+89
View File
@@ -0,0 +1,89 @@
from __future__ import annotations
import asyncio
from typing import Any
from supabase import create_client
from core.config.settings import SupabaseSettings, config
from .service_interface import BaseServiceProvider, register_service_instance
class SupabaseService(BaseServiceProvider):
def __init__(self, settings: SupabaseSettings | None = None) -> None:
super().__init__("supabase")
self._settings = settings or config.supabase
self._client: Any = None
self._admin_client: Any = None
async def initialize(self, **_: Any) -> bool:
try:
self._client = create_client(
self._settings.url,
self._settings.anon_key,
)
self._admin_client = create_client(
self._settings.url,
self._settings.service_role_key,
)
self._set_initialized(True)
self.logger.info("Supabase service initialized")
return True
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase service initialization failed", error=str(exc))
self._client = None
self._admin_client = None
self._set_initialized(False)
return False
async def close(self) -> bool:
self._client = None
self._admin_client = None
self._set_initialized(False)
self.logger.info("Supabase service closed")
return True
async def health_check(self) -> dict[str, Any]:
client = self._client
admin_client = self._admin_client
if client is None or admin_client is None:
return {"status": "unhealthy", "details": {"error": "not initialized"}}
try:
await asyncio.to_thread(client.auth.get_session)
await asyncio.to_thread(admin_client.auth.admin.list_users, page=1, per_page=1)
return {
"status": "healthy",
"details": {
"anon_client": "ready",
"admin_client": "ready",
},
}
except Exception as exc: # noqa: BLE001
self.logger.warning("Supabase health check failed", error=str(exc))
return {"status": "unhealthy", "details": {"error": str(exc)}}
def get_client(self) -> Any:
return self._require_client()
def get_admin_client(self) -> Any:
return self._require_admin_client()
def _require_client(self) -> Any:
client = self._client
if client is None:
raise RuntimeError("Supabase client is not initialized")
return client
def _require_admin_client(self) -> Any:
admin_client = self._admin_client
if admin_client is None:
raise RuntimeError("Supabase admin client is not initialized")
return admin_client
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
)
__all__ = ["SupabaseService", "supabase_service"]
+69 -27
View File
@@ -1,57 +1,98 @@
from __future__ import annotations
from typing import Any, cast
import asyncio
from typing import Any
from uuid import UUID
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.agent.infrastructure.queue.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import redis_service
from services.base.redis import get_or_init_redis_client
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
class CeleryQueueClient:
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis = redis_service.get_client()
self._redis: Redis | None = None
async def _get_redis(self) -> Redis:
if self._redis is None:
self._redis = await get_or_init_redis_client()
return self._redis
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
queue = str(command.get("queue", "default")).strip().lower()
if queue == "critical":
return run_command_task_critical
if queue == "bulk":
return run_command_task_bulk
return run_command_task
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_client = await self._get_redis()
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
locked = await redis_client.set(
redis_key,
DEDUP_INFLIGHT_MARKER,
nx=True,
ex=DEDUP_LOCK_SECONDS,
)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
for _ in range(DEDUP_WAIT_RETRIES):
existing = await redis_client.get(redis_key)
if existing and existing != DEDUP_INFLIGHT_MARKER:
return existing
await asyncio.sleep(DEDUP_WAIT_SECONDS)
raise RuntimeError("duplicate request is still in progress")
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
queue_task = self._select_queue_task(payload)
try:
result = await queue_task.kiq(payload)
task_id = str(result.task_id)
if redis_key is not None:
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
return task_id
except Exception:
if redis_key is not None:
await redis_client.delete(redis_key)
raise
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis_service.get_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
self._store: RedisStreamEventStore | None = None
async def _get_store(self) -> RedisStreamEventStore:
if self._store is None:
client = await get_or_init_redis_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
return self._store
async def read(
self,
@@ -59,7 +100,8 @@ class RedisEventStream:
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
store = await self._get_store()
rows = await store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
@@ -69,6 +111,6 @@ class RedisEventStream:
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
)
+28 -20
View File
@@ -5,10 +5,10 @@ from collections.abc import Mapping
from typing import Any, cast
from fastapi import HTTPException
from supabase import AuthError, create_client
from supabase import AuthError
from core.config.settings import SupabaseSettings, config
from core.logging import get_logger
from services.base.supabase import supabase_service
from v1.auth.schemas import (
AuthUser,
PasswordResetConfirmRequest,
@@ -28,17 +28,16 @@ logger = get_logger("v1.auth.gateway")
class SupabaseAuthGateway(AuthServiceGateway):
_client: Any
_admin_client: Any
def _get_client(self) -> Any:
return supabase_service.get_client()
def __init__(self) -> None:
settings: SupabaseSettings = config.supabase
self._client = create_client(settings.url, settings.anon_key)
self._admin_client = create_client(settings.url, settings.service_role_key)
def _get_admin_client(self) -> Any:
return supabase_service.get_admin_client()
async def create_verification(
self, request: VerificationCreateRequest
) -> VerificationCreateResponse:
client = self._get_client()
metadata: dict[str, Any] = {"username": request.username}
if request.invite_code:
metadata["invite_code"] = request.invite_code
@@ -50,7 +49,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
if request.redirect_to:
payload["options"] = {"email_redirect_to": request.redirect_to}
try:
sign_up = cast(Any, self._client.auth.sign_up)
sign_up = cast(Any, client.auth.sign_up)
await asyncio.to_thread(sign_up, payload)
return VerificationCreateResponse(email=request.email)
except AuthError as exc:
@@ -62,13 +61,14 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def verify_verification(
self, request: VerificationVerifyRequest
) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {
"type": "signup",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, payload)
return _map_auth_response(response, "Invalid verification code")
except AuthError as exc:
@@ -78,17 +78,19 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def resend_verification(self, request: VerificationResendRequest) -> None:
client = self._get_client()
payload: dict[str, Any] = {"type": "signup", "email": request.email}
try:
resend = cast(Any, self._client.auth.resend)
resend = cast(Any, client.auth.resend)
await asyncio.to_thread(resend, payload)
except AuthError as exc:
logger.warning("Signup resend failed", error_type=type(exc).__name__)
async def create_session(self, request: SessionCreateRequest) -> SessionResponse:
client = self._get_client()
payload: dict[str, Any] = {"email": request.email, "password": request.password}
try:
sign_in = cast(Any, self._client.auth.sign_in_with_password)
sign_in = cast(Any, client.auth.sign_in_with_password)
response = await asyncio.to_thread(sign_in, payload)
return _map_auth_response(response, "Invalid credentials")
except AuthError as exc:
@@ -96,9 +98,10 @@ class SupabaseAuthGateway(AuthServiceGateway):
raise HTTPException(status_code=401, detail="Invalid credentials") from exc
async def refresh_session(self, request: SessionRefreshRequest) -> SessionResponse:
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
request.refresh_token,
)
return _map_auth_response(response, "Invalid refresh token")
@@ -111,20 +114,21 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def delete_session(self, refresh_token: str | None) -> None:
if not refresh_token:
raise HTTPException(status_code=401, detail="Missing refresh token")
client = self._get_client()
try:
response = await asyncio.to_thread(
self._client.auth.refresh_session,
client.auth.refresh_session,
refresh_token,
)
session = getattr(response, "session", None)
if session is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")
await asyncio.to_thread(
self._client.auth.set_session,
client.auth.set_session,
str(session.access_token),
str(session.refresh_token),
)
await asyncio.to_thread(self._client.auth.sign_out)
await asyncio.to_thread(client.auth.sign_out)
except AuthError as exc:
logger.warning("Logout failed", error_type=type(exc).__name__)
raise HTTPException(
@@ -132,7 +136,8 @@ class SupabaseAuthGateway(AuthServiceGateway):
) from exc
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
users = await asyncio.to_thread(_list_auth_users, self._admin_client)
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
normalized_email = email.lower()
user = next(
(
@@ -157,8 +162,9 @@ class SupabaseAuthGateway(AuthServiceGateway):
)
async def request_password_reset(self, request: PasswordResetRequest) -> None:
client = self._get_client()
try:
reset_email = cast(Any, self._client.auth.reset_password_email)
reset_email = cast(Any, client.auth.reset_password_email)
email = _coerce_reset_email(request.email)
if request.redirect_to:
options: dict[str, str] = {"redirect_to": request.redirect_to}
@@ -174,13 +180,15 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def confirm_password_reset(
self, request: PasswordResetConfirmRequest
) -> None:
client = self._get_client()
admin_client = self._get_admin_client()
verify_payload: dict[str, Any] = {
"type": "recovery",
"email": request.email,
"token": request.token,
}
try:
verify_otp = cast(Any, self._client.auth.verify_otp)
verify_otp = cast(Any, client.auth.verify_otp)
response = await asyncio.to_thread(verify_otp, verify_payload)
session = getattr(response, "session", None)
user = getattr(response, "user", None)
@@ -190,7 +198,7 @@ class SupabaseAuthGateway(AuthServiceGateway):
status_code=401, detail="Invalid or expired verification code"
)
await asyncio.to_thread(
self._admin_client.auth.admin.update_user_by_id,
admin_client.auth.admin.update_user_by_id,
user_id,
{"password": request.new_password},
)
@@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.agent.infrastructure.queue.tasks import _build_redis_publisher, run_agent_task
class _FakeRunService:
@@ -67,3 +67,35 @@ async def test_run_agent_task_emits_error_event_on_exception() -> None:
)
assert events == ["RUN_STARTED", "RUN_ERROR"]
@pytest.mark.asyncio
async def test_run_agent_task_rejects_invalid_command() -> None:
with pytest.raises(ValueError, match="invalid command type"):
await run_agent_task({"command": "invalid", "session_id": "00000000-0000-0000-0000-000000000001"})
@pytest.mark.asyncio
async def test_run_agent_task_resume_requires_tool_call_id() -> None:
with pytest.raises(ValueError, match="tool_call_id is required"):
await run_agent_task(
{
"command": "resume",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
@pytest.mark.asyncio
async def test_build_redis_publisher_init_fail_raises_runtime_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from core.agent.infrastructure.queue import tasks
async def _fake_get_client() -> object:
raise RuntimeError("Redis service initialization failed")
monkeypatch.setattr(tasks, "get_or_init_redis_client", _fake_get_client)
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
await _build_redis_publisher()
@@ -0,0 +1,10 @@
from __future__ import annotations
from core.config.settings import Settings
def test_taskiq_uses_redis_url_by_default() -> None:
settings = Settings()
assert settings.taskiq_broker_url.startswith("redis://")
assert settings.taskiq_result_backend_url.startswith("redis://")
@@ -0,0 +1,37 @@
from __future__ import annotations
import importlib
import sys
import pytest
from core.taskiq.app import broker, bulk_broker, critical_broker, default_broker
def test_taskiq_broker_is_configured() -> None:
assert broker is not None
assert default_broker is broker
assert critical_broker is not None
assert bulk_broker is not None
def test_taskiq_app_configures_logging_on_import(
monkeypatch: pytest.MonkeyPatch,
) -> None:
sys.modules.pop("core.taskiq.app", None)
sys.modules.pop("core.taskiq", None)
called = {"count": 0, "args": None}
def _fake_configure_logging(*args: object, **__: object) -> None:
called["count"] += 1
called["args"] = args
monkeypatch.setattr("core.logging.configure_logging", _fake_configure_logging)
importlib.import_module("core.taskiq.app")
from core.config.settings import config
assert called["count"] == 1
assert called["args"] == (config,)
@@ -0,0 +1,20 @@
from __future__ import annotations
from pathlib import Path
ROOT_DIR = Path(__file__).resolve().parents[4]
APP_SCRIPT = ROOT_DIR / "infra" / "scripts" / "app.sh"
def test_worker_commands_use_taskiq() -> None:
content = APP_SCRIPT.read_text(encoding="utf-8")
removed_runner = "uv run c" "elery"
assert "uv run taskiq worker" in content
assert "core.taskiq.app:critical_broker" in content
assert "core.taskiq.app:default_broker" in content
assert "core.taskiq.app:bulk_broker" in content
assert 'pgrep -f "taskiq.*worker"' in content
assert 'pkill -f "taskiq.*worker"' in content
assert removed_runner not in content
@@ -3,7 +3,7 @@ from __future__ import annotations
import pytest
from core.config.settings import RedisSettings
from services.base.redis import RedisService
from services.base.redis import RedisService, get_or_init_redis_client, redis_service
class _FakeRedisClient:
@@ -96,3 +96,35 @@ def test_get_client_raises_before_init() -> None:
with pytest.raises(RuntimeError):
service.get_client()
@pytest.mark.asyncio
async def test_get_or_init_redis_client_initializes_when_needed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_client = _FakeRedisClient()
async def _fake_initialize() -> bool:
return True
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
monkeypatch.setattr(redis_service, "get_client", lambda: fake_client)
client = await get_or_init_redis_client()
assert client is fake_client
@pytest.mark.asyncio
async def test_get_or_init_redis_client_raises_when_init_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_initialize() -> bool:
return False
monkeypatch.setattr(type(redis_service), "is_initialized", property(lambda _: False))
monkeypatch.setattr(redis_service, "initialize", _fake_initialize)
with pytest.raises(RuntimeError, match="Redis service initialization failed"):
await get_or_init_redis_client()
@@ -1,8 +1,12 @@
from __future__ import annotations
import pytest
from services.base.service_interface import (
BaseServiceProvider,
ServiceRegistry,
close_registered_services,
initialize_registered_services,
register_service,
register_service_instance,
)
@@ -35,6 +39,17 @@ def test_register_service_and_create_service() -> None:
assert created.get_service_info()["name"] == "dummy"
def test_register_service_and_get_service() -> None:
@register_service("dummy-service-get")
class _RegisteredService(_DummyService):
pass
resolved = ServiceRegistry.get_service("dummy-service-get")
assert resolved is not None
assert resolved.get_service_info()["name"] == "dummy"
def test_register_service_instance_returns_same_instance() -> None:
instance = _DummyService("singleton")
@@ -47,3 +62,77 @@ def test_register_service_instance_returns_same_instance() -> None:
def test_create_service_returns_none_for_missing() -> None:
assert ServiceRegistry.create_service("missing-service") is None
def test_get_service_returns_none_for_missing() -> None:
assert ServiceRegistry.get_service("missing-service") is None
class _LifecycleService(BaseServiceProvider):
def __init__(self, name: str, recorder: list[str], fail_on_init: bool = False) -> None:
super().__init__(name)
self._recorder = recorder
self._fail_on_init = fail_on_init
async def initialize(self, **_: object) -> bool:
self._recorder.append(f"init:{self.service_name}")
if self._fail_on_init:
return False
self._set_initialized(True)
return True
async def close(self) -> bool:
self._recorder.append(f"close:{self.service_name}")
self._set_initialized(False)
return True
async def health_check(self) -> dict[str, object]:
return {"status": "healthy", "details": {}}
@pytest.mark.asyncio
async def test_initialize_registered_services_success() -> None:
recorder: list[str] = []
first = register_service_instance(
"lifecycle-success-first", _LifecycleService("first", recorder)
)
second = register_service_instance(
"lifecycle-success-second", _LifecycleService("second", recorder)
)
initialized, services = await initialize_registered_services(
["lifecycle-success-first", "lifecycle-success-second"]
)
assert initialized is True
assert services == [first, second]
assert recorder == ["init:first", "init:second"]
@pytest.mark.asyncio
async def test_initialize_registered_services_failure_rolls_back() -> None:
recorder: list[str] = []
register_service_instance("lifecycle-fail-first", _LifecycleService("first", recorder))
register_service_instance(
"lifecycle-fail-second", _LifecycleService("second", recorder, fail_on_init=True)
)
initialized, services = await initialize_registered_services(
["lifecycle-fail-first", "lifecycle-fail-second"]
)
assert initialized is False
assert services == []
assert recorder == ["init:first", "init:second", "close:first"]
@pytest.mark.asyncio
async def test_close_registered_services_closes_in_reverse_order() -> None:
recorder: list[str] = []
first = _LifecycleService("first", recorder)
second = _LifecycleService("second", recorder)
closed = await close_registered_services([first, second])
assert closed is True
assert recorder == ["close:second", "close:first"]
@@ -0,0 +1,111 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.config.settings import SupabaseSettings
from services.base.supabase import SupabaseService
@pytest.mark.asyncio
async def test_initialize_success(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
anon_client = MagicMock()
admin_client = MagicMock()
create_calls: list[tuple[str, str]] = []
def _fake_create_client(url: str, key: str) -> object:
create_calls.append((url, key))
return anon_client if len(create_calls) == 1 else admin_client
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
result = await service.initialize()
assert result is True
assert service.is_initialized is True
assert service.get_client() is anon_client
assert service.get_admin_client() is admin_client
assert len(create_calls) == 2
@pytest.mark.asyncio
async def test_initialize_failure(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
def _fake_create_client(_: str, __: str) -> object:
raise RuntimeError("boom")
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
result = await service.initialize()
assert result is False
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
@pytest.mark.asyncio
async def test_close_clears_clients(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
def _fake_create_client(_: str, __: str) -> object:
return MagicMock()
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
assert await service.initialize() is True
assert await service.close() is True
assert service.is_initialized is False
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
@pytest.mark.asyncio
async def test_health_check_uninitialized() -> None:
service = SupabaseService(settings=SupabaseSettings())
health = await service.health_check()
assert health["status"] == "unhealthy"
@pytest.mark.asyncio
async def test_health_check_initialized(monkeypatch: pytest.MonkeyPatch) -> None:
service = SupabaseService(settings=SupabaseSettings())
anon_client = MagicMock()
anon_client.auth.get_session = MagicMock(return_value=None)
admin_list_users = MagicMock(return_value=SimpleNamespace(users=[]))
admin_client = MagicMock()
admin_client.auth.admin = SimpleNamespace(list_users=admin_list_users)
create_sequence = [anon_client, admin_client]
def _fake_create_client(_: str, __: str) -> object:
return create_sequence.pop(0)
monkeypatch.setattr("services.base.supabase.create_client", _fake_create_client)
assert await service.initialize() is True
health = await service.health_check()
assert health["status"] == "healthy"
admin_list_users.assert_called_once_with(page=1, per_page=1)
def test_get_client_raises_before_init() -> None:
service = SupabaseService(settings=SupabaseSettings())
with pytest.raises(RuntimeError):
service.get_client()
with pytest.raises(RuntimeError):
service.get_admin_client()
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import pytest
import app as app_module
@pytest.mark.asyncio
async def test_lifespan_uses_registered_services(monkeypatch: pytest.MonkeyPatch) -> None:
initialized_services = [object(), object()]
calls: dict[str, object] = {}
async def _fake_initialize(service_names: list[str]) -> tuple[bool, list[object]]:
calls["init_names"] = service_names
return True, initialized_services
async def _fake_close(services: list[object]) -> bool:
calls["close_services"] = services
return True
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
monkeypatch.setattr(app_module, "close_registered_services", _fake_close)
context = app_module.lifespan(app_module.app)
await context.__aenter__()
await context.__aexit__(None, None, None)
assert calls["init_names"] == ["redis", "supabase"]
assert calls["close_services"] == initialized_services
@pytest.mark.asyncio
async def test_lifespan_raises_when_initialization_failed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _fake_initialize(_: list[str]) -> tuple[bool, list[object]]:
return False, []
monkeypatch.setattr(app_module, "initialize_registered_services", _fake_initialize)
context = app_module.lifespan(app_module.app)
with pytest.raises(RuntimeError, match="Service initialization failed"):
await context.__aenter__()
-45
View File
@@ -1,45 +0,0 @@
from __future__ import annotations
from celery import Celery
from pytest import MonkeyPatch
from core.logging import celery as celery_logging
from core.logging.context import clear_context, get_context
class DummyTask:
name: str = "tasks.sample"
def test_celery_prerun_binds_task_context() -> None:
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_task_prerun(task_id="task-123", task=DummyTask())
context = get_context()
assert context["task_id"] == "task-123"
assert context["task_name"] == "tasks.sample"
clear_context()
def test_celery_setup_logging_calls_configure(monkeypatch: MonkeyPatch) -> None:
called = {"value": False}
def fake_configure_logging(settings: object | None = None) -> None:
called["value"] = True
monkeypatch.setattr(celery_logging, "configure_logging", fake_configure_logging)
handlers = celery_logging.build_celery_signal_handlers()
handlers.on_setup_logging()
assert called["value"] is True
def test_configure_celery_app_disables_hijack() -> None:
app = Celery("test")
celery_logging.configure_celery_app(app)
assert app.conf.worker_hijack_root_logger is False
@@ -0,0 +1,227 @@
from __future__ import annotations
import pytest
from v1.agent.dependencies import TaskiqQueueClient
class _FakeRedis:
def __init__(self) -> None:
self.store: dict[str, str] = {}
self.delete_calls: list[str] = []
async def set(
self,
key: str,
value: str,
*,
nx: bool = False,
ex: int | None = None,
) -> bool:
del ex
if nx and key in self.store:
return False
self.store[key] = value
return True
async def get(self, key: str) -> str | None:
return self.store.get(key)
async def delete(self, key: str) -> int:
self.delete_calls.append(key)
existed = 1 if key in self.store else 0
self.store.pop(key, None)
return existed
class _FakeAsyncResult:
def __init__(self, task_id: str) -> None:
self.task_id = task_id
@pytest.mark.asyncio
async def test_enqueue_returns_task_id(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
assert payload["command"] == "run"
return _FakeAsyncResult("task-123")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
assert resolved_client["value"] is True
assert task_id == "task-123"
@pytest.mark.asyncio
async def test_enqueue_resume_dedup_returns_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
resolved_client = {"value": False}
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
return _FakeAsyncResult("new-task-id")
async def _fake_get_or_init_client() -> _FakeRedis:
resolved_client["value"] = True
return fake_redis
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
dedup_key = "resume:session-1:call-1"
fake_redis.store[f"agent:dedup:{dedup_key}"] = "existing-task-id"
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert resolved_client["value"] is True
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_inflight_dedup_waits_and_reuses_existing_task_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
fake_redis.store[redis_key] = deps.DEDUP_INFLIGHT_MARKER
attempts = {"count": 0}
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_get(key: str) -> str | None:
attempts["count"] += 1
if attempts["count"] > 1:
fake_redis.store[key] = "existing-task-id"
return fake_redis.store.get(key)
async def _fake_sleep(_: float) -> None:
return None
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise AssertionError("should not enqueue when dedup task id appears")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(fake_redis, "get", _fake_get)
monkeypatch.setattr(deps.asyncio, "sleep", _fake_sleep)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert task_id == "existing-task-id"
@pytest.mark.asyncio
async def test_enqueue_failure_cleans_dedup_lock(monkeypatch: pytest.MonkeyPatch) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
dedup_key = "resume:session-1:call-1"
redis_key = f"agent:dedup:{dedup_key}"
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_kiq(payload: dict[str, object]) -> _FakeAsyncResult:
del payload
raise RuntimeError("enqueue failed")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_kiq)
client = TaskiqQueueClient()
with pytest.raises(RuntimeError, match="enqueue failed"):
await client.enqueue(
command={"command": "resume", "session_id": "session-1", "tool_call_id": "call-1"},
dedup_key=dedup_key,
)
assert redis_key in fake_redis.delete_calls
@pytest.mark.asyncio
async def test_enqueue_uses_critical_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_critical_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("critical-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_critical, "kiq", _fake_critical_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "critical"},
dedup_key=None,
)
assert task_id == "critical-task-id"
@pytest.mark.asyncio
async def test_enqueue_uses_bulk_queue_when_requested(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from v1.agent import dependencies as deps
fake_redis = _FakeRedis()
async def _fake_get_or_init_client() -> _FakeRedis:
return fake_redis
async def _fake_default_kiq(_: dict[str, object]) -> _FakeAsyncResult:
raise AssertionError("default queue should not be selected")
async def _fake_bulk_kiq(_: dict[str, object]) -> _FakeAsyncResult:
return _FakeAsyncResult("bulk-task-id")
monkeypatch.setattr(deps, "get_or_init_redis_client", _fake_get_or_init_client)
monkeypatch.setattr(deps.run_command_task, "kiq", _fake_default_kiq)
monkeypatch.setattr(deps.run_command_task_bulk, "kiq", _fake_bulk_kiq)
client = TaskiqQueueClient()
task_id = await client.enqueue(
command={"command": "run", "queue": "bulk"},
dedup_key=None,
)
assert task_id == "bulk-task-id"
+40 -28
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
@@ -12,37 +12,44 @@ from v1.auth.schemas import PasswordResetConfirmRequest, PasswordResetRequest
class TestSupabaseAuthGateway:
@pytest.fixture
def gateway(self) -> SupabaseAuthGateway:
with patch("v1.auth.gateway.create_client") as mock_create:
mock_client = MagicMock()
mock_admin_client = MagicMock()
mock_create.side_effect = [mock_client, mock_admin_client]
return SupabaseAuthGateway()
def gateway(
self, monkeypatch: pytest.MonkeyPatch
) -> tuple[SupabaseAuthGateway, MagicMock, MagicMock]:
mock_client = MagicMock()
mock_admin_client = MagicMock()
monkeypatch.setattr("v1.auth.gateway.supabase_service.get_client", lambda: mock_client)
monkeypatch.setattr(
"v1.auth.gateway.supabase_service.get_admin_client",
lambda: mock_admin_client,
)
return SupabaseAuthGateway(), mock_client, mock_admin_client
@pytest.mark.asyncio
async def test_request_password_reset_calls_email_with_string(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_with_redirect(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(
email="test@example.com",
redirect_to="http://localhost:3000/reset-password",
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with(
"test@example.com",
@@ -51,64 +58,68 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_request_password_reset_swallows_auth_error(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
from supabase import AuthError
mock_reset_email = MagicMock(side_effect=AuthError("rate limit exceeded", None))
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest(email="test@example.com")
result = await gateway.request_password_reset(request)
result = await sut.request_password_reset(request)
mock_reset_email.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_request_password_reset_extracts_email_from_mapping(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
mock_reset_email = MagicMock()
gateway._client.auth.reset_password_email = mock_reset_email
mock_client.auth.reset_password_email = mock_reset_email
request = PasswordResetRequest.model_construct(
email={"email": "test@example.com"},
redirect_to=None,
)
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
mock_reset_email.assert_called_once_with("test@example.com")
@pytest.mark.asyncio
async def test_request_password_reset_rejects_invalid_email_shape(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, _, _ = gateway
request = PasswordResetRequest.model_construct(
email={"unexpected": "value"},
redirect_to=None,
)
with pytest.raises(HTTPException) as exc_info:
await gateway.request_password_reset(request)
await sut.request_password_reset(request)
assert exc_info.value.status_code == 422
assert exc_info.value.detail == "Invalid email"
@pytest.mark.asyncio
async def test_confirm_password_reset_updates_password_by_user_id(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, mock_admin_client = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id="user-1"),
)
mock_verify_otp = MagicMock(return_value=verify_response)
gateway._client.auth.verify_otp = mock_verify_otp
mock_client.auth.verify_otp = mock_verify_otp
mock_update_user_by_id = MagicMock()
gateway._admin_client.auth.admin = SimpleNamespace(
mock_admin_client.auth.admin = SimpleNamespace(
update_user_by_id=mock_update_user_by_id
)
@@ -118,7 +129,7 @@ class TestSupabaseAuthGateway:
new_password="newpassword123",
)
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
mock_verify_otp.assert_called_once_with(
{
@@ -134,13 +145,14 @@ class TestSupabaseAuthGateway:
@pytest.mark.asyncio
async def test_confirm_password_reset_raises_when_user_id_missing(
self, gateway: SupabaseAuthGateway
self, gateway: tuple[SupabaseAuthGateway, MagicMock, MagicMock]
) -> None:
sut, mock_client, _ = gateway
verify_response = SimpleNamespace(
session=SimpleNamespace(access_token="access"),
user=SimpleNamespace(id=""),
)
gateway._client.auth.verify_otp = MagicMock(return_value=verify_response)
mock_client.auth.verify_otp = MagicMock(return_value=verify_response)
request = PasswordResetConfirmRequest(
email="test@example.com",
@@ -149,7 +161,7 @@ class TestSupabaseAuthGateway:
)
with pytest.raises(HTTPException) as exc_info:
await gateway.confirm_password_reset(request)
await sut.confirm_password_reset(request)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid or expired verification code"
@@ -0,0 +1,260 @@
# Supabase 统一服务生命周期设计(优化版)
**Date:** 2026-03-06
**Status:** Draft
---
## 0. Intake Contract
- Objective: 将 Supabase 客户端纳入统一服务生命周期管理,避免每次请求重复创建客户端。
- Deliverable: 新增 `SupabaseService`,并基于 `service_interface.py``ServiceRegistry` 提供统一初始化/关闭路径,完成 auth 侧迁移。
- Constraints:
- 保持现有 `core.config.settings` 的配置读取行为不变。
- 不引入 `os.environ` 直接读取。
- 不改变现有 API 语义。
- Verification target:
- 通过单元测试证明 Supabase 服务初始化、关闭、健康检查行为。
- 通过应用启动测试证明统一初始化流程可用。
- 通过 auth 相关测试证明迁移后业务行为一致。
---
## 1. 复杂度与风险分级
- Complexity: `S2`
- 原因:涉及多文件改造(`services/base``app.py``v1/auth`、测试)。
- Risk Tier: `L1`
- 原因:涉及应用启动链路和认证网关依赖,但不改变对外接口契约。
L1 Gate 要求:执行 `refactor-cleaner` 审视冗余与结构风险(`code-reviewer` 可选)。
---
## 2. 现状与问题
### 2.1 当前现状
- `SupabaseAuthGateway``__init__` 内直接 `create_client(...)`,每次实例化都会创建 anon/admin 客户端。
- `get_auth_service()` 当前每次请求都会 new `SupabaseAuthGateway()`,导致客户端重复构造。
- `ServiceRegistry` 已存在,但目前主要用于注册,应用启动仍是手写逐个初始化。
### 2.2 核心问题
1. 生命周期不统一:Supabase 没有接入应用启动/关闭的统一管理。
2. 初始化代码重复趋势:服务增多后,`app.py` 的 lifespan 会继续膨胀。
3. 网关构造时机风险:若在应用未初始化阶段取客户端,可能抛运行时异常。
---
## 3. 优化设计(推荐方案)
### 3.1 方案摘要
`service_interface.py` 基础上新增统一生命周期函数,按服务名列表批量初始化/关闭;`app.py` 仅声明服务顺序,减少样板代码。Supabase 使用 `config.supabase` 作为默认配置来源,保持 settings 行为一致。
### 3.2 目标文件结构
```text
backend/src/services/base/
├── __init__.py
├── service_interface.py # 扩展:统一生命周期函数
├── redis.py
└── supabase.py # 新增
```
### 3.3 service_interface 统一初始化能力(新增)
`service_interface.py` 新增以下函数(建议命名):
- `resolve_registered_services(service_names: list[str]) -> list[BaseServiceProvider]`
- `initialize_registered_services(service_names: list[str]) -> tuple[bool, list[BaseServiceProvider]]`
- `close_registered_services(services: list[BaseServiceProvider]) -> bool`
约束与行为:
1. 初始化按 `service_names` 顺序执行。
2. 任一服务初始化失败时:
- 返回 `False`
- 对已成功初始化的服务按逆序执行关闭回滚。
3. 关闭按逆序执行,最大化依赖安全性。
4. 日志必须包含失败服务名和错误摘要。
这样 `app.py` 只需声明:
```python
SERVICE_STARTUP_ORDER = ["redis", "supabase"]
```
并调用统一函数,减少重复初始化样板。
### 3.4 SupabaseService 设计
`supabase.py` 关键点:
- 继承 `BaseServiceProvider`
- 构造函数签名:
- `def __init__(self, settings: SupabaseSettings | None = None) -> None`
- 默认 `settings or config.supabase`,确保与当前配置源一致。
- `initialize()`:创建 anon/admin 两个 client,失败返回 `False`
- `close()`
- 清空 `_client``_admin_client`
- `self._set_initialized(False)`
- `health_check()`
- 必须进行至少一个轻量真实请求验证,不仅检查本地对象存在。
- 返回结构与 `RedisService.health_check()`风格一致(`status + details`)。
注册方式:
```python
supabase_service: SupabaseService = register_service_instance(
"supabase", SupabaseService()
)
```
### 3.5 app.py 改造
当前手写 `redis_service.initialize()` 改为调用统一初始化函数。
目标行为:
1. 启动阶段:
- 调用 `initialize_registered_services(["redis", "supabase"])`
- 失败则 `raise RuntimeError("Service initialization failed")`
2. 关闭阶段:
- 调用 `close_registered_services(initialized_services)`
### 3.6 AuthGateway 迁移策略(避免构造时机问题)
不建议在 `SupabaseAuthGateway.__init__` 里立即绑定 client;改为按需获取:
- 保留网关对象轻量化。
- 在每个业务方法内部通过 `supabase_service.get_client()` / `get_admin_client()` 取实例。
优点:
1. 避免模块导入或依赖构建阶段误触未初始化 client。
2.`users/dependencies.py` 中全局缓存 gateway 的场景更安全。
3. 不改变业务层接口。
---
## 4. 配置与兼容性保证
### 4.1 settings/config 行为不变
迁移后依然通过 `core.config.settings.config.supabase` 读取:
- `url`
- `anon_key`
- `service_role_key`
- `jwt_secret`JWT 校验现有逻辑继续使用)
### 4.2 环境变量兼容
由于 `Settings` + `env_nested_delimiter` 机制不变,现有环境变量命名与 `.env` 内容无需修改。
### 4.3 对现有代码影响
- API 层 schema/路由不变。
- 认证行为不变。
- 仅优化客户端生命周期与启动流程。
---
## 5. 实施计划(可执行)
### Task 1: 扩展统一生命周期接口
**Files**
- Modify: `backend/src/services/base/service_interface.py`
- Test: `backend/tests/unit/services/base/test_service_interface.py`(新增)
**Steps**
1. 写失败测试:初始化顺序、失败回滚、关闭逆序。
2. 实现生命周期函数。
3. 跑单测确认通过。
### Task 2: 新增 SupabaseService
**Files**
- Create: `backend/src/services/base/supabase.py`
- Modify: `backend/src/services/base/__init__.py`
- Test: `backend/tests/unit/services/base/test_supabase.py`
**Steps**
1. 写失败测试(init success/fail、close、health_check)。
2. 实现 `SupabaseService` 与实例注册。
3. 跑单测。
### Task 3: 接入 app lifespan 统一初始化
**Files**
- Modify: `backend/src/app.py`
- Test: `backend/tests/integration/test_app_lifespan.py`(新增或扩展)
**Steps**
1. 写失败测试(supabase init fail 时应用启动失败)。
2. 替换手写初始化为统一函数。
3. 跑集成测试。
### Task 4: 迁移 AuthGateway 获取 client 方式
**Files**
- Modify: `backend/src/v1/auth/gateway.py`
- Optional Modify: `backend/src/v1/auth/dependencies.py`
- Optional Modify: `backend/src/v1/users/dependencies.py`
- Test: `backend/tests/unit/v1/auth/test_gateway.py`(扩展)
**Steps**
1. 写失败测试(未初始化时错误、初始化后正常调用)。
2. 改为方法内按需取 client。
3. 跑 auth 相关单测。
### Task 5: 全量验证与门禁
**Commands**
- `uv run ruff check backend/src backend/tests`
- `uv run basedpyright`
- `uv run pytest backend/tests/unit/services/base -q`
- `uv run pytest backend/tests/unit/v1/auth -q`
- `uv run pytest backend/tests/integration -q`
输出要求:记录每条命令 pass/fail 与关键摘要。
---
## 6. 验收标准(更新)
- [ ] `SupabaseService` 继承 `BaseServiceProvider` 并注册到 `ServiceRegistry`
- [ ] `service_interface.py` 提供统一初始化/关闭函数
- [ ] `app.py` 通过统一函数初始化 `redis + supabase`
- [ ] Supabase 配置读取仍仅来自 `core.config.settings.config`
- [ ] `auth/gateway.py` 不再在 `__init__` 新建客户端
- [ ] 初始化失败具备回滚关闭逻辑
- [ ] 单元/集成测试覆盖核心迁移路径并通过
---
## 7. 风险与缓解
| 风险 | 级别 | 缓解 |
|---|---|---|
| 统一初始化函数引入顺序错误 | 中 | 显式 `SERVICE_STARTUP_ORDER` + 顺序测试 |
| Supabase 健康检查误报 | 中 | 使用真实轻量请求,不只做对象检查 |
| gateway 与生命周期耦合导致运行时错误 | 中 | 改为方法内按需取 client,并覆盖未初始化测试 |
| 迁移影响现有 auth 行为 | 中 | 保持 service 接口不变,补充回归测试 |
---
## 8. 完成定义(Completion Contract
1. Complexity: `S2`
2. Risk Tier: `L1`
3. Gates:
- 必需:`refactor-cleaner`
- 可选:`code-reviewer`(建议在合并前执行)
4. Verification evidence:
- 提供 lint/typecheck/unit/integration 命令结果
5. Remaining risks/follow-ups:
- 若后续新增第三方服务,沿用 `ServiceRegistry + 统一生命周期函数` 接入,不再在 `app.py` 手写初始化。
+359
View File
@@ -0,0 +1,359 @@
# Celery To Taskiq One-Shot Migration Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 在当前早期项目中一次性移除 Celery,并以 Taskiq 替换异步任务基础设施,保持 agent runtime 行为不变。
**Architecture:** 复用现有 `AgentService -> QueueClientLike` 抽象,仅替换基础设施层实现(任务声明、入队调用、worker 启动、配置与依赖)。保持 Redis 作为 broker/result 存储与事件流通道,避免改动业务服务层语义。
**Tech Stack:** FastAPI, Taskiq, taskiq-redis, Redis, pytest, uv
---
### Task 1: 依赖与配置切换(先 RED 后 GREEN)
**Files:**
- Modify: `pyproject.toml`
- Modify: `backend/src/core/config/settings.py`
- Test: `backend/tests/unit/core/config/test_taskiq_settings.py` (new)
**Step 1: Write the failing test**
```python
from core.config.settings import Settings
def test_taskiq_uses_redis_url_by_default() -> None:
settings = Settings()
assert settings.taskiq_broker_url.startswith("redis://")
def test_taskiq_queue_default_value() -> None:
settings = Settings()
assert settings.taskiq.default_queue == "default"
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/core/config/test_taskiq_settings.py -v`
Expected: FAIL`taskiq_broker_url` / `taskiq` 字段不存在)
**Step 3: Write minimal implementation**
```python
class TaskiqSettings(BaseModel):
broker_url: str | None = None
result_backend_url: str | None = None
default_queue: str = "default"
class Settings(BaseSettings):
taskiq: TaskiqSettings = TaskiqSettings()
@computed_field
@property
def taskiq_broker_url(self) -> str:
return self.taskiq.broker_url or self.redis.url
@computed_field
@property
def taskiq_result_backend_url(self) -> str:
return self.taskiq.result_backend_url or self.redis.url
```
`pyproject.toml` 同步变更:
- 删除 `celery>=...`
- 增加 `taskiq>=...`
- 增加 `taskiq-redis>=...`
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/core/config/test_taskiq_settings.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add pyproject.toml backend/src/core/config/settings.py backend/tests/unit/core/config/test_taskiq_settings.py
git commit -m "refactor(queue): replace celery config with taskiq settings"
```
### Task 2: 新建 Taskiq broker 与 worker 启动入口
**Files:**
- Create: `backend/src/core/taskiq/app.py`
- Create: `backend/tests/unit/core/taskiq/test_app.py`
- Delete: `backend/src/core/celery/app.py`
**Step 1: Write the failing test**
```python
from core.taskiq.app import broker
def test_taskiq_broker_is_configured() -> None:
assert broker is not None
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/core/taskiq/test_app.py -v`
Expected: FAIL(模块不存在)
**Step 3: Write minimal implementation**
```python
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config.settings import config
broker = ListQueueBroker(url=config.taskiq_broker_url).with_result_backend(
RedisAsyncResultBackend(redis_url=config.taskiq_result_backend_url)
)
```
说明:若当前 `taskiq-redis` 版本 API 名称有差异,以该版本官方 API 为准做等价实现。
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/core/taskiq/test_app.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/core/taskiq/app.py backend/tests/unit/core/taskiq/test_app.py backend/src/core/celery/app.py
git commit -m "feat(queue): add taskiq broker app and remove celery app"
```
### Task 3: 迁移任务定义(Celery task -> Taskiq task
**Files:**
- Modify: `backend/src/core/agent/infrastructure/queue/tasks.py`
- Test: `backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py` (new)
**Step 1: Write the failing test**
```python
from core.agent.infrastructure.queue.tasks import run_agent_task
async def test_run_agent_task_invalid_command_raises() -> None:
try:
await run_agent_task({"command": "unknown", "session_id": "00000000-0000-0000-0000-000000000001"})
raise AssertionError("expected ValueError")
except ValueError as exc:
assert "invalid command type" in str(exc)
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py -v`
Expected: FAIL(测试文件不存在或导入失败)
**Step 3: Write minimal implementation**
```python
from core.taskiq.app import broker
@broker.task(task_name="tasks.agent.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
```
并移除:
- `from core.celery.app import celery_app`
- `@celery_app.task(...)`
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/core/agent/infrastructure/queue/tasks.py backend/tests/unit/core/agent/infrastructure/queue/test_tasks.py
git commit -m "refactor(agent): migrate run command task to taskiq"
```
### Task 4: 迁移 API 入队客户端(.delay -> .kiq
**Files:**
- Modify: `backend/src/v1/agent/dependencies.py`
- Test: `backend/tests/unit/v1/agent/test_dependencies_queue.py` (new)
**Step 1: Write the failing test**
```python
class _FakeTask:
async def kiq(self, payload: dict[str, object]):
class _Result:
task_id = "task-123"
return _Result()
async def test_enqueue_returns_task_id(monkeypatch):
from v1.agent.dependencies import CeleryQueueClient
client = CeleryQueueClient() # 迁移后应重命名为 TaskiqQueueClient
monkeypatch.setattr("v1.agent.dependencies.run_command_task", _FakeTask())
task_id = await client.enqueue(command={"command": "run"}, dedup_key=None)
assert task_id == "task-123"
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/v1/agent/test_dependencies_queue.py -v`
Expected: FAIL(类型/方法不匹配)
**Step 3: Write minimal implementation**
```python
class TaskiqQueueClient:
async def enqueue(self, *, command: dict[str, object], dedup_key: str | None) -> str:
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
result = await run_command_task.kiq(payload)
task_id = str(result.task_id)
return task_id
```
并替换 DI
```python
queue=TaskiqQueueClient()
```
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/v1/agent/test_dependencies_queue.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/agent/dependencies.py backend/tests/unit/v1/agent/test_dependencies_queue.py
git commit -m "refactor(api): switch agent enqueue client from celery to taskiq"
```
### Task 5: 运维脚本与日志测试清理(一次性删除 Celery)
**Files:**
- Modify: `infra/scripts/app.sh`
- Delete: `backend/tests/unit/test_celery_logging.py`
- Modify/Create: `backend/tests/unit/core/logging/test_taskiq_logging.py` (if taskiq logging hook implemented)
- Modify: `backend/src/core/logging/__init__.py`(移除 celery logging export
**Step 1: Write the failing test**
```python
def test_worker_command_uses_taskiq() -> None:
content = Path("infra/scripts/app.sh").read_text()
assert "uv run taskiq worker" in content
assert "uv run celery" not in content
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/core/logging/test_taskiq_logging.py -v`
Expected: FAIL(脚本仍含 celery
**Step 3: Write minimal implementation**
`infra/scripts/app.sh` worker 命令替换为 Taskiq worker,例如:
```bash
uv run taskiq worker core.taskiq.app:broker core.agent.infrastructure.queue.tasks
```
删除所有 celery 进程清理匹配:
```bash
pgrep -f "taskiq.*worker"
pkill -f "taskiq.*worker"
```
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/core/logging/test_taskiq_logging.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add infra/scripts/app.sh backend/src/core/logging/__init__.py backend/tests/unit/core/logging/test_taskiq_logging.py backend/tests/unit/test_celery_logging.py
git commit -m "chore(infra): replace celery worker scripts and remove celery-specific tests"
```
### Task 6: 全量引用清理与回归验证
**Files:**
- Modify: `docs/runtime/runtime-runbook.md`
- Modify: 其他引用 Celery 的运行文档(按 `rg` 结果逐个更新)
**Step 1: Write the failing test**
```python
# 用命令断言替代代码测试
# rg -n "celery" backend/src infra/scripts docs/runtime pyproject.toml
```
**Step 2: Run check to verify it fails**
Run: `rg -n "celery" backend/src infra/scripts docs/runtime pyproject.toml`
Expected: 仍有旧引用
**Step 3: Write minimal implementation**
- 删除/替换剩余 Celery 代码、文档、配置。
- 保留历史变更记录中的 Celery 字样(如 bugs 归档)可接受,但运行路径必须为 0 引用。
**Step 4: Run verification suite**
Run:
- `uv run pytest backend/tests/unit -q`
- `uv run pytest backend/tests/integration -q`
- `uv run pytest backend/tests/e2e -q`(如环境不满足,记录原因)
- `uv run ruff check backend/src backend/tests`
- `uv run basedpyright`
- `rg -n "celery" backend/src infra/scripts pyproject.toml`
Expected:
- 测试与静态检查通过
- 运行路径无 Celery 引用
**Step 5: Commit**
```bash
git add docs/runtime/runtime-runbook.md pyproject.toml backend/src infra/scripts backend/tests
git commit -m "refactor(queue): complete one-shot migration from celery to taskiq"
```
### Task 7: L1 Review Gates 与交付确认
**Files:**
- No code changes required by default
**Step 1: Run required L1 gate (`refactor-cleaner`)**
Run: 使用 `refactor-cleaner` 审查迁移后冗余代码、死引用、命名一致性。
Expected: 无阻断问题。
**Step 2: Optional `code-reviewer` (recommended for infra switch)**
Run: 使用 `code-reviewer` 聚焦任务丢失、重复消费、幂等锁逻辑。
Expected: 无 CRITICAL/HIGH 问题。
**Step 3: Final evidence report**
输出内容必须包含:
- 执行命令列表
- 每条命令 PASS/FAIL
- 若有无法执行项(如 e2e 环境),给出原因与人工验证步骤
**Step 4: Commit review notes (optional)**
```bash
git add docs/plans/2026-03-06-taskiq-migration.md
git commit -m "docs(plan): taskiq one-shot migration execution checklist"
```
-144
View File
@@ -1,144 +0,0 @@
# Agent LLM Config Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:**`system_agents.config` 中的 `temperature` / `max_tokens` 以受约束方式加载到运行时,并在调用 LiteLLM 时按需透传。
**Architecture:** 在应用层 `RunService` 读取模型选择时同步读取并校验 `SystemAgents.config`;将校验后的 `SystemAgentLLMConfig` 传入 `CrewAIRuntime`;由 runtime 将配置转交给 LiteLLM clientclient 仅在值非 `None` 时向 `completion()` 传参,避免不必要的 provider 兼容风险。
**Tech Stack:** FastAPI, SQLAlchemy (async), Pydantic v2, LiteLLM, pytest
---
## 背景与修正点
- 当前真实调用链为:`RunService._load_agent_model_selection()` -> `create_runtime()` -> `CrewAIRuntime.execute()` -> `run_completion()`,并非 `load_stage_models()`
- `SystemAgentLLMConfig` 已存在:`backend/src/core/agent/domain/system_agent_config.py`
- `system_agents.config` 目前在初始化 YAML 侧有约束,但运行时 DB 读取仍需二次校验,防止脏数据绕过。
## 规则约束
- 严格 TDD:先写失败测试,再做实现。
- Python 命令统一使用 `uv run ...`
- 仅做增量改动,不回滚或覆盖与本任务无关的已有变更。
## 字段映射与透传策略
| 配置字段 | LiteLLM 参数 | 规则 |
|---|---|---|
| `temperature` | `temperature` | `None` 不透传;非空直接透传 |
| `max_tokens` | `max_tokens` | `None` 不透传;非空直接透传 |
---
### Task 1: 应用层加载并校验 Agent LLM Config
**Files:**
- Modify: `backend/src/core/agent/application/run_service.py`
- Test: `backend/tests/unit/core/agent/test_run_resume_service.py`
**Step 1: 写失败测试(RED**
新增单测覆盖以下行为:
1. `_load_agent_model_selection()` 返回三元组:`(model_code, provider_name, llm_config)`
2. 当 DB `config``{}` 时,`llm_config.temperature/max_tokens``None`
3. 当 DB `config` 含非法值(如 `temperature=3`)时抛 `ValueError`
**Step 2: 运行测试确认失败**
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
Expected: 新增断言失败(返回值结构/异常行为不匹配)。
**Step 3: 最小实现(GREEN**
`run_service.py`
1. 查询 `SystemAgents.config`
2.`SystemAgentLLMConfig.model_validate(config or {})` 校验。
3.`_load_agent_model_selection()` 改为返回三元组。
4.`run()` 中把 `llm_config` 传递到 `create_runtime(...)`
**Step 4: 运行测试确认通过**
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
Expected: PASS。
---
### Task 2: Runtime 与 LiteLLM Client 支持可选参数透传
**Files:**
- Modify: `backend/src/core/agent/infrastructure/crewai/factory.py`
- Modify: `backend/src/core/agent/infrastructure/crewai/runtime.py`
- Modify: `backend/src/core/agent/infrastructure/litellm/client.py`
- Test: `backend/tests/unit/core/agent/test_crewai_runtime.py`
**Step 1: 写失败测试(RED**
`test_crewai_runtime.py` 增加用例:
1. 传入 `temperature/max_tokens` 时,`run_completion` 收到对应参数。
2. 参数为 `None` 时,不应被透传到 LiteLLM。
必要时新增 `backend/tests/unit/core/agent/test_litellm_client.py`,单测 `run_completion` 的 kwargs 组装逻辑。
**Step 2: 运行测试确认失败**
Run: `uv run pytest backend/tests/unit/core/agent/test_crewai_runtime.py -q`
Expected: 新增断言失败(参数未透传或未过滤 `None`)。
**Step 3: 最小实现(GREEN**
1. `create_runtime()` 增加 `llm_config` 参数并传给 `CrewAIRuntime`
2. `CrewAIRuntime` 保存 `llm_config`,执行时调用:
- `run_completion(..., temperature=llm_config.temperature, max_tokens=llm_config.max_tokens)`
3. `run_completion()` 改为支持可选 `temperature/max_tokens`,内部仅在非 `None` 时加入 kwargs 再调用 `completion()`
**Step 4: 运行测试确认通过**
Run: `uv run pytest backend/tests/unit/core/agent/test_crewai_runtime.py -q`
Expected: PASS。
---
### Task 3: 初始化数据补齐与回归验证
**Files:**
- Modify: `backend/src/core/config/static/database/system_agents.yaml`
- Modify: `backend/src/core/config/initial/init_data.py`(如需补充类型兜底)
- Test: `backend/tests/unit/core/agent/test_run_resume_service.py`
**Step 1: 写失败测试(RED**
补充断言:YAML 读取后 `config` 可为空或包含 `max_tokens: null`,初始化逻辑不会报错,且生成结构符合 `SystemAgentLLMConfig`
**Step 2: 运行测试确认失败**
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
Expected: 新增断言失败。
**Step 3: 最小实现(GREEN**
1.`system_agents.yaml` 为各 agent 配置显式补充 `max_tokens: null`
2. `init_data.py` 保持 `config: SystemAgentLLMConfig | None = None`,写库时统一序列化为 dict。
**Step 4: 运行测试确认通过**
Run: `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py -q`
Expected: PASS。
---
## 最终验证
1. `uv run pytest backend/tests/unit/core/agent/test_run_resume_service.py backend/tests/unit/core/agent/test_crewai_runtime.py -q`
2. `uv run pytest backend/tests/integration/core/agent/test_queue_run_resume.py -q`
3. `uv run ruff check backend/src backend/tests`
4. `uv run basedpyright`
预期:全部通过;若集成测试依赖本地 DB 状态导致跳过/失败,需记录原因并给出手工验证步骤。
## 完成标准
- `RunService` 从 DB 读取并校验 `config`
- runtime 到 LiteLLM 链路支持 `temperature/max_tokens` 可选透传。
- `None` 不透传。
- 单测与相关集成测试通过,并给出命令级证据。
+5 -4
View File
@@ -69,7 +69,7 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T db \
### 启动应用进程
```bash
bash infra/scripts/app-up.sh
bash infra/scripts/app.sh start
```
该脚本会在 tmux `social-dev` 会话中拉起:
@@ -172,6 +172,7 @@ curl -sS "${WEB_BASE_URL}/api/v1/profile/me" \
- 症状:队列堆积,任务长时间 pending。
- 定位:检查 `worker-*` tmux 窗口和对应日志文件。
- 修复:重启 tmux 会话,确认并发配置与队列名(critical/default/bulk)。
- 说明:Taskiq 路径当前仅消费 `SOCIAL_WORKER__GROUPS__*__CONCURRENCY`,旧 Celery 参数(prefetch/time_limit 等)已废弃。
### 2.1) Agent Runtime run/resume 事件不闭环
@@ -179,7 +180,7 @@ curl -sS "${WEB_BASE_URL}/api/v1/profile/me" \
- 定位步骤:
```bash
# 1) 检查 celery worker 是否消费 agent 任务
# 1) 检查 taskiq worker 是否消费 agent 任务
grep -E "tasks\.agent\.run_command|RUN_STARTED|RUN_FINISHED|RUN_ERROR" logs/worker-default.log
# 2) 检查 API SSE 事件读取(带 Last-Event-ID
@@ -192,7 +193,7 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml exec -T redis
```
- 修复建议:
- 若 worker 无消费:重启 `worker-default` 窗口并确认 `core.agent.infrastructure.queue.tasks` 已被 Celery include
- 若 worker 无消费:重启 `worker-default` 窗口并确认 `core.agent.infrastructure.queue.tasks` 已被 Taskiq worker 加载
- 若 worker 有事件但 API 无输出:排查 Redis stream 前缀配置与 session_id 是否一致。
- 若出现 `RUN_ERROR`:按 error_id 回查后端日志,不在 API/SSE 中暴露敏感上下文。
@@ -270,4 +271,4 @@ docker compose --env-file .env -f infra/docker/docker-compose.yml up -d --force-
| 2026-02-28 | 邀请码功能:新增 invite_codes 表、profiles.referred_by,注册时可选填邀请码并记录邀请关系 |
| 2026-03-02 | 文档整理:修正 auth 端点名称(/verifications)、补充 profile 路由文档、修复 L2/L3 验证命令 |
| 2026-03-02 | 修正 bootstrap 命令:init-job 需要使用 `uv run python -m core.runtime.cli bootstrap` |
| 2026-03-05 | 新增 Agent Runtime run/resume/events 运维排障流程(Celery + Redis + Last-Event-ID |
| 2026-03-05 | 新增 Agent Runtime run/resume/events 运维排障流程(Taskiq + Redis + Last-Event-ID |
+6 -6
View File
@@ -56,9 +56,9 @@ ${SOCIAL_WEB__GUNICORN__WORKER_CLASS:-uvicorn.workers.UvicornWorker} --timeout \
${SOCIAL_WEB__GUNICORN__TIMEOUT:-60} \
--log-level ${SOCIAL_RUNTIME__LOG_LEVEL:-info}"
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run celery -A core.celery.app worker --loglevel=info --queues=critical --concurrency=${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run celery -A core.celery.app worker --loglevel=info --queues=default --concurrency=${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run celery -A core.celery.app worker --loglevel=info --queues=bulk --concurrency=${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
WORKER_CRITICAL_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-critical uv run taskiq worker core.taskiq.app:critical_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__CRITICAL__CONCURRENCY:-2}"
WORKER_DEFAULT_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-default uv run taskiq worker core.taskiq.app:default_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__DEFAULT__CONCURRENCY:-2}"
WORKER_BULK_CMD="cd '$ROOT_DIR' && PYTHONPATH=backend/src SOCIAL_RUNTIME__SERVICE_NAME=worker-bulk uv run taskiq worker core.taskiq.app:bulk_broker core.agent.infrastructure.queue.tasks --workers ${SOCIAL_WORKER__GROUPS__BULK__CONCURRENCY:-1}"
tmux new-session -d -s "$SESSION_NAME" -n web "bash -lc \"$WEB_CMD; echo '[web] exited'; exec bash\""
tmux new-window -t "$SESSION_NAME" -n worker-critical "bash -lc \"$WORKER_CRITICAL_CMD; echo '[worker-critical] exited'; exec bash\""
@@ -92,9 +92,9 @@ stop() {
echo "Killing orphaned gunicorn processes..."
pkill -f "gunicorn.*app:app"
fi
if pgrep -f "celery.*worker" > /dev/null 2>&1; then
echo "Killing orphaned celery processes..."
pkill -f "celery.*worker"
if pgrep -f "taskiq.*worker" > /dev/null 2>&1; then
echo "Killing orphaned taskiq processes..."
pkill -f "taskiq.*worker"
fi
echo "Session stopped and cleaned up."
+3 -1
View File
@@ -8,7 +8,6 @@ dependencies = [
"alembic>=1.18.3",
"asyncpg>=0.31.0",
"basedpyright>=1.37.2",
"celery>=5.6.2",
"crewai>=1.6.1",
"crewai-tools>=1.6.1",
"email-validator>=2.3.0",
@@ -22,8 +21,11 @@ dependencies = [
"redis>=7.1.0",
"sqlalchemy[asyncio]>=2.0.46",
"structlog>=24.4.0",
"taskiq>=0.11.0",
"taskiq-redis>=1.0.0",
"supabase>=2.27.2",
"uvicorn[standard]>=0.40.0",
"gunicorn>=25.1.0",
]
[project.optional-dependencies]