refactor: 重整 schemas 作用域并统一用户上下文模型
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from pydantic import ValidationError
|
||||
|
||||
MAX_RUN_INPUT_BYTES = 256_000
|
||||
MAX_RUN_ID_LENGTH = 128
|
||||
MAX_MESSAGES = 200
|
||||
MAX_TEXT_CHARS = 10_000
|
||||
|
||||
|
||||
def _safe_len(value: str | None) -> int:
|
||||
if value is None:
|
||||
return 0
|
||||
return len(value)
|
||||
|
||||
|
||||
def _user_text_chars(run_input: RunAgentInput) -> int:
|
||||
total = 0
|
||||
for message in run_input.messages:
|
||||
if getattr(message, "role", None) != "user":
|
||||
continue
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
total += len(content)
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if getattr(item, "type", None) != "text":
|
||||
continue
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str):
|
||||
total += len(text)
|
||||
return total
|
||||
|
||||
|
||||
def parse_run_input(payload: dict[str, Any]) -> RunAgentInput:
|
||||
payload_bytes = len(
|
||||
json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode("utf-8")
|
||||
)
|
||||
if payload_bytes > MAX_RUN_INPUT_BYTES:
|
||||
raise ValueError("RunAgentInput payload exceeds size limit")
|
||||
try:
|
||||
run_input = RunAgentInput.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise ValueError("invalid AG-UI RunAgentInput payload") from exc
|
||||
try:
|
||||
UUID(run_input.thread_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError("threadId must be a valid UUID") from exc
|
||||
if _safe_len(run_input.run_id) > MAX_RUN_ID_LENGTH:
|
||||
raise ValueError("runId exceeds length limit")
|
||||
if len(run_input.messages) > MAX_MESSAGES:
|
||||
raise ValueError("RunAgentInput.messages exceeds limit")
|
||||
if _user_text_chars(run_input) > MAX_TEXT_CHARS:
|
||||
raise ValueError("RunAgentInput user message text exceeds limit")
|
||||
return run_input
|
||||
|
||||
|
||||
def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
|
||||
if len(run_input.messages) != 1:
|
||||
raise ValueError("RunAgentInput.messages must contain exactly one user message")
|
||||
message = run_input.messages[0]
|
||||
if getattr(message, "role", None) != "user":
|
||||
raise ValueError("RunAgentInput.messages[0].role must be user")
|
||||
_validate_user_content_blocks(getattr(message, "content", None))
|
||||
extract_latest_user_payload(run_input)
|
||||
|
||||
|
||||
def extract_latest_user_text(run_input: RunAgentInput) -> str:
|
||||
text, _ = extract_latest_user_payload(run_input)
|
||||
return text
|
||||
|
||||
|
||||
def extract_latest_user_content(
|
||||
run_input: RunAgentInput,
|
||||
) -> list[dict[str, Any]]:
|
||||
_, content_blocks = extract_latest_user_payload(run_input)
|
||||
return content_blocks
|
||||
|
||||
|
||||
def extract_latest_user_payload(
|
||||
run_input: RunAgentInput,
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
for message in reversed(run_input.messages):
|
||||
role = getattr(message, "role", None)
|
||||
if role != "user":
|
||||
continue
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
if text:
|
||||
return text, [{"type": "text", "text": text}]
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
item_type = getattr(item, "type", None)
|
||||
if item_type == "text":
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str) and text:
|
||||
text_parts.append(text)
|
||||
blocks.append({"type": "text", "text": text})
|
||||
continue
|
||||
if item_type != "binary":
|
||||
continue
|
||||
source_url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
if isinstance(source_url, str) and source_url:
|
||||
blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": source_url}}
|
||||
)
|
||||
combined = "".join(text_parts).strip()
|
||||
if combined or blocks:
|
||||
return combined, blocks
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
|
||||
|
||||
def _validate_user_content_blocks(content: Any) -> None:
|
||||
if isinstance(content, str):
|
||||
if content.strip():
|
||||
return
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
if not isinstance(content, list):
|
||||
raise ValueError("RunAgentInput.messages[0].content must be string or list")
|
||||
|
||||
has_text = False
|
||||
has_binary = False
|
||||
for item in content:
|
||||
item_type = getattr(item, "type", None)
|
||||
if item_type == "text":
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str) and text.strip():
|
||||
has_text = True
|
||||
continue
|
||||
if item_type == "binary":
|
||||
mime_type = (
|
||||
item.get("mimeType")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "mime_type", None)
|
||||
)
|
||||
url = (
|
||||
item.get("url")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "url", None)
|
||||
)
|
||||
data = (
|
||||
item.get("data")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "data", None)
|
||||
)
|
||||
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
|
||||
raise ValueError("binary content requires image mimeType")
|
||||
if not isinstance(url, str) or not url:
|
||||
raise ValueError("binary content requires url")
|
||||
if isinstance(data, str) and data:
|
||||
raise ValueError("binary content data is not allowed")
|
||||
has_binary = True
|
||||
continue
|
||||
raise ValueError("unsupported content block type")
|
||||
|
||||
if not has_text and not has_binary:
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires at least one non-empty user message"
|
||||
)
|
||||
|
||||
|
||||
def extract_latest_tool_result(
|
||||
run_input: RunAgentInput,
|
||||
) -> tuple[str, dict[str, object]]:
|
||||
for message in reversed(run_input.messages):
|
||||
role = getattr(message, "role", None)
|
||||
if role != "tool":
|
||||
continue
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
content = getattr(message, "content", None)
|
||||
if not isinstance(tool_call_id, str) or not tool_call_id:
|
||||
continue
|
||||
if not isinstance(content, str):
|
||||
break
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
except (TypeError, ValueError):
|
||||
return tool_call_id, {"content": content}
|
||||
if isinstance(parsed, dict):
|
||||
return tool_call_id, parsed
|
||||
return tool_call_id, {"content": content}
|
||||
raise ValueError(
|
||||
"RunAgentInput.messages requires a tool message with toolCallId for resume"
|
||||
)
|
||||
Reference in New Issue
Block a user