From 2c59fe5ee2fad38c7f9519874540f677bddcc34e Mon Sep 17 00:00:00 2001 From: qzl Date: Fri, 6 Mar 2026 16:09:15 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=BB=9F=E4=B8=80=20Redis=20?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E7=AE=A1=E7=90=86=EF=BC=8C=E6=94=B9=E7=94=A8?= =?UTF-8?q?=20RedisService?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - App 启动时初始化 RedisService,关闭时释放连接 - Celery worker 通过 worker_process_init 钩子初始化 Redis - Agent 端点改用 RedisService 替代直接创建连接 - Celery task 改为 async def,使用统一连接 - 删除无用的 infra 模块和 core/http/models - 日志脱敏,不记录 Redis 密码 - 初始化失败时 fail-fast - 异常发布添加二级保护 --- backend/src/app.py | 32 ++++++- .../infrastructure/events/redis_stream.py | 8 ++ .../core/agent/infrastructure/queue/tasks.py | 91 +++++++------------ backend/src/core/celery/app.py | 21 +++++ backend/src/core/http/models.py | 7 -- backend/src/v1/agent/dependencies.py | 7 +- backend/src/v1/infra/__init__.py | 1 - backend/src/v1/infra/dependencies.py | 7 -- backend/src/v1/infra/router.py | 28 ------ backend/src/v1/infra/schemas.py | 15 --- backend/src/v1/router.py | 8 -- backend/tests/e2e/test_mobile_health_e2e.py | 2 +- .../integration/test_mobile_app_skeleton.py | 10 -- .../tests/unit/core/agent/test_queue_tasks.py | 14 +-- 14 files changed, 101 insertions(+), 150 deletions(-) delete mode 100644 backend/src/core/http/models.py delete mode 100644 backend/src/v1/infra/__init__.py delete mode 100644 backend/src/v1/infra/dependencies.py delete mode 100644 backend/src/v1/infra/router.py delete mode 100644 backend/src/v1/infra/schemas.py diff --git a/backend/src/app.py b/backend/src/app.py index ffdf34d..ea5a66e 100644 --- a/backend/src/app.py +++ b/backend/src/app.py @@ -1,18 +1,26 @@ from __future__ import annotations +from contextlib import asynccontextmanager +from typing import AsyncGenerator + from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from pydantic import BaseModel from starlette.exceptions import HTTPException as StarletteHTTPException from core.config.settings import config -from core.http.models import HealthResponse from core.http.response import build_problem_details from core.logging import configure_logging, get_logger, log_service_banner +from services.base.redis import redis_service from v1.router import router as mobile_router +class HealthResponse(BaseModel): + status: str + + configure_logging(config) log_service_banner( @@ -20,7 +28,26 @@ log_service_banner( environment=config.runtime.environment, ) -app = FastAPI() +logger = get_logger("api.app") + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + initialized = await redis_service.initialize() + if not initialized: + logger.error("Redis service failed to initialize, aborting startup") + raise RuntimeError("Redis service initialization failed") + logger.info( + "Redis service initialized", + host=config.redis.host, + db=config.redis.db, + ) + yield + await redis_service.close() + logger.info("Redis service closed") + + +app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=config.cors.allow_origins, @@ -29,7 +56,6 @@ app.add_middleware( allow_headers=config.cors.allow_headers, ) app.include_router(mobile_router) -logger = get_logger("api.app") logger.info( "Web application initialized", diff --git a/backend/src/core/agent/infrastructure/events/redis_stream.py b/backend/src/core/agent/infrastructure/events/redis_stream.py index 7b3619f..301a63f 100644 --- a/backend/src/core/agent/infrastructure/events/redis_stream.py +++ b/backend/src/core/agent/infrastructure/events/redis_stream.py @@ -31,6 +31,14 @@ class RedisStreamEventStore: payload = json.dumps(event, ensure_ascii=True, separators=(",", ":")) return str(self._client.xadd(stream, {"event": payload})) + async def append_event(self, *, session_id: UUID, event: dict[str, Any]) -> str: + stream = self._stream_name(session_id) + payload = json.dumps(event, ensure_ascii=True, separators=(",", ":")) + result = self._client.xadd(stream, {"event": payload}) + if inspect.isawaitable(result): + return str(await result) + return str(result) + async def read_events( self, *, diff --git a/backend/src/core/agent/infrastructure/queue/tasks.py b/backend/src/core/agent/infrastructure/queue/tasks.py index 2dca983..962d8d8 100644 --- a/backend/src/core/agent/infrastructure/queue/tasks.py +++ b/backend/src/core/agent/infrastructure/queue/tasks.py @@ -1,28 +1,21 @@ from __future__ import annotations -import asyncio -import threading -from typing import Any, Callable, Protocol, cast +from typing import Any, Protocol, cast from uuid import UUID -import redis - from core.agent.application.resume_service import ResumeService from core.agent.application.run_service import RunService from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore from core.celery.app import celery_app from core.config.settings import config from core.logging import get_logger +from services.base.redis import redis_service logger = get_logger("core.agent.infrastructure.queue.tasks") -_background_loop: asyncio.AbstractEventLoop | None = None -_background_thread: threading.Thread | None = None -_background_ready = threading.Event() - class PublishEvent(Protocol): - def __call__(self, event_type: str, payload: dict[str, object]) -> None: ... + async def __call__(self, event_type: str, payload: dict[str, object]) -> None: ... class RunServiceLike(Protocol): @@ -35,36 +28,9 @@ class ResumeServiceLike(Protocol): ) -> dict[str, object]: ... -def _run_async(task: Callable[[], Any]) -> Any: - loop = _ensure_background_loop() - future = asyncio.run_coroutine_threadsafe(task(), loop) - return future.result() - - -def _ensure_background_loop() -> asyncio.AbstractEventLoop: - global _background_loop, _background_thread - if _background_loop is not None: - return _background_loop - - def _loop_worker() -> None: - global _background_loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - _background_loop = loop - _background_ready.set() - loop.run_forever() - - _background_thread = threading.Thread(target=_loop_worker, daemon=True) - _background_thread.start() - _background_ready.wait(timeout=5) - if _background_loop is None: - raise RuntimeError("failed to initialize background event loop") - return _background_loop - - -def _build_redis_publisher() -> PublishEvent: +async def _build_redis_publisher() -> PublishEvent: settings = cast(Any, config) - client = redis.from_url(settings.redis.url, decode_responses=True) + client = redis_service.get_client() event_store = RedisStreamEventStore( client=client, stream_prefix=settings.agent_runtime.redis_stream_prefix, @@ -72,11 +38,11 @@ def _build_redis_publisher() -> PublishEvent: block_ms=settings.agent_runtime.redis_stream_block_ms, ) - def _publish(event_type: str, payload: dict[str, object]) -> None: + async def _publish(event_type: str, payload: dict[str, object]) -> None: session_id = str(payload.get("session_id", "")).strip() if not session_id: raise ValueError("session_id is required in event payload") - event_store.append_event_sync( + await event_store.append_event( session_id=UUID(session_id), event={"type": event_type, "data": payload}, ) @@ -84,14 +50,14 @@ def _build_redis_publisher() -> PublishEvent: return _publish -def run_agent_task( +async def run_agent_task( command: dict[str, Any], *, publish_event: PublishEvent | None = None, run_service: RunServiceLike | None = None, resume_service: ResumeServiceLike | None = None, ) -> dict[str, object]: - publisher = publish_event or _build_redis_publisher() + publisher = publish_event or await _build_redis_publisher() service_run = run_service or RunService() service_resume = resume_service or ResumeService() @@ -105,31 +71,27 @@ def run_agent_task( UUID(session_id) start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED" - publisher(start_event, {"session_id": session_id}) + await publisher(start_event, {"session_id": session_id}) try: if command_type == "resume": tool_call_id = str(command.get("tool_call_id", "")) if not tool_call_id: raise ValueError("tool_call_id is required") - result = _run_async( - lambda: service_resume.resume( - session_id=session_id, - tool_call_id=tool_call_id, - ) + result = await service_resume.resume( + session_id=session_id, + tool_call_id=tool_call_id, ) else: user_input = str(command.get("user_input", "")) if not user_input: raise ValueError("user_input is required") - result = _run_async( - lambda: service_run.run( - session_id=session_id, - user_input=user_input, - ) + result = await service_run.run( + session_id=session_id, + user_input=user_input, ) - publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result}) + await publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result}) extra_events = result.get("events") if isinstance(result, dict) else None if isinstance(extra_events, list): for event in extra_events: @@ -140,8 +102,8 @@ def run_agent_task( if not isinstance(event_type, str) or not isinstance(event_data, dict): continue payload = {"session_id": session_id, **event_data} - publisher(event_type, payload) - publisher("RUN_FINISHED", {"session_id": session_id}) + await publisher(event_type, payload) + await publisher("RUN_FINISHED", {"session_id": session_id}) return result except Exception: # noqa: BLE001 error_id = "agent_runtime_failed" @@ -150,10 +112,19 @@ def run_agent_task( session_id=session_id, error_id=error_id, ) - publisher("RUN_ERROR", {"session_id": session_id, "error_id": error_id}) + try: + await publisher( + "RUN_ERROR", {"session_id": session_id, "error_id": error_id} + ) + except Exception as publish_exc: # noqa: BLE001 + logger.warning( + "Failed to publish RUN_ERROR event", + session_id=session_id, + error=str(publish_exc), + ) raise @celery_app.task(name="tasks.agent.run_command") -def run_command_task(command: dict[str, Any]) -> dict[str, object]: - return run_agent_task(command) +async def run_command_task(command: dict[str, Any]) -> dict[str, object]: + return await run_agent_task(command) diff --git a/backend/src/core/celery/app.py b/backend/src/core/celery/app.py index 51517f4..b2674cc 100644 --- a/backend/src/core/celery/app.py +++ b/backend/src/core/celery/app.py @@ -1,10 +1,29 @@ from __future__ import annotations from celery import Celery +from celery import signals as celery_signals from kombu import Queue from core.config.settings import config +from core.logging import get_logger from core.logging.celery import configure_celery_app +from services.base.redis import redis_service + +logger = get_logger("core.celery") + + +def _init_redis_on_worker_startup(**_: object) -> None: + import asyncio + + logger.info("Initializing Redis service for Celery worker") + try: + result = asyncio.run(redis_service.initialize()) + if result: + logger.info("Redis service initialized for Celery worker") + else: + logger.warning("Redis service initialization returned False") + except Exception as exc: # noqa: BLE001 + logger.error("Failed to initialize Redis for Celery worker", error=str(exc)) def create_celery_app() -> Celery: @@ -50,3 +69,5 @@ def create_celery_app() -> Celery: celery_app = create_celery_app() + +celery_signals.worker_process_init.connect(_init_redis_on_worker_startup) diff --git a/backend/src/core/http/models.py b/backend/src/core/http/models.py deleted file mode 100644 index a31ae83..0000000 --- a/backend/src/core/http/models.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import annotations - -from pydantic import BaseModel - - -class HealthResponse(BaseModel): - status: str diff --git a/backend/src/v1/agent/dependencies.py b/backend/src/v1/agent/dependencies.py index b08eabd..eb45b8a 100644 --- a/backend/src/v1/agent/dependencies.py +++ b/backend/src/v1/agent/dependencies.py @@ -4,21 +4,20 @@ from typing import Any, cast from uuid import UUID from fastapi import Depends -import redis.asyncio as redis from sqlalchemy.ext.asyncio import AsyncSession from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore from core.agent.infrastructure.queue.tasks import run_command_task from core.config.settings import config from core.db import get_db +from services.base.redis import redis_service from v1.agent.repository import AgentRepository from v1.agent.service import AgentService class CeleryQueueClient: def __init__(self) -> None: - settings = cast(Any, config) - self._redis = redis.from_url(settings.redis.url, decode_responses=True) + self._redis = redis_service.get_client() async def enqueue( self, *, command: dict[str, object], dedup_key: str | None @@ -46,7 +45,7 @@ class CeleryQueueClient: class RedisEventStream: def __init__(self) -> None: settings = cast(Any, config) - client = redis.from_url(settings.redis.url, decode_responses=True) + client = redis_service.get_client() self._store = RedisStreamEventStore( client=client, stream_prefix=settings.agent_runtime.redis_stream_prefix, diff --git a/backend/src/v1/infra/__init__.py b/backend/src/v1/infra/__init__.py deleted file mode 100644 index 9d48db4..0000000 --- a/backend/src/v1/infra/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import annotations diff --git a/backend/src/v1/infra/dependencies.py b/backend/src/v1/infra/dependencies.py deleted file mode 100644 index b37b466..0000000 --- a/backend/src/v1/infra/dependencies.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import annotations - -from services.base.redis import RedisService, redis_service - - -def get_redis_service() -> RedisService: - return redis_service diff --git a/backend/src/v1/infra/router.py b/backend/src/v1/infra/router.py deleted file mode 100644 index 318a116..0000000 --- a/backend/src/v1/infra/router.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from fastapi import APIRouter, Depends - -from services.base.redis import RedisService -from v1.infra.dependencies import get_redis_service -from v1.infra.schemas import InfraHealthResponse, ServiceHealth - - -router = APIRouter(prefix="/infra", tags=["infra"]) - - -@router.get("/health", response_model=InfraHealthResponse) -async def infra_health( - redis_service: RedisService = Depends(get_redis_service), -) -> InfraHealthResponse: - if not redis_service.is_initialized: - await redis_service.initialize() - - redis_health = await redis_service.health_check() - status = "healthy" if redis_health["status"] == "healthy" else "unhealthy" - - return InfraHealthResponse( - status=status, - services={ - "redis": ServiceHealth(**redis_health), - }, - ) diff --git a/backend/src/v1/infra/schemas.py b/backend/src/v1/infra/schemas.py deleted file mode 100644 index a051ca5..0000000 --- a/backend/src/v1/infra/schemas.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, Literal - -from pydantic import BaseModel - - -class ServiceHealth(BaseModel): - status: Literal["healthy", "unhealthy"] - details: Dict[str, Any] - - -class InfraHealthResponse(BaseModel): - status: Literal["healthy", "unhealthy"] - services: Dict[str, ServiceHealth] diff --git a/backend/src/v1/router.py b/backend/src/v1/router.py index 8e880b9..54ed53d 100644 --- a/backend/src/v1/router.py +++ b/backend/src/v1/router.py @@ -2,12 +2,10 @@ from __future__ import annotations from fastapi import APIRouter -from core.http.models import HealthResponse from v1.agent.router import router as agent_router from v1.auth.router import router as auth_router from v1.friendships.router import router as friendships_router from v1.inbox_messages.router import router as inbox_messages_router -from v1.infra.router import router as infra_router from v1.schedule_items.router import router as schedule_items_router from v1.users.router import router as users_router @@ -16,12 +14,6 @@ router = APIRouter(prefix="/api/v1") router.include_router(auth_router) router.include_router(agent_router) router.include_router(friendships_router) -router.include_router(infra_router) router.include_router(users_router) router.include_router(schedule_items_router) router.include_router(inbox_messages_router) - - -@router.get("/health", response_model=HealthResponse) -async def health() -> HealthResponse: - return HealthResponse(status="ok") diff --git a/backend/tests/e2e/test_mobile_health_e2e.py b/backend/tests/e2e/test_mobile_health_e2e.py index dfc1b18..0591cb5 100644 --- a/backend/tests/e2e/test_mobile_health_e2e.py +++ b/backend/tests/e2e/test_mobile_health_e2e.py @@ -46,7 +46,7 @@ def test_mobile_health_e2e() -> None: base_url=f"http://{host}:{port}" ) try: - response = request_context.get("/api/v1/health") + response = request_context.get("/health") assert response.status == 200 body = response.json() assert body["status"] == "ok" diff --git a/backend/tests/integration/test_mobile_app_skeleton.py b/backend/tests/integration/test_mobile_app_skeleton.py index 8a55537..241a6e2 100644 --- a/backend/tests/integration/test_mobile_app_skeleton.py +++ b/backend/tests/integration/test_mobile_app_skeleton.py @@ -15,16 +15,6 @@ def test_app_health_returns_envelope() -> None: assert body["status"] == "ok" -def test_mobile_router_health_returns_envelope() -> None: - client = TestClient(app) - - response = client.get("/api/v1/health") - - assert response.status_code == 200 - body = response.json() - assert body["status"] == "ok" - - def test_not_found_returns_error_envelope() -> None: client = TestClient(app) diff --git a/backend/tests/unit/core/agent/test_queue_tasks.py b/backend/tests/unit/core/agent/test_queue_tasks.py index d8a63a0..c55813d 100644 --- a/backend/tests/unit/core/agent/test_queue_tasks.py +++ b/backend/tests/unit/core/agent/test_queue_tasks.py @@ -15,15 +15,16 @@ class _FakeResumeService: return {"session_id": session_id, "tool_call_id": tool_call_id} -def test_run_agent_task_emits_started_runtime_and_finished_events() -> None: +@pytest.mark.asyncio +async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None: session_id = "00000000-0000-0000-0000-000000000001" events: list[str] = [] - def _publish(event_type: str, payload: dict[str, object]) -> None: + async def _publish(event_type: str, payload: dict[str, object]) -> None: del payload events.append(event_type) - result = run_agent_task( + result = await run_agent_task( { "command": "run", "session_id": session_id, @@ -38,7 +39,8 @@ def test_run_agent_task_emits_started_runtime_and_finished_events() -> None: assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"] -def test_run_agent_task_emits_error_event_on_exception() -> None: +@pytest.mark.asyncio +async def test_run_agent_task_emits_error_event_on_exception() -> None: session_id = "00000000-0000-0000-0000-000000000001" class _BrokenRunService(_FakeRunService): @@ -48,12 +50,12 @@ def test_run_agent_task_emits_error_event_on_exception() -> None: events: list[str] = [] - def _publish(event_type: str, payload: dict[str, object]) -> None: + async def _publish(event_type: str, payload: dict[str, object]) -> None: del payload events.append(event_type) with pytest.raises(RuntimeError): - run_agent_task( + await run_agent_task( { "command": "run", "session_id": session_id,