refactor: 统一 Redis 连接管理,改用 RedisService
- App 启动时初始化 RedisService,关闭时释放连接 - Celery worker 通过 worker_process_init 钩子初始化 Redis - Agent 端点改用 RedisService 替代直接创建连接 - Celery task 改为 async def,使用统一连接 - 删除无用的 infra 模块和 core/http/models - 日志脱敏,不记录 Redis 密码 - 初始化失败时 fail-fast - 异常发布添加二级保护
This commit is contained in:
+29
-3
@@ -1,18 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from core.config.settings import config
|
||||
from core.http.models import HealthResponse
|
||||
from core.http.response import build_problem_details
|
||||
from core.logging import configure_logging, get_logger, log_service_banner
|
||||
from services.base.redis import redis_service
|
||||
from v1.router import router as mobile_router
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
configure_logging(config)
|
||||
|
||||
log_service_banner(
|
||||
@@ -20,7 +28,26 @@ log_service_banner(
|
||||
environment=config.runtime.environment,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
logger = get_logger("api.app")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
initialized = await redis_service.initialize()
|
||||
if not initialized:
|
||||
logger.error("Redis service failed to initialize, aborting startup")
|
||||
raise RuntimeError("Redis service initialization failed")
|
||||
logger.info(
|
||||
"Redis service initialized",
|
||||
host=config.redis.host,
|
||||
db=config.redis.db,
|
||||
)
|
||||
yield
|
||||
await redis_service.close()
|
||||
logger.info("Redis service closed")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.cors.allow_origins,
|
||||
@@ -29,7 +56,6 @@ app.add_middleware(
|
||||
allow_headers=config.cors.allow_headers,
|
||||
)
|
||||
app.include_router(mobile_router)
|
||||
logger = get_logger("api.app")
|
||||
|
||||
logger.info(
|
||||
"Web application initialized",
|
||||
|
||||
@@ -31,6 +31,14 @@ class RedisStreamEventStore:
|
||||
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
|
||||
return str(self._client.xadd(stream, {"event": payload}))
|
||||
|
||||
async def append_event(self, *, session_id: UUID, event: dict[str, Any]) -> str:
|
||||
stream = self._stream_name(session_id)
|
||||
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
|
||||
result = self._client.xadd(stream, {"event": payload})
|
||||
if inspect.isawaitable(result):
|
||||
return str(await result)
|
||||
return str(result)
|
||||
|
||||
async def read_events(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -1,28 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Any, Callable, Protocol, cast
|
||||
from typing import Any, Protocol, cast
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.celery.app import celery_app
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import redis_service
|
||||
|
||||
logger = get_logger("core.agent.infrastructure.queue.tasks")
|
||||
|
||||
_background_loop: asyncio.AbstractEventLoop | None = None
|
||||
_background_thread: threading.Thread | None = None
|
||||
_background_ready = threading.Event()
|
||||
|
||||
|
||||
class PublishEvent(Protocol):
|
||||
def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
|
||||
async def __call__(self, event_type: str, payload: dict[str, object]) -> None: ...
|
||||
|
||||
|
||||
class RunServiceLike(Protocol):
|
||||
@@ -35,36 +28,9 @@ class ResumeServiceLike(Protocol):
|
||||
) -> dict[str, object]: ...
|
||||
|
||||
|
||||
def _run_async(task: Callable[[], Any]) -> Any:
|
||||
loop = _ensure_background_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(task(), loop)
|
||||
return future.result()
|
||||
|
||||
|
||||
def _ensure_background_loop() -> asyncio.AbstractEventLoop:
|
||||
global _background_loop, _background_thread
|
||||
if _background_loop is not None:
|
||||
return _background_loop
|
||||
|
||||
def _loop_worker() -> None:
|
||||
global _background_loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
_background_loop = loop
|
||||
_background_ready.set()
|
||||
loop.run_forever()
|
||||
|
||||
_background_thread = threading.Thread(target=_loop_worker, daemon=True)
|
||||
_background_thread.start()
|
||||
_background_ready.wait(timeout=5)
|
||||
if _background_loop is None:
|
||||
raise RuntimeError("failed to initialize background event loop")
|
||||
return _background_loop
|
||||
|
||||
|
||||
def _build_redis_publisher() -> PublishEvent:
|
||||
async def _build_redis_publisher() -> PublishEvent:
|
||||
settings = cast(Any, config)
|
||||
client = redis.from_url(settings.redis.url, decode_responses=True)
|
||||
client = redis_service.get_client()
|
||||
event_store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
@@ -72,11 +38,11 @@ def _build_redis_publisher() -> PublishEvent:
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
session_id = str(payload.get("session_id", "")).strip()
|
||||
if not session_id:
|
||||
raise ValueError("session_id is required in event payload")
|
||||
event_store.append_event_sync(
|
||||
await event_store.append_event(
|
||||
session_id=UUID(session_id),
|
||||
event={"type": event_type, "data": payload},
|
||||
)
|
||||
@@ -84,14 +50,14 @@ def _build_redis_publisher() -> PublishEvent:
|
||||
return _publish
|
||||
|
||||
|
||||
def run_agent_task(
|
||||
async def run_agent_task(
|
||||
command: dict[str, Any],
|
||||
*,
|
||||
publish_event: PublishEvent | None = None,
|
||||
run_service: RunServiceLike | None = None,
|
||||
resume_service: ResumeServiceLike | None = None,
|
||||
) -> dict[str, object]:
|
||||
publisher = publish_event or _build_redis_publisher()
|
||||
publisher = publish_event or await _build_redis_publisher()
|
||||
service_run = run_service or RunService()
|
||||
service_resume = resume_service or ResumeService()
|
||||
|
||||
@@ -105,31 +71,27 @@ def run_agent_task(
|
||||
UUID(session_id)
|
||||
|
||||
start_event = "RUN_RESUMED" if command_type == "resume" else "RUN_STARTED"
|
||||
publisher(start_event, {"session_id": session_id})
|
||||
await publisher(start_event, {"session_id": session_id})
|
||||
|
||||
try:
|
||||
if command_type == "resume":
|
||||
tool_call_id = str(command.get("tool_call_id", ""))
|
||||
if not tool_call_id:
|
||||
raise ValueError("tool_call_id is required")
|
||||
result = _run_async(
|
||||
lambda: service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
result = await service_resume.resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
else:
|
||||
user_input = str(command.get("user_input", ""))
|
||||
if not user_input:
|
||||
raise ValueError("user_input is required")
|
||||
result = _run_async(
|
||||
lambda: service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
)
|
||||
result = await service_run.run(
|
||||
session_id=session_id,
|
||||
user_input=user_input,
|
||||
)
|
||||
|
||||
publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
|
||||
await publisher("RUNTIME_EVENT", {"session_id": session_id, "result": result})
|
||||
extra_events = result.get("events") if isinstance(result, dict) else None
|
||||
if isinstance(extra_events, list):
|
||||
for event in extra_events:
|
||||
@@ -140,8 +102,8 @@ def run_agent_task(
|
||||
if not isinstance(event_type, str) or not isinstance(event_data, dict):
|
||||
continue
|
||||
payload = {"session_id": session_id, **event_data}
|
||||
publisher(event_type, payload)
|
||||
publisher("RUN_FINISHED", {"session_id": session_id})
|
||||
await publisher(event_type, payload)
|
||||
await publisher("RUN_FINISHED", {"session_id": session_id})
|
||||
return result
|
||||
except Exception: # noqa: BLE001
|
||||
error_id = "agent_runtime_failed"
|
||||
@@ -150,10 +112,19 @@ def run_agent_task(
|
||||
session_id=session_id,
|
||||
error_id=error_id,
|
||||
)
|
||||
publisher("RUN_ERROR", {"session_id": session_id, "error_id": error_id})
|
||||
try:
|
||||
await publisher(
|
||||
"RUN_ERROR", {"session_id": session_id, "error_id": error_id}
|
||||
)
|
||||
except Exception as publish_exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to publish RUN_ERROR event",
|
||||
session_id=session_id,
|
||||
error=str(publish_exc),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.agent.run_command")
|
||||
def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return run_agent_task(command)
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agent_task(command)
|
||||
|
||||
@@ -1,10 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals as celery_signals
|
||||
from kombu import Queue
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from core.logging.celery import configure_celery_app
|
||||
from services.base.redis import redis_service
|
||||
|
||||
logger = get_logger("core.celery")
|
||||
|
||||
|
||||
def _init_redis_on_worker_startup(**_: object) -> None:
|
||||
import asyncio
|
||||
|
||||
logger.info("Initializing Redis service for Celery worker")
|
||||
try:
|
||||
result = asyncio.run(redis_service.initialize())
|
||||
if result:
|
||||
logger.info("Redis service initialized for Celery worker")
|
||||
else:
|
||||
logger.warning("Redis service initialization returned False")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to initialize Redis for Celery worker", error=str(exc))
|
||||
|
||||
|
||||
def create_celery_app() -> Celery:
|
||||
@@ -50,3 +69,5 @@ def create_celery_app() -> Celery:
|
||||
|
||||
|
||||
celery_app = create_celery_app()
|
||||
|
||||
celery_signals.worker_process_init.connect(_init_redis_on_worker_startup)
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
@@ -4,21 +4,20 @@ from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.queue.tasks import run_command_task
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from services.base.redis import redis_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
class CeleryQueueClient:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
self._redis = redis.from_url(settings.redis.url, decode_responses=True)
|
||||
self._redis = redis_service.get_client()
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
@@ -46,7 +45,7 @@ class CeleryQueueClient:
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
client = redis.from_url(settings.redis.url, decode_responses=True)
|
||||
client = redis_service.get_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -1,7 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from services.base.redis import RedisService, redis_service
|
||||
|
||||
|
||||
def get_redis_service() -> RedisService:
|
||||
return redis_service
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from services.base.redis import RedisService
|
||||
from v1.infra.dependencies import get_redis_service
|
||||
from v1.infra.schemas import InfraHealthResponse, ServiceHealth
|
||||
|
||||
|
||||
router = APIRouter(prefix="/infra", tags=["infra"])
|
||||
|
||||
|
||||
@router.get("/health", response_model=InfraHealthResponse)
|
||||
async def infra_health(
|
||||
redis_service: RedisService = Depends(get_redis_service),
|
||||
) -> InfraHealthResponse:
|
||||
if not redis_service.is_initialized:
|
||||
await redis_service.initialize()
|
||||
|
||||
redis_health = await redis_service.health_check()
|
||||
status = "healthy" if redis_health["status"] == "healthy" else "unhealthy"
|
||||
|
||||
return InfraHealthResponse(
|
||||
status=status,
|
||||
services={
|
||||
"redis": ServiceHealth(**redis_health),
|
||||
},
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ServiceHealth(BaseModel):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
details: Dict[str, Any]
|
||||
|
||||
|
||||
class InfraHealthResponse(BaseModel):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
services: Dict[str, ServiceHealth]
|
||||
@@ -2,12 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.http.models import HealthResponse
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
from v1.infra.router import router as infra_router
|
||||
from v1.schedule_items.router import router as schedule_items_router
|
||||
from v1.users.router import router as users_router
|
||||
|
||||
@@ -16,12 +14,6 @@ router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(infra_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(schedule_items_router)
|
||||
router.include_router(inbox_messages_router)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
@@ -46,7 +46,7 @@ def test_mobile_health_e2e() -> None:
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.get("/api/v1/health")
|
||||
response = request_context.get("/health")
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
@@ -15,16 +15,6 @@ def test_app_health_returns_envelope() -> None:
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
def test_mobile_router_health_returns_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
def test_not_found_returns_error_envelope() -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@@ -15,15 +15,16 @@ class _FakeResumeService:
|
||||
return {"session_id": session_id, "tool_call_id": tool_call_id}
|
||||
|
||||
|
||||
def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
events: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
result = await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
@@ -38,7 +39,8 @@ def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
|
||||
|
||||
|
||||
def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
class _BrokenRunService(_FakeRunService):
|
||||
@@ -48,12 +50,12 @@ def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
async def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
run_agent_task(
|
||||
await run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
|
||||
Reference in New Issue
Block a user