refactor: 统一 Redis 连接管理,改用 RedisService

- App 启动时初始化 RedisService,关闭时释放连接
- Celery worker 通过 worker_process_init 钩子初始化 Redis
- Agent 端点改用 RedisService 替代直接创建连接
- Celery task 改为 async def,使用统一连接
- 删除无用的 infra 模块和 core/http/models
- 日志脱敏,不记录 Redis 密码
- 初始化失败时 fail-fast
- 异常发布添加二级保护
This commit is contained in:
qzl
2026-03-06 16:09:15 +08:00
parent c5ccfc4b88
commit 2c59fe5ee2
14 changed files with 101 additions and 150 deletions
+29 -3
View File
@@ -1,18 +1,26 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from starlette.exceptions import HTTPException as StarletteHTTPException
from core.config.settings import config
from core.http.models import HealthResponse
from core.http.response import build_problem_details
from core.logging import configure_logging, get_logger, log_service_banner
from services.base.redis import redis_service
from v1.router import router as mobile_router
class HealthResponse(BaseModel):
status: str
configure_logging(config)
log_service_banner(
@@ -20,7 +28,26 @@ log_service_banner(
environment=config.runtime.environment,
)
app = FastAPI()
logger = get_logger("api.app")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
initialized = await redis_service.initialize()
if not initialized:
logger.error("Redis service failed to initialize, aborting startup")
raise RuntimeError("Redis service initialization failed")
logger.info(
"Redis service initialized",
host=config.redis.host,
db=config.redis.db,
)
yield
await redis_service.close()
logger.info("Redis service closed")
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors.allow_origins,
@@ -29,7 +56,6 @@ app.add_middleware(
allow_headers=config.cors.allow_headers,
)
app.include_router(mobile_router)
logger = get_logger("api.app")
logger.info(
"Web application initialized",
@@ -31,6 +31,14 @@ class RedisStreamEventStore:
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
return str(self._client.xadd(stream, {"event": payload}))
async def append_event(self, *, session_id: UUID, event: dict[str, Any]) -> str:
stream = self._stream_name(session_id)
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
result = self._client.xadd(stream, {"event": payload})
if inspect.isawaitable(result):
return str(await result)
return str(result)
async def read_events(
self,
*,
@@ -1,28 +1,21 @@
from __future__ import annotations
import asyncio
import threading
from typing import Any, Callable, Protocol, cast
from typing import Any, Protocol, cast
from uuid import UUID
import redis
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.celery.app import celery_app
from core.config.settings import config
from core.logging import get_logger
from services.base.redis import redis_service
logger = get_logger("core.agent.infrastructure.queue.tasks")
_background_loop: asyncio.AbstractEventLoop | None = None
_background_thread: threading.Thread | None = None
_background_ready = threading.Event()
class PublishEvent(Protocol):
def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
async def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
class RunServiceLike(Protocol):
@@ -35,36 +28,9 @@ class ResumeServiceLike(Protocol):
) -> dict[str, object]: ...
def _run_async(task: Callable[[], Any]) -> Any:
loop = _ensure_background_loop()
future = asyncio.run_coroutine_threadsafe(task(), loop)
return future.result()
def _ensure_background_loop() -> asyncio.AbstractEventLoop:
global _background_loop, _background_thread
if _background_loop is not None:
return _background_loop
def _loop_worker() -> None:
global _background_loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
_background_loop = loop
_background_ready.set()
loop.run_forever()
_background_thread = threading.Thread(target=_loop_worker, daemon=True)
_background_thread.start()
_background_ready.wait(timeout=5)
if _background_loop is None:
raise RuntimeError("failed to initialize background event loop")
return _background_loop
def _build_redis_publisher() -> PublishEvent:
async def _build_redis_publisher() -> PublishEvent:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
client = redis_service.get_client()
event_store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
@@ -72,11 +38,11 @@ def _build_redis_publisher() -> PublishEvent:
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
def _publish(event_type: str, payload: dict[str, object]) -> None:
async def _publish(event_type: str, payload: dict[str, object]) -> None:
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
raise ValueError("session_id is required in event payload")
event_store.append_event_sync(
await event_store.append_event(
session_id=UUID(session_id),
event={"type": event_type, "data": payload},
)
@@ -84,14 +50,14 @@ def _build_redis_publisher() -> PublishEvent:
return _publish
def run_agent_task(
async def run_agent_task(
command: dict[str, Any],
*,
publish_event: PublishEvent | None = None,
run_service: RunServiceLike | None = None,
resume_service: ResumeServiceLike | None = None,
) -> dict[str, object]:
publisher = publish_event or _build_redis_publisher()
publisher = publish_event or await _build_redis_publisher()
service_run = run_service or RunService()
service_resume = resume_service or ResumeService()
@@ -105,31 +71,27 @@ def run_agent_task(
UUID(session_id)
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
publisher(start_event, {"session_id": session_id})
await publisher(start_event, {"session_id": session_id})
try:
if command_type == "resume":
tool_call_id = str(command.get("tool_call_id", ""))
if not tool_call_id:
raise ValueError("tool_call_id is required")
result = _run_async(
lambda: service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
result = await service_resume.resume(
session_id=session_id,
tool_call_id=tool_call_id,
)
else:
user_input = str(command.get("user_input", ""))
if not user_input:
raise ValueError("user_input is required")
result = _run_async(
lambda: service_run.run(
session_id=session_id,
user_input=user_input,
)
result = await service_run.run(
session_id=session_id,
user_input=user_input,
)
publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
await publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
extra_events = result.get("events") if isinstance(result, dict) else None
if isinstance(extra_events, list):
for event in extra_events:
@@ -140,8 +102,8 @@ def run_agent_task(
if not isinstance(event_type, str) or not isinstance(event_data, dict):
continue
payload = {"session_id": session_id, **event_data}
publisher(event_type, payload)
publisher("RUN_FINISHED", {"session_id": session_id})
await publisher(event_type, payload)
await publisher("RUN_FINISHED", {"session_id": session_id})
return result
except Exception: # noqa: BLE001
error_id = "agent_runtime_failed"
@@ -150,10 +112,19 @@ def run_agent_task(
session_id=session_id,
error_id=error_id,
)
publisher("RUN_ERROR", {"session_id": session_id, "error_id": error_id})
try:
await publisher(
"RUN_ERROR", {"session_id": session_id, "error_id": error_id}
)
except Exception as publish_exc: # noqa: BLE001
logger.warning(
"Failed to publish RUN_ERROR event",
session_id=session_id,
error=str(publish_exc),
)
raise
@celery_app.task(name="tasks.agent.run_command")
def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return run_agent_task(command)
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agent_task(command)
+21
View File
@@ -1,10 +1,29 @@
from __future__ import annotations
from celery import Celery
from celery import signals as celery_signals
from kombu import Queue
from core.config.settings import config
from core.logging import get_logger
from core.logging.celery import configure_celery_app
from services.base.redis import redis_service
logger = get_logger("core.celery")
def _init_redis_on_worker_startup(**_: object) -> None:
import asyncio
logger.info("Initializing Redis service for Celery worker")
try:
result = asyncio.run(redis_service.initialize())
if result:
logger.info("Redis service initialized for Celery worker")
else:
logger.warning("Redis service initialization returned False")
except Exception as exc: # noqa: BLE001
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
def create_celery_app() -> Celery:
@@ -50,3 +69,5 @@ def create_celery_app() -> Celery:
celery_app = create_celery_app()
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
-7
View File
@@ -1,7 +0,0 @@
from __future__ import annotations
from pydantic import BaseModel
class HealthResponse(BaseModel):
status: str
+3 -4
View File
@@ -4,21 +4,20 @@ from typing import Any, cast
from uuid import UUID
from fastapi import Depends
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.config.settings import config
from core.db import get_db
from services.base.redis import redis_service
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
class CeleryQueueClient:
def __init__(self) -> None:
settings = cast(Any, config)
self._redis = redis.from_url(settings.redis.url, decode_responses=True)
self._redis = redis_service.get_client()
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
@@ -46,7 +45,7 @@ class CeleryQueueClient:
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
client = redis_service.get_client()
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
-1
View File
@@ -1 +0,0 @@
from __future__ import annotations
-7
View File
@@ -1,7 +0,0 @@
from __future__ import annotations
from services.base.redis import RedisService, redis_service
def get_redis_service() -> RedisService:
return redis_service
-28
View File
@@ -1,28 +0,0 @@
from __future__ import annotations
from fastapi import APIRouter, Depends
from services.base.redis import RedisService
from v1.infra.dependencies import get_redis_service
from v1.infra.schemas import InfraHealthResponse, ServiceHealth
router = APIRouter(prefix="/infra", tags=["infra"])
@router.get("/health", response_model=InfraHealthResponse)
async def infra_health(
redis_service: RedisService = Depends(get_redis_service),
) -> InfraHealthResponse:
if not redis_service.is_initialized:
await redis_service.initialize()
redis_health = await redis_service.health_check()
status = "healthy" if redis_health["status"] == "healthy" else "unhealthy"
return InfraHealthResponse(
status=status,
services={
"redis": ServiceHealth(**redis_health),
},
)
-15
View File
@@ -1,15 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, Literal
from pydantic import BaseModel
class ServiceHealth(BaseModel):
status: Literal["healthy", "unhealthy"]
details: Dict[str, Any]
class InfraHealthResponse(BaseModel):
status: Literal["healthy", "unhealthy"]
services: Dict[str, ServiceHealth]
-8
View File
@@ -2,12 +2,10 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent.router import router as agent_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
from v1.infra.router import router as infra_router
from v1.schedule_items.router import router as schedule_items_router
from v1.users.router import router as users_router
@@ -16,12 +14,6 @@ router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(agent_router)
router.include_router(friendships_router)
router.include_router(infra_router)
router.include_router(users_router)
router.include_router(schedule_items_router)
router.include_router(inbox_messages_router)
@router.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status="ok")
+1 -1
View File
@@ -46,7 +46,7 @@ def test_mobile_health_e2e() -> None:
base_url=f"http://{host}:{port}"
)
try:
response = request_context.get("/api/v1/health")
response = request_context.get("/health")
assert response.status == 200
body = response.json()
assert body["status"] == "ok"
@@ -15,16 +15,6 @@ def test_app_health_returns_envelope() -> None:
assert body["status"] == "ok"
def test_mobile_router_health_returns_envelope() -> None:
client = TestClient(app)
response = client.get("/api/v1/health")
assert response.status_code == 200
body = response.json()
assert body["status"] == "ok"
def test_not_found_returns_error_envelope() -> None:
client = TestClient(app)
@@ -15,15 +15,16 @@ class _FakeResumeService:
return {"session_id": session_id, "tool_call_id": tool_call_id}
def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
@pytest.mark.asyncio
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
events: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
async def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
result = run_agent_task(
result = await run_agent_task(
{
"command": "run",
"session_id": session_id,
@@ -38,7 +39,8 @@ def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
def test_run_agent_task_emits_error_event_on_exception() -> None:
@pytest.mark.asyncio
async def test_run_agent_task_emits_error_event_on_exception() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
class _BrokenRunService(_FakeRunService):
@@ -48,12 +50,12 @@ def test_run_agent_task_emits_error_event_on_exception() -> None:
events: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
async def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
with pytest.raises(RuntimeError):
run_agent_task(
await run_agent_task(
{
"command": "run",
"session_id": session_id,