203 lines
7.0 KiB
Python
203 lines
7.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from ag_ui.core import RunAgentInput
|
|
from pydantic import ValidationError
|
|
|
|
MAX_RUN_INPUT_BYTES = 256_000
|
|
MAX_RUN_ID_LENGTH = 128
|
|
MAX_MESSAGES = 200
|
|
MAX_TEXT_CHARS = 10_000
|
|
|
|
|
|
def _safe_len(value: str | None) -> int:
|
|
if value is None:
|
|
return 0
|
|
return len(value)
|
|
|
|
|
|
def _user_text_chars(run_input: RunAgentInput) -> int:
|
|
total = 0
|
|
for message in run_input.messages:
|
|
if getattr(message, "role", None) != "user":
|
|
continue
|
|
content = getattr(message, "content", None)
|
|
if isinstance(content, str):
|
|
total += len(content)
|
|
continue
|
|
if isinstance(content, list):
|
|
for item in content:
|
|
if getattr(item, "type", None) != "text":
|
|
continue
|
|
text = getattr(item, "text", None)
|
|
if isinstance(text, str):
|
|
total += len(text)
|
|
return total
|
|
|
|
|
|
def parse_run_input(payload: dict[str, Any]) -> RunAgentInput:
|
|
payload_bytes = len(
|
|
json.dumps(payload, ensure_ascii=True, separators=(",", ":")).encode("utf-8")
|
|
)
|
|
if payload_bytes > MAX_RUN_INPUT_BYTES:
|
|
raise ValueError("RunAgentInput payload exceeds size limit")
|
|
try:
|
|
run_input = RunAgentInput.model_validate(payload)
|
|
except ValidationError as exc:
|
|
raise ValueError("invalid AG-UI RunAgentInput payload") from exc
|
|
try:
|
|
UUID(run_input.thread_id)
|
|
except ValueError as exc:
|
|
raise ValueError("threadId must be a valid UUID") from exc
|
|
if _safe_len(run_input.run_id) > MAX_RUN_ID_LENGTH:
|
|
raise ValueError("runId exceeds length limit")
|
|
if len(run_input.messages) > MAX_MESSAGES:
|
|
raise ValueError("RunAgentInput.messages exceeds limit")
|
|
if _user_text_chars(run_input) > MAX_TEXT_CHARS:
|
|
raise ValueError("RunAgentInput user message text exceeds limit")
|
|
return run_input
|
|
|
|
|
|
def validate_run_request_messages_contract(run_input: RunAgentInput) -> None:
|
|
if len(run_input.messages) != 1:
|
|
raise ValueError("RunAgentInput.messages must contain exactly one user message")
|
|
message = run_input.messages[0]
|
|
if getattr(message, "role", None) != "user":
|
|
raise ValueError("RunAgentInput.messages[0].role must be user")
|
|
_validate_user_content_blocks(getattr(message, "content", None))
|
|
extract_latest_user_payload(run_input)
|
|
|
|
|
|
def extract_latest_user_text(run_input: RunAgentInput) -> str:
|
|
text, _ = extract_latest_user_payload(run_input)
|
|
return text
|
|
|
|
|
|
def extract_latest_user_content(
|
|
run_input: RunAgentInput,
|
|
) -> list[dict[str, Any]]:
|
|
_, content_blocks = extract_latest_user_payload(run_input)
|
|
return content_blocks
|
|
|
|
|
|
def extract_latest_user_payload(
|
|
run_input: RunAgentInput,
|
|
) -> tuple[str, list[dict[str, Any]]]:
|
|
for message in reversed(run_input.messages):
|
|
role = getattr(message, "role", None)
|
|
if role != "user":
|
|
continue
|
|
content = getattr(message, "content", None)
|
|
if isinstance(content, str):
|
|
text = content.strip()
|
|
if text:
|
|
return text, [{"type": "text", "text": text}]
|
|
continue
|
|
if isinstance(content, list):
|
|
text_parts: list[str] = []
|
|
blocks: list[dict[str, Any]] = []
|
|
for item in content:
|
|
item_type = getattr(item, "type", None)
|
|
if item_type == "text":
|
|
text = getattr(item, "text", None)
|
|
if isinstance(text, str) and text:
|
|
text_parts.append(text)
|
|
blocks.append({"type": "text", "text": text})
|
|
continue
|
|
if item_type != "binary":
|
|
continue
|
|
source_url = (
|
|
item.get("url")
|
|
if isinstance(item, dict)
|
|
else getattr(item, "url", None)
|
|
)
|
|
if isinstance(source_url, str) and source_url:
|
|
blocks.append(
|
|
{"type": "image_url", "image_url": {"url": source_url}}
|
|
)
|
|
combined = "".join(text_parts).strip()
|
|
if combined or blocks:
|
|
return combined, blocks
|
|
raise ValueError(
|
|
"RunAgentInput.messages requires at least one non-empty user message"
|
|
)
|
|
|
|
|
|
def _validate_user_content_blocks(content: Any) -> None:
|
|
if isinstance(content, str):
|
|
if content.strip():
|
|
return
|
|
raise ValueError(
|
|
"RunAgentInput.messages requires at least one non-empty user message"
|
|
)
|
|
if not isinstance(content, list):
|
|
raise ValueError("RunAgentInput.messages[0].content must be string or list")
|
|
|
|
has_text = False
|
|
has_binary = False
|
|
for item in content:
|
|
item_type = getattr(item, "type", None)
|
|
if item_type == "text":
|
|
text = getattr(item, "text", None)
|
|
if isinstance(text, str) and text.strip():
|
|
has_text = True
|
|
continue
|
|
if item_type == "binary":
|
|
mime_type = (
|
|
item.get("mimeType")
|
|
if isinstance(item, dict)
|
|
else getattr(item, "mime_type", None)
|
|
)
|
|
url = (
|
|
item.get("url")
|
|
if isinstance(item, dict)
|
|
else getattr(item, "url", None)
|
|
)
|
|
data = (
|
|
item.get("data")
|
|
if isinstance(item, dict)
|
|
else getattr(item, "data", None)
|
|
)
|
|
if not isinstance(mime_type, str) or not mime_type.startswith("image/"):
|
|
raise ValueError("binary content requires image mimeType")
|
|
if not isinstance(url, str) or not url:
|
|
raise ValueError("binary content requires url")
|
|
if isinstance(data, str) and data:
|
|
raise ValueError("binary content data is not allowed")
|
|
has_binary = True
|
|
continue
|
|
raise ValueError("unsupported content block type")
|
|
|
|
if not has_text and not has_binary:
|
|
raise ValueError(
|
|
"RunAgentInput.messages requires at least one non-empty user message"
|
|
)
|
|
|
|
|
|
def extract_latest_tool_result(
|
|
run_input: RunAgentInput,
|
|
) -> tuple[str, dict[str, object]]:
|
|
for message in reversed(run_input.messages):
|
|
role = getattr(message, "role", None)
|
|
if role != "tool":
|
|
continue
|
|
tool_call_id = getattr(message, "tool_call_id", None)
|
|
content = getattr(message, "content", None)
|
|
if not isinstance(tool_call_id, str) or not tool_call_id:
|
|
continue
|
|
if not isinstance(content, str):
|
|
break
|
|
try:
|
|
parsed = json.loads(content)
|
|
except (TypeError, ValueError):
|
|
return tool_call_id, {"content": content}
|
|
if isinstance(parsed, dict):
|
|
return tool_call_id, parsed
|
|
return tool_call_id, {"content": content}
|
|
raise ValueError(
|
|
"RunAgentInput.messages requires a tool message with toolCallId for resume"
|
|
)
|