feat: 添加 Agent 步骤事件与图片附件功能
- 新增 stepStarted/stepFinished 事件类型支持 - 前端实现图片附件上传和预览功能 - 后端增强工具结果存储和事件处理 - 完善相关单元测试和集成测试
This commit is contained in:
@@ -10,7 +10,17 @@ import time
|
||||
from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFile
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
@@ -20,11 +30,18 @@ from core.agentscope.schemas.agui_input import (
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import AsrTranscribeResponse, TaskAcceptedResponse
|
||||
from v1.agent.schemas import (
|
||||
AsrTranscribeResponse,
|
||||
AttachmentReference,
|
||||
AttachmentUploadResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from v1.agent.service import AgentService, asr_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
@@ -38,6 +55,7 @@ _SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
|
||||
_WAV_HEADER_MIN_BYTES = 12
|
||||
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
"audio/wav",
|
||||
@@ -46,6 +64,42 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _verified_access_token_for_user(
|
||||
*,
|
||||
authorization: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> str | None:
|
||||
if not isinstance(authorization, str):
|
||||
return None
|
||||
normalized = authorization.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if not normalized.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
token = normalized[7:].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
jwt_secret = config.supabase.jwt_secret
|
||||
if jwt_secret is None:
|
||||
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
|
||||
|
||||
verifier = JwtVerifier(
|
||||
issuer=str(config.supabase.jwt_issuer),
|
||||
jwt_secret=jwt_secret.get_secret_value(),
|
||||
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||
)
|
||||
try:
|
||||
payload = verifier.verify(token)
|
||||
except TokenValidationError as exc:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or subject != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return token
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
@@ -111,6 +165,7 @@ async def enqueue_run(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
@@ -120,10 +175,15 @@ async def enqueue_run(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
task = await service.enqueue_run(
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -143,6 +203,7 @@ async def enqueue_resume(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
@@ -154,10 +215,15 @@ async def enqueue_resume(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
task = await service.enqueue_resume(
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -253,6 +319,31 @@ async def get_history_snapshot(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/attachments/{message_id}/{attachment_index}")
|
||||
async def get_attachment_preview(
|
||||
thread_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> StreamingResponse:
|
||||
if attachment_index < 0:
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment index")
|
||||
payload, mime_type = await service.get_attachment_preview(
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
attachment_index=attachment_index,
|
||||
current_user=current_user,
|
||||
)
|
||||
return StreamingResponse(
|
||||
iter([payload]),
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Cache-Control": "private, max-age=300",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
@@ -267,6 +358,34 @@ async def get_user_history_snapshot(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/attachments",
|
||||
response_model=AttachmentUploadResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def upload_attachment(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str = Form(alias="threadId"),
|
||||
file: UploadFile = File(),
|
||||
) -> AttachmentUploadResponse:
|
||||
payload = await file.read()
|
||||
if not payload:
|
||||
raise HTTPException(status_code=422, detail="Empty attachment")
|
||||
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Attachment too large")
|
||||
attachment = await service.upload_attachment(
|
||||
thread_id=thread_id,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type,
|
||||
payload=payload,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentUploadResponse(
|
||||
attachment=AttachmentReference.model_validate(attachment),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
|
||||
Reference in New Issue
Block a user