feat(agent): complete closed-loop runtime and pricing fallback
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.queue.tasks import run_command_task
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
class CeleryQueueClient:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
self._redis = redis.from_url(settings.redis.url, decode_responses=True)
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
redis_key = None
|
||||
if dedup_key:
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
locked = await self._redis.set(redis_key, "__inflight__", nx=True, ex=300)
|
||||
if not locked:
|
||||
existing = await self._redis.get(redis_key)
|
||||
if existing and existing != "__inflight__":
|
||||
return existing
|
||||
|
||||
payload = dict(command)
|
||||
if dedup_key:
|
||||
payload["dedup_key"] = dedup_key
|
||||
delay = getattr(run_command_task, "delay")
|
||||
result = delay(payload)
|
||||
task_id = str(result.id)
|
||||
if redis_key is not None:
|
||||
await self._redis.set(redis_key, task_id, ex=300)
|
||||
return task_id
|
||||
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
settings = cast(Any, config)
|
||||
client = redis.from_url(settings.redis.url, decode_responses=True)
|
||||
self._store = RedisStreamEventStore(
|
||||
client=client,
|
||||
stream_prefix=settings.agent_runtime.redis_stream_prefix,
|
||||
read_count=settings.agent_runtime.redis_stream_read_count,
|
||||
block_ms=settings.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = await self._store.read_events(
|
||||
session_id=UUID(session_id),
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
return [{**row, "cursor": last_event_id} for row in rows]
|
||||
|
||||
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
return AgentService(
|
||||
repository=AgentRepository(session),
|
||||
queue=CeleryQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
|
||||
class AgentRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
|
||||
stmt = select(AgentChatSession.user_id).where(
|
||||
AgentChatSession.id == session_uuid
|
||||
)
|
||||
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return str(owner_id)
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
||||
|
||||
session = AgentChatSession(user_id=user_uuid)
|
||||
self._session.add(session)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(session)
|
||||
return str(session.id)
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self._session.rollback()
|
||||
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
||||
session = await self._session.get(AgentChatSession, session_uuid)
|
||||
if session is not None:
|
||||
await self._session.delete(session)
|
||||
await self._session.flush()
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from core.agent.infrastructure.agui.stream import to_sse_event
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import ResumeRequest, RunRequest, TaskAcceptedResponse
|
||||
from v1.agent.service import AgentService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def enqueue_run(
|
||||
request: RunRequest,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
task = await service.enqueue_run(
|
||||
session_id=request.session_id,
|
||||
prompt=request.prompt,
|
||||
current_user=current_user,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
task_id=task.task_id,
|
||||
session_id=task.session_id,
|
||||
created=task.created,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs/{session_id}/resume",
|
||||
response_model=TaskAcceptedResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def enqueue_resume(
|
||||
session_id: str,
|
||||
request: ResumeRequest,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
task = await service.enqueue_resume(
|
||||
session_id=session_id,
|
||||
tool_call_id=request.tool_call_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
task_id=task.task_id,
|
||||
session_id=task.session_id,
|
||||
created=task.created,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{session_id}/events")
|
||||
async def stream_events(
|
||||
request: Request,
|
||||
session_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
||||
idle_limit: int = Query(default=300, ge=1, le=3600),
|
||||
) -> StreamingResponse:
|
||||
async def _event_iter() -> AsyncIterator[str]:
|
||||
cursor = last_event_id
|
||||
idle_polls = 0
|
||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||
rows = await service.stream_events(
|
||||
session_id=session_id,
|
||||
last_event_id=cursor,
|
||||
current_user=current_user,
|
||||
)
|
||||
if not rows:
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
|
||||
idle_polls = 0
|
||||
for row in rows:
|
||||
row_id = str(row.get("id", ""))
|
||||
event = row.get("event")
|
||||
if not row_id or not isinstance(event, dict):
|
||||
continue
|
||||
cursor = row_id
|
||||
yield to_sse_event(row_id, event)
|
||||
|
||||
return StreamingResponse(
|
||||
_event_iter(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
session_id: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
prompt: str = Field(min_length=1, max_length=5000)
|
||||
|
||||
|
||||
class ResumeRequest(BaseModel):
|
||||
tool_call_id: str = Field(min_length=1, max_length=200)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
task_id: str
|
||||
session_id: str
|
||||
created: bool
|
||||
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
session_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
class AgentRepositoryLike(Protocol):
|
||||
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str: ...
|
||||
|
||||
|
||||
class EventStreamLike(Protocol):
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
|
||||
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
if owner_id != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
class AgentService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: AgentRepositoryLike,
|
||||
queue: QueueClientLike,
|
||||
stream: EventStreamLike,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._queue = queue
|
||||
self._stream = stream
|
||||
|
||||
async def enqueue_run(
|
||||
self,
|
||||
*,
|
||||
session_id: str | None,
|
||||
prompt: str,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
target_session_id = session_id
|
||||
if target_session_id is None:
|
||||
target_session_id = await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
created = True
|
||||
else:
|
||||
owner = await self._repository.get_session_owner(
|
||||
session_id=target_session_id
|
||||
)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
if created:
|
||||
await self._repository.commit()
|
||||
|
||||
try:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"session_id": target_session_id,
|
||||
"user_input": prompt,
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
raise
|
||||
return TaskAccepted(
|
||||
task_id=task_id, session_id=target_session_id, created=created
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=session_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
dedup_key = f"resume:{session_id}:{tool_call_id}"
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"session_id": session_id,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
|
||||
return TaskAccepted(task_id=task_id, session_id=session_id, created=False)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
owner = await self._repository.get_session_owner(session_id=session_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
return await self._stream.read(
|
||||
session_id=session_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.http.models import HealthResponse
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
@@ -13,6 +14,7 @@ from v1.users.router import router as users_router
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(infra_router)
|
||||
router.include_router(users_router)
|
||||
|
||||
Reference in New Issue
Block a user