refactor: 简化 AgentScope 运行时模块与事件处理
- 移除冗余的 user_token 参数传递 - 重构 tool.result 事件使用 ToolAgentOutput 模型 - 重构 text.end 事件使用 WorkerAgentOutput 模型 - 简化 store 模块的 tool result 处理逻辑 - 更新 router/service 适配新事件结构 - 清理废弃的测试文件与设计文档 - 新增 AgentRuns 多模态存储设计文档
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
import json
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
@@ -9,10 +8,9 @@ from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from services.base.supabase import supabase_service
|
||||
from schemas.messages.chat_message import AgentChatMessage as AgentChatMessageSchema
|
||||
|
||||
|
||||
class ToolResultPayloadStorage(Protocol):
|
||||
@@ -210,132 +208,17 @@ class AgentRepository:
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
else str(message.role)
|
||||
)
|
||||
payload: dict[str, object] = {
|
||||
"id": str(message.id),
|
||||
"role": role,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
if role == AgentChatMessageRole.TOOL.value:
|
||||
metadata = message.metadata_json or {}
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
payload["toolCallId"] = tool_call_id
|
||||
|
||||
parsed_content: dict[str, object] | None = None
|
||||
try:
|
||||
decoded = json.loads(message.content)
|
||||
if isinstance(decoded, dict):
|
||||
parsed_content = decoded
|
||||
except (TypeError, ValueError):
|
||||
parsed_content = None
|
||||
|
||||
hydrated_content: dict[str, object] | None = None
|
||||
if self._tool_result_storage is not None:
|
||||
storage_bucket = metadata.get("storage_bucket")
|
||||
storage_path = metadata.get("storage_path")
|
||||
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
||||
expected_bucket = config.storage.bucket
|
||||
message_session_id = getattr(message, "session_id", None)
|
||||
expected_prefix = (
|
||||
f"tool-results/{message_session_id}/"
|
||||
if message_session_id is not None
|
||||
else None
|
||||
)
|
||||
tool_call_id = metadata.get("tool_call_id")
|
||||
is_legacy_path = isinstance(
|
||||
tool_call_id, str
|
||||
) and storage_path.endswith(f"/{tool_call_id}.json")
|
||||
if (
|
||||
storage_bucket == expected_bucket
|
||||
and _is_safe_storage_path(storage_path)
|
||||
and (
|
||||
(
|
||||
expected_prefix is not None
|
||||
and storage_path.startswith(expected_prefix)
|
||||
)
|
||||
or (
|
||||
storage_path.startswith("tool-results/")
|
||||
and is_legacy_path
|
||||
)
|
||||
)
|
||||
):
|
||||
try:
|
||||
hydrated_content = (
|
||||
await self._tool_result_storage.read_json(
|
||||
bucket=storage_bucket,
|
||||
path=storage_path,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
hydrated_content = None
|
||||
|
||||
resolved_content = hydrated_content or parsed_content
|
||||
payload["content"] = message.content
|
||||
if resolved_content is not None:
|
||||
ui = resolved_content.get("ui")
|
||||
if not isinstance(ui, dict):
|
||||
ui = resolved_content.get("ui_schema")
|
||||
if isinstance(ui, dict):
|
||||
payload["ui"] = ui
|
||||
display_content = resolved_content.get("content")
|
||||
if not isinstance(display_content, str):
|
||||
nested_result = resolved_content.get("result")
|
||||
if isinstance(nested_result, dict):
|
||||
nested_content = nested_result.get("content")
|
||||
if isinstance(nested_content, str):
|
||||
display_content = nested_content
|
||||
if (
|
||||
isinstance(display_content, str)
|
||||
and display_content.strip()
|
||||
and (
|
||||
not payload["content"]
|
||||
or _looks_like_offloaded_placeholder(str(payload["content"]))
|
||||
)
|
||||
):
|
||||
payload["content"] = display_content
|
||||
else:
|
||||
payload["content"] = message.content
|
||||
|
||||
if role == AgentChatMessageRole.USER.value:
|
||||
metadata = message.metadata_json or {}
|
||||
user_attachments = metadata.get("user_message_attachments")
|
||||
if isinstance(user_attachments, dict):
|
||||
bucket = user_attachments.get("bucket")
|
||||
path = user_attachments.get("path")
|
||||
mime_type = user_attachments.get("mime_type")
|
||||
if (
|
||||
isinstance(bucket, str)
|
||||
and isinstance(path, str)
|
||||
and isinstance(mime_type, str)
|
||||
):
|
||||
try:
|
||||
signed_url = await supabase_service.create_signed_url(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
expires_in_seconds=3600,
|
||||
)
|
||||
attachment_block = {
|
||||
"type": "binary",
|
||||
"mimeType": mime_type,
|
||||
"url": signed_url,
|
||||
}
|
||||
existing_content = message.content
|
||||
if (
|
||||
isinstance(existing_content, str)
|
||||
and existing_content.strip()
|
||||
):
|
||||
content_blocks = [
|
||||
{"type": "text", "text": existing_content}
|
||||
]
|
||||
content_blocks.append(attachment_block)
|
||||
payload["content"] = content_blocks
|
||||
else:
|
||||
payload["content"] = [attachment_block]
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
return payload
|
||||
payload_model = AgentChatMessageSchema.model_validate(
|
||||
{
|
||||
"id": str(message.id),
|
||||
"seq": int(message.seq),
|
||||
"role": role,
|
||||
"content": message.content,
|
||||
"metadata": message.metadata_json,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
return payload_model.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
@@ -347,19 +230,3 @@ def _derive_session_title(content_text: str) -> str | None:
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
|
||||
|
||||
def _is_safe_storage_path(path: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _looks_like_offloaded_placeholder(content: str) -> bool:
|
||||
normalized = content.strip().lower()
|
||||
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
|
||||
|
||||
@@ -11,9 +11,7 @@ from typing import Annotated, Union
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.agentscope.events import to_sse_event
|
||||
from core.auth.jwt_verifier import JwtVerifier, TokenValidationError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -38,6 +36,7 @@ from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
AsrTranscribeResponse,
|
||||
AttachmentReference,
|
||||
AttachmentSignedUrlResponse,
|
||||
AttachmentUploadResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
@@ -63,42 +62,6 @@ _ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _verified_access_token_for_user(
|
||||
*,
|
||||
authorization: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> str:
|
||||
if not isinstance(authorization, str):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
normalized = authorization.strip()
|
||||
if not normalized:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
if not normalized.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
token = normalized[7:].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
jwt_secret = config.supabase.jwt_secret
|
||||
if jwt_secret is None:
|
||||
raise HTTPException(status_code=503, detail="Auth verifier unavailable")
|
||||
|
||||
verifier = JwtVerifier(
|
||||
issuer=str(config.supabase.jwt_issuer),
|
||||
jwt_secret=jwt_secret.get_secret_value(),
|
||||
jwt_algorithm=config.supabase.jwt_algorithm,
|
||||
)
|
||||
try:
|
||||
payload = verifier.verify(token)
|
||||
except TokenValidationError as exc:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized") from exc
|
||||
|
||||
subject = payload.get("sub")
|
||||
if not isinstance(subject, str) or subject != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
return token
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
@@ -164,7 +127,6 @@ async def enqueue_run(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
|
||||
@@ -174,15 +136,10 @@ async def enqueue_run(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
task = await service.enqueue_run(
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -202,7 +159,6 @@ async def enqueue_resume(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
authorization: str | None = Header(default=None, alias="Authorization"),
|
||||
) -> TaskAcceptedResponse:
|
||||
if request.thread_id != thread_id:
|
||||
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
|
||||
@@ -214,15 +170,10 @@ async def enqueue_resume(
|
||||
allowed = await _allow_run_request(user_id=str(current_user.id))
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="Too many run requests")
|
||||
user_token = _verified_access_token_for_user(
|
||||
authorization=authorization,
|
||||
current_user=current_user,
|
||||
)
|
||||
task = await service.enqueue_resume(
|
||||
thread_id=thread_id,
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
user_token=user_token,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
@@ -304,20 +255,6 @@ async def stream_events(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/history")
|
||||
async def get_history_snapshot(
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
before: date | None = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
return await service.get_history_snapshot(
|
||||
thread_id=thread_id,
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
@@ -360,6 +297,25 @@ async def upload_attachment(
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/attachments/signed-url",
|
||||
response_model=AttachmentSignedUrlResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def create_attachment_signed_url(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
bucket: str = Query(min_length=1, max_length=100),
|
||||
path: str = Query(min_length=1, max_length=500),
|
||||
) -> AttachmentSignedUrlResponse:
|
||||
signed = await service.create_attachment_signed_url(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentSignedUrlResponse(**signed)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
|
||||
@@ -27,3 +27,9 @@ class AttachmentReference(BaseModel):
|
||||
|
||||
class AttachmentUploadResponse(BaseModel):
|
||||
attachment: AttachmentReference
|
||||
|
||||
|
||||
class AttachmentSignedUrlResponse(BaseModel):
|
||||
bucket: str
|
||||
path: str
|
||||
url: str
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from datetime import date
|
||||
import hashlib
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dashscope
|
||||
from ag_ui.core import RunAgentInput, StateSnapshotEvent
|
||||
@@ -23,19 +24,6 @@ _MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
|
||||
|
||||
def _normalize_bearer_token(value: str | None) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
lower = normalized.lower()
|
||||
if lower.startswith("bearer "):
|
||||
token = normalized[7:].strip()
|
||||
return token or None
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
@@ -70,14 +58,6 @@ class AgentRepositoryLike(Protocol):
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None: ...
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
@@ -148,7 +128,6 @@ class AgentService:
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
@@ -188,7 +167,6 @@ class AgentService:
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
@@ -226,19 +204,28 @@ class AgentService:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
if self._attachment_storage is None:
|
||||
continue
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Attachment storage unavailable",
|
||||
)
|
||||
|
||||
try:
|
||||
bucket, path = self._attachment_storage.parse_signed_url(url)
|
||||
bucket, path = self._validate_binary_signed_url(
|
||||
url=url,
|
||||
thread_id=run_input.thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
user_attachments = UserMessageAttachments(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
break
|
||||
except Exception: # noqa: BLE001
|
||||
logger.warning("Failed to parse signed URL", url=url)
|
||||
continue
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to parse signed URL", url=url, error=str(exc))
|
||||
raise HTTPException(status_code=422, detail="Invalid signed image url")
|
||||
|
||||
metadata: dict[str, object] | None = None
|
||||
if user_attachments is not None:
|
||||
@@ -329,13 +316,57 @@ class AgentService:
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def create_attachment_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
normalized_bucket = bucket.strip()
|
||||
if normalized_bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment bucket")
|
||||
|
||||
normalized_path = path.strip()
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/"
|
||||
if not _is_safe_attachment_path(
|
||||
normalized_path, expected_prefix=expected_prefix
|
||||
):
|
||||
raise HTTPException(status_code=422, detail="Invalid attachment path scope")
|
||||
|
||||
try:
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=normalized_bucket,
|
||||
path=normalized_path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment signed URL generation failed",
|
||||
extra={
|
||||
"bucket": normalized_bucket,
|
||||
"path": normalized_path,
|
||||
"user_id": str(current_user.id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Failed to generate signed URL")
|
||||
|
||||
return {
|
||||
"bucket": normalized_bucket,
|
||||
"path": normalized_path,
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
user_token: str | None = None,
|
||||
) -> TaskAccepted:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
@@ -345,7 +376,6 @@ class AgentService:
|
||||
command={
|
||||
"command": "resume",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _normalize_bearer_token(user_token),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
@@ -428,6 +458,37 @@ class AgentService:
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
def _validate_binary_signed_url(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Attachment storage unavailable"
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
expected_host = urlparse(config.supabase.url).netloc
|
||||
if parsed.netloc != expected_host:
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_HOST")
|
||||
|
||||
try:
|
||||
bucket, path = self._attachment_storage.parse_signed_url(url)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Invalid signed image url"
|
||||
) from exc
|
||||
|
||||
if bucket != config.storage.bucket:
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_BUCKET")
|
||||
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
if not _is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise HTTPException(status_code=422, detail="INVALID_BINARY_URL_PATH_SCOPE")
|
||||
return bucket, path
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user