1c02503d1d
- 移除冗余的 user_token 参数传递 - 重构 tool.result 事件使用 ToolAgentOutput 模型 - 重构 text.end 事件使用 WorkerAgentOutput 模型 - 简化 store 模块的 tool result 处理逻辑 - 更新 router/service 适配新事件结构 - 清理废弃的测试文件与设计文档 - 新增 AgentRuns 多模态存储设计文档
393 lines
13 KiB
Python
393 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from datetime import date
|
|
from typing import Annotated, Union
|
|
|
|
from ag_ui.core import RunAgentInput
|
|
from core.agentscope.events import to_sse_event
|
|
from core.auth.models import CurrentUser
|
|
from core.logging import get_logger
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
File,
|
|
Form,
|
|
Header,
|
|
HTTPException,
|
|
Query,
|
|
Request,
|
|
UploadFile,
|
|
status,
|
|
)
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from core.agentscope.schemas.agui_input import (
|
|
extract_latest_tool_result,
|
|
parse_run_input,
|
|
validate_run_request_messages_contract,
|
|
)
|
|
from services.base.redis import get_or_init_redis_client
|
|
from v1.agent.dependencies import get_agent_service
|
|
from v1.agent.schemas import (
|
|
AsrTranscribeResponse,
|
|
AttachmentReference,
|
|
AttachmentSignedUrlResponse,
|
|
AttachmentUploadResponse,
|
|
TaskAcceptedResponse,
|
|
)
|
|
from v1.agent.service import AgentService, asr_service
|
|
from v1.users.dependencies import get_current_user
|
|
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
|
logger = get_logger("v1.agent.router")
|
|
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
|
_RUNS_PER_MINUTE = 30
|
|
_TRANSCRIBES_PER_MINUTE = 20
|
|
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
|
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
|
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
|
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
|
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
|
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
|
|
_WAV_HEADER_MIN_BYTES = 12
|
|
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
|
"audio/wav",
|
|
"audio/x-wav",
|
|
"audio/wave",
|
|
}
|
|
|
|
|
|
def _looks_like_wav_header(header: bytes) -> bool:
|
|
if len(header) < _WAV_HEADER_MIN_BYTES:
|
|
return False
|
|
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
|
|
|
|
|
|
async def _allow_run_request(*, user_id: str) -> bool:
|
|
try:
|
|
redis = await get_or_init_redis_client()
|
|
minute_bucket = int(time.time() // 60)
|
|
key = f"agent:run-rate:{user_id}:{minute_bucket}"
|
|
count = await redis.incr(key)
|
|
if count == 1:
|
|
await redis.expire(key, 70)
|
|
return int(count) <= _RUNS_PER_MINUTE
|
|
except Exception: # noqa: BLE001
|
|
return False
|
|
|
|
|
|
async def _allow_transcribe_request(*, user_id: str) -> bool:
|
|
try:
|
|
redis = await get_or_init_redis_client()
|
|
minute_bucket = int(time.time() // 60)
|
|
key = f"agent:transcribe-rate:{user_id}:{minute_bucket}"
|
|
count = await redis.incr(key)
|
|
if count == 1:
|
|
await redis.expire(key, 70)
|
|
return int(count) <= _TRANSCRIBES_PER_MINUTE
|
|
except Exception: # noqa: BLE001
|
|
return False
|
|
|
|
|
|
async def _acquire_sse_slot(*, user_id: str) -> bool:
|
|
try:
|
|
redis = await get_or_init_redis_client()
|
|
key = f"agent:sse-active:{user_id}"
|
|
count = await redis.incr(key)
|
|
if count == 1:
|
|
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
|
if int(count) > _MAX_SSE_CONNECTIONS_PER_USER:
|
|
await redis.decr(key)
|
|
return False
|
|
return True
|
|
except Exception: # noqa: BLE001
|
|
return False
|
|
|
|
|
|
async def _release_sse_slot(*, user_id: str) -> None:
|
|
try:
|
|
redis = await get_or_init_redis_client()
|
|
key = f"agent:sse-active:{user_id}"
|
|
count = await redis.decr(key)
|
|
if int(count) <= 0:
|
|
await redis.delete(key)
|
|
except Exception: # noqa: BLE001
|
|
return None
|
|
|
|
|
|
@router.post(
|
|
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
|
)
|
|
async def enqueue_run(
|
|
request: RunAgentInput,
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
) -> TaskAcceptedResponse:
|
|
try:
|
|
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
|
validate_run_request_messages_contract(normalized)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
allowed = await _allow_run_request(user_id=str(current_user.id))
|
|
if not allowed:
|
|
raise HTTPException(status_code=429, detail="Too many run requests")
|
|
|
|
task = await service.enqueue_run(
|
|
run_input=request,
|
|
current_user=current_user,
|
|
)
|
|
return TaskAcceptedResponse(
|
|
taskId=task.task_id,
|
|
threadId=task.thread_id,
|
|
runId=task.run_id,
|
|
created=task.created,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/runs/{thread_id}/resume",
|
|
response_model=TaskAcceptedResponse,
|
|
status_code=status.HTTP_202_ACCEPTED,
|
|
)
|
|
async def enqueue_resume(
|
|
thread_id: str,
|
|
request: RunAgentInput,
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
) -> TaskAcceptedResponse:
|
|
if request.thread_id != thread_id:
|
|
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
|
try:
|
|
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
|
extract_latest_tool_result(normalized)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
allowed = await _allow_run_request(user_id=str(current_user.id))
|
|
if not allowed:
|
|
raise HTTPException(status_code=429, detail="Too many run requests")
|
|
task = await service.enqueue_resume(
|
|
thread_id=thread_id,
|
|
run_input=request,
|
|
current_user=current_user,
|
|
)
|
|
return TaskAcceptedResponse(
|
|
taskId=task.task_id,
|
|
threadId=task.thread_id,
|
|
runId=task.run_id,
|
|
created=task.created,
|
|
)
|
|
|
|
|
|
@router.get("/runs/{thread_id}/events")
|
|
async def stream_events(
|
|
request: Request,
|
|
thread_id: str,
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
|
idle_limit: int = Query(default=300, ge=1, le=3600),
|
|
) -> StreamingResponse:
|
|
if last_event_id is not None and (
|
|
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
|
|
):
|
|
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
|
|
|
|
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
|
|
if not sse_slot_acquired:
|
|
raise HTTPException(status_code=429, detail="Too many SSE connections")
|
|
|
|
async def _event_iter() -> AsyncIterator[str]:
|
|
cursor = last_event_id
|
|
idle_polls = 0
|
|
try:
|
|
while not await request.is_disconnected() and idle_polls < idle_limit:
|
|
try:
|
|
rows = await service.stream_events(
|
|
thread_id=thread_id,
|
|
last_event_id=cursor,
|
|
current_user=current_user,
|
|
)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning(
|
|
"SSE stream read failed",
|
|
thread_id=thread_id,
|
|
user_id=str(current_user.id),
|
|
reason=str(exc),
|
|
)
|
|
if "Timeout reading from" in str(exc):
|
|
idle_polls += 1
|
|
yield ": keep-alive\n\n"
|
|
await asyncio.sleep(0.2)
|
|
continue
|
|
break
|
|
|
|
if not rows:
|
|
idle_polls += 1
|
|
yield ": keep-alive\n\n"
|
|
await asyncio.sleep(0.2)
|
|
continue
|
|
|
|
idle_polls = 0
|
|
for row in rows:
|
|
row_id = str(row.get("id", ""))
|
|
event = row.get("event")
|
|
if not row_id or not isinstance(event, dict):
|
|
continue
|
|
cursor = row_id
|
|
yield to_sse_event(row_id, event)
|
|
|
|
finally:
|
|
await _release_sse_slot(user_id=str(current_user.id))
|
|
|
|
return StreamingResponse(
|
|
_event_iter(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/history")
|
|
async def get_user_history_snapshot(
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
thread_id: str | None = Query(default=None, alias="threadId"),
|
|
before: date | None = Query(default=None),
|
|
) -> dict[str, object]:
|
|
return await service.get_user_history_snapshot(
|
|
current_user=current_user,
|
|
thread_id=thread_id,
|
|
before=before,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/attachments",
|
|
response_model=AttachmentUploadResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
)
|
|
async def upload_attachment(
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
thread_id: str = Form(alias="threadId"),
|
|
file: UploadFile = File(),
|
|
) -> AttachmentUploadResponse:
|
|
payload = await file.read()
|
|
if not payload:
|
|
raise HTTPException(status_code=422, detail="Empty attachment")
|
|
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
|
|
raise HTTPException(status_code=413, detail="Attachment too large")
|
|
attachment = await service.upload_attachment(
|
|
thread_id=thread_id,
|
|
filename=file.filename,
|
|
content_type=file.content_type,
|
|
payload=payload,
|
|
current_user=current_user,
|
|
)
|
|
return AttachmentUploadResponse(
|
|
attachment=AttachmentReference.model_validate(attachment),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/attachments/signed-url",
|
|
response_model=AttachmentSignedUrlResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
)
|
|
async def create_attachment_signed_url(
|
|
service: Annotated[AgentService, Depends(get_agent_service)],
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
bucket: str = Query(min_length=1, max_length=100),
|
|
path: str = Query(min_length=1, max_length=500),
|
|
) -> AttachmentSignedUrlResponse:
|
|
signed = await service.create_attachment_signed_url(
|
|
bucket=bucket,
|
|
path=path,
|
|
current_user=current_user,
|
|
)
|
|
return AttachmentSignedUrlResponse(**signed)
|
|
|
|
|
|
@router.post(
|
|
"/transcribe",
|
|
response_model=AsrTranscribeResponse,
|
|
status_code=status.HTTP_200_OK,
|
|
)
|
|
async def transcribe(
|
|
audio: UploadFile,
|
|
request: Request,
|
|
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
|
) -> Union[AsrTranscribeResponse, JSONResponse]:
|
|
temp_path: str | None = None
|
|
try:
|
|
allowed = await _allow_transcribe_request(user_id=str(current_user.id))
|
|
if not allowed:
|
|
raise HTTPException(status_code=429, detail="Too many transcribe requests")
|
|
|
|
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
|
raise ValueError("Unsupported audio format")
|
|
|
|
content_length = request.headers.get("content-length")
|
|
if content_length is not None:
|
|
try:
|
|
declared_length = int(content_length)
|
|
except ValueError:
|
|
declared_length = None
|
|
if (
|
|
declared_length is not None
|
|
and declared_length
|
|
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
|
):
|
|
raise ValueError("Audio file too large")
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
|
temp_path = tmp_file.name
|
|
|
|
total_bytes = 0
|
|
header = bytearray()
|
|
while True:
|
|
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
|
|
if not chunk:
|
|
break
|
|
total_bytes += len(chunk)
|
|
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
|
raise ValueError("Audio file too large")
|
|
if len(header) < _WAV_HEADER_MIN_BYTES:
|
|
required = _WAV_HEADER_MIN_BYTES - len(header)
|
|
header.extend(chunk[:required])
|
|
tmp_file.write(chunk)
|
|
|
|
if total_bytes == 0:
|
|
raise ValueError("Empty audio file")
|
|
if not _looks_like_wav_header(bytes(header)):
|
|
raise ValueError("Unsupported audio format")
|
|
|
|
transcript = await asr_service.transcribe_file(
|
|
temp_path, audio.filename or "unknown"
|
|
)
|
|
|
|
return AsrTranscribeResponse(transcript=transcript)
|
|
|
|
except ValueError as exc:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": str(exc)},
|
|
)
|
|
except RuntimeError:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
content={"detail": "ASR service unavailable"},
|
|
)
|
|
finally:
|
|
await audio.close()
|
|
if temp_path and os.path.exists(temp_path):
|
|
os.unlink(temp_path)
|