feat(agent): complete closed-loop runtime and pricing fallback

This commit is contained in:
qzl
2026-03-05 15:34:37 +08:00
parent b02a322bf3
commit b486e78ff3
67 changed files with 3832 additions and 7 deletions
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+75
View File
@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Any, cast
from uuid import UUID
from fastapi import Depends
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.queue.tasks import run_command_task
from core.config.settings import config
from core.db import get_db
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
class CeleryQueueClient:
def __init__(self) -> None:
settings = cast(Any, config)
self._redis = redis.from_url(settings.redis.url, decode_responses=True)
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
if not locked:
existing = await self._redis.get(redis_key)
if existing and existing != "__inflight__":
return existing
payload = dict(command)
if dedup_key:
payload["dedup_key"] = dedup_key
delay = getattr(run_command_task, "delay")
result = delay(payload)
task_id = str(result.id)
if redis_key is not None:
await self._redis.set(redis_key, task_id, ex=300)
return task_id
class RedisEventStream:
def __init__(self) -> None:
settings = cast(Any, config)
client = redis.from_url(settings.redis.url, decode_responses=True)
self._store = RedisStreamEventStore(
client=client,
stream_prefix=settings.agent_runtime.redis_stream_prefix,
read_count=settings.agent_runtime.redis_stream_read_count,
block_ms=settings.agent_runtime.redis_stream_block_ms,
)
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
rows = await self._store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
return [{**row, "cursor": last_event_id} for row in rows]
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
return AgentService(
repository=AgentRepository(session),
queue=CeleryQueueClient(),
stream=RedisEventStream(),
)
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_session import AgentChatSession
class AgentRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
stmt = select(AgentChatSession.user_id).where(
AgentChatSession.id == session_uuid
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(self, *, user_id: str) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session = AgentChatSession(user_id=user_uuid)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = await self._session.get(AgentChatSession, session_uuid)
if session is not None:
await self._session.delete(session)
await self._session.flush()
+104
View File
@@ -0,0 +1,104 @@
from __future__ import annotations
from collections.abc import AsyncIterator
import asyncio
from typing import Annotated
from fastapi import APIRouter, Depends, Header, Query, Request, status
from fastapi.responses import StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunRequest,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
task = await service.enqueue_run(
session_id=request.session_id,
prompt=request.prompt,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
created=task.created,
)
@router.post(
"/runs/{session_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
session_id: str,
request: ResumeRequest,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
task = await service.enqueue_resume(
session_id=session_id,
tool_call_id=request.tool_call_id,
current_user=current_user,
)
return TaskAcceptedResponse(
task_id=task.task_id,
session_id=task.session_id,
created=task.created,
)
@router.get("/runs/{session_id}/events")
async def stream_events(
request: Request,
session_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(
session_id=session_id,
last_event_id=cursor,
current_user=current_user,
)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
return StreamingResponse(
_event_iter(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class RunRequest(BaseModel):
session_id: str | None = Field(default=None, min_length=1, max_length=100)
prompt: str = Field(min_length=1, max_length=5000)
class ResumeRequest(BaseModel):
tool_call_id: str = Field(min_length=1, max_length=200)
class TaskAcceptedResponse(BaseModel):
task_id: str
session_id: str
created: bool
+132
View File
@@ -0,0 +1,132 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from fastapi import HTTPException
from core.auth.models import CurrentUser
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
session_id: str
created: bool
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(self, *, user_id: str) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
raise HTTPException(status_code=403, detail="Forbidden")
class AgentService:
def __init__(
self,
*,
repository: AgentRepositoryLike,
queue: QueueClientLike,
stream: EventStreamLike,
) -> None:
self._repository = repository
self._queue = queue
self._stream = stream
async def enqueue_run(
self,
*,
session_id: str | None,
prompt: str,
current_user: CurrentUser,
) -> TaskAccepted:
created = False
target_session_id = session_id
if target_session_id is None:
target_session_id = await self._repository.create_session_for_user(
user_id=str(current_user.id)
)
created = True
else:
owner = await self._repository.get_session_owner(
session_id=target_session_id
)
ensure_session_owner(owner_id=owner, current_user=current_user)
if created:
await self._repository.commit()
try:
task_id = await self._queue.enqueue(
command={
"command": "run",
"session_id": target_session_id,
"user_input": prompt,
},
dedup_key=None,
)
except Exception: # noqa: BLE001
raise
return TaskAccepted(
task_id=task_id, session_id=target_session_id, created=created
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=session_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{session_id}:{tool_call_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"session_id": session_id,
"tool_call_id": tool_call_id,
},
dedup_key=dedup_key,
)
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
async def stream_events(
self,
*,
session_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=session_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=session_id,
last_event_id=last_event_id,
)
+2
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent.router import router as agent_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
@@ -13,6 +14,7 @@ from v1.users.router import router as users_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(agent_router)
router.include_router(friendships_router)
router.include_router(infra_router)
router.include_router(users_router)