feat: 优化 Agent 运行时与聊天设置体验

This commit is contained in:
qzl
2026-03-16 18:32:09 +08:00
parent 3f79cf0df7
commit 5a34616287
41 changed files with 2603 additions and 1263 deletions
@@ -74,12 +74,34 @@ def _router_role_rules() -> list[str]:
def _worker_role_rules() -> list[str]:
return [
"Worker only: execute routed objective without changing router intent.",
"Treat router output as objective/constraints contract, not as a fully-materialized tool-args payload.",
"Infer deterministic required tool arguments from contract fields, tool schema, and runtime context.",
"Ask minimal clarification only when required arguments cannot be inferred safely.",
"Ground every claim in available evidence and tool results; never fabricate execution state.",
"Keep status/result_type/answer/key_points/suggested_actions/error internally consistent.",
"On partial/failed execution, return concise actionable error context.",
]
def build_worker_contract_prompt(*, router_output: RouterAgentOutput) -> str:
contract_json = json.dumps(
router_output.model_dump(mode="json", exclude_none=True),
ensure_ascii=False,
separators=(",", ":"),
)
return "\n".join(
[
"[Worker Contract]",
"- Keep routed objective unchanged.",
"- Use normalized_task_input as objective text.",
"- Use multimodal_summary/key_entities/constraints as execution evidence.",
"- Infer deterministic missing required tool args from evidence + tool schema.",
"- Ask clarification only when safe inference is impossible.",
"[RouterAgentOutput]",
contract_json,
]
)
def build_agent_prompt(*, agent_type: AgentType) -> str:
lines = [
"[Agent Identity]",
@@ -1,13 +1,12 @@
from __future__ import annotations
import json
from typing import Any
from agentscope.agent import ReActAgent
from agentscope.message import Msg
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from core.agentscope.runtime.utils import extract_text_content, parse_json_dict
from core.agentscope.utils import finalize_json_response
class JsonReActAgent(ReActAgent):
@@ -47,77 +46,14 @@ class JsonReActAgent(ReActAgent):
*,
output_model: type[BaseModel],
) -> dict[str, Any]:
schema_json = json.dumps(
output_model.model_json_schema(),
ensure_ascii=True,
separators=(",", ":"),
)
last_error = ""
for attempt in range(1, self._finalize_retries + 2):
prompt = await self.formatter.format(
msgs=[
Msg("system", self.sys_prompt, "system"),
*await self.memory.get_memory(),
Msg(
"user",
self._build_finalize_instruction(
schema_json=schema_json,
validation_error=last_error,
attempt=attempt,
),
"user",
),
],
)
original_stream = self.model.stream
self.model.stream = False
try:
response = await self.model(
prompt,
tool_choice="none",
response_format={"type": "json_object"},
)
finally:
self.model.stream = original_stream
raw_text = extract_text_content(getattr(response, "content", []))
payload = parse_json_dict(raw_text)
if payload is None:
last_error = "Model output is not a valid JSON object."
continue
try:
validated = output_model.model_validate(payload)
return validated.model_dump(mode="json", exclude_none=True)
except ValidationError as exc:
last_error = str(exc)
raise RuntimeError(
f"failed to finalize structured output for {output_model.__name__}: {last_error}"
)
@staticmethod
def _build_finalize_instruction(
*,
schema_json: str,
validation_error: str,
attempt: int,
) -> str:
error_part = (
""
if not validation_error
else (
"\n\n[Validation Error From Previous Attempt]\n"
f"{validation_error}\n"
"Fix all missing/invalid fields and regenerate."
)
)
return (
"Return JSON only. Do not output markdown, prose, or code fences. "
"Follow this JSON Schema exactly and include all required fields. "
"Do not call tools.\n\n"
f"[Schema]\n{schema_json}\n\n"
f"[Attempt]\n{attempt}{error_part}"
_, payload = await finalize_json_response(
model=self.model,
formatter=self.formatter,
base_messages=[
Msg("system", self.sys_prompt, "system"),
*await self.memory.get_memory(),
],
output_model=output_model,
retries=self._finalize_retries,
)
return payload
@@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from agentscope.model import OpenAIChatModel
from core.logging import get_logger
logger = get_logger("core.agentscope.runtime.runner")
class TrackingChatModel:
def __init__(self, inner: OpenAIChatModel) -> None:
self._inner = inner
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_latency_ms = 0
self._cached_prompt_tokens = 0
@property
def stream(self) -> bool:
return self._inner.stream
@stream.setter
def stream(self, value: bool) -> None:
self._inner.stream = value
def __getattr__(self, name: str) -> Any:
return getattr(self._inner, name)
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._log_model_call(kwargs)
response = await self._inner(*args, **kwargs)
if isinstance(response, AsyncGenerator):
return self._track_stream(response)
self._record_usage(getattr(response, "usage", None))
return response
def usage_summary(self) -> dict[str, int]:
return {
"input_tokens": self._total_input_tokens,
"output_tokens": self._total_output_tokens,
"latency_ms": self._total_latency_ms,
"cached_prompt_tokens": self._cached_prompt_tokens,
}
def _log_model_call(self, kwargs: dict[str, Any]) -> None:
tools = kwargs.get("tools")
tool_names, generate_response_schema = self._extract_tool_debug_info(tools)
logger.info(
"model_call_debug",
tool_choice=kwargs.get("tool_choice"),
tool_count=len(tool_names),
tool_names=tool_names,
generate_response_schema=generate_response_schema,
)
@staticmethod
def _extract_tool_debug_info(
tools: Any,
) -> tuple[list[str], dict[str, Any] | None]:
tool_names: list[str] = []
generate_response_schema: dict[str, Any] | None = None
if not isinstance(tools, list):
return tool_names, generate_response_schema
for tool in tools:
if not isinstance(tool, dict):
continue
function = tool.get("function")
if not isinstance(function, dict):
continue
name = function.get("name")
if not isinstance(name, str):
continue
tool_names.append(name)
if name != "generate_response":
continue
parameters = function.get("parameters")
if not isinstance(parameters, dict):
continue
props = parameters.get("properties", {})
generate_response_schema = {
"required": parameters.get("required"),
"properties": list(props.keys()) if isinstance(props, dict) else [],
}
return tool_names, generate_response_schema
async def _track_stream(
self, response: AsyncGenerator[Any, None]
) -> AsyncGenerator[Any, None]:
latest_usage = None
async for chunk in response:
usage = getattr(chunk, "usage", None)
if usage is not None:
latest_usage = usage
yield chunk
self._record_usage(latest_usage)
def _record_usage(self, usage: Any) -> None:
if usage is None:
return
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
self._total_output_tokens += max(
int(getattr(usage, "output_tokens", 0) or 0), 0
)
self._total_latency_ms += max(
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
)
metadata = getattr(usage, "metadata", None)
if metadata is None:
return
self._cached_prompt_tokens += max(self._extract_cached_tokens(metadata), 0)
@staticmethod
def _extract_cached_tokens(metadata: Any) -> int:
if isinstance(metadata, dict):
prompt_details = metadata.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
return int(prompt_details.get("cached_tokens", 0) or 0)
return 0
prompt_details = getattr(metadata, "prompt_tokens_details", None)
return int(getattr(prompt_details, "cached_tokens", 0) or 0)
@@ -0,0 +1,96 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID, uuid4
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from schemas.agent.runtime_models import RouterAgentOutput
from schemas.agent.system_agent import AgentType
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
from sqlalchemy.ext.asyncio import AsyncSession
def _to_int(value: object) -> int:
if value is None:
return 0
if isinstance(value, bool):
return int(value)
if isinstance(value, int):
return value
if isinstance(value, Decimal):
return int(value)
if isinstance(value, float):
return int(value)
if isinstance(value, str):
text = value.strip()
if not text:
return 0
try:
return int(text)
except ValueError:
return int(float(text))
return 0
async def persist_router_message(
*,
session: AsyncSession,
thread_id: str,
run_id: str,
model_code: str,
router_output: RouterAgentOutput,
response_metadata: dict[str, object],
) -> None:
session_id = UUID(thread_id)
message_repo = MessageRepository(session)
session_repo = SessionRepository(session)
locked_session = await session_repo.lock_session_for_update(session_id=session_id)
if locked_session is None:
raise RuntimeError("chat session not found for router persistence")
seq = _to_int(getattr(locked_session, "message_count", 0)) + 1
metadata = AgentChatMessageMetadata(
run_id=run_id,
agent_type=AgentType.ROUTER,
router_agent_output=router_output,
)
message_payload = AgentChatMessage(
id=uuid4(),
seq=seq,
role=AgentChatMessageRole.ASSISTANT.value,
content="",
model_code=model_code,
tool_name=None,
input_tokens=_to_int(response_metadata.get("inputTokens", 0)),
output_tokens=_to_int(response_metadata.get("outputTokens", 0)),
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
latency_ms=_to_int(response_metadata.get("latencyMs", 0)),
metadata=metadata,
timestamp=datetime.now(timezone.utc),
)
await message_repo.append_message(
session_id=session_id,
seq=message_payload.seq,
role=AgentChatMessageRole.ASSISTANT,
content=message_payload.content,
model_code=message_payload.model_code,
tool_name=message_payload.tool_name,
metadata=metadata.model_dump(mode="json", exclude_none=True),
input_tokens=message_payload.input_tokens,
output_tokens=message_payload.output_tokens,
cost=message_payload.cost,
latency_ms=message_payload.latency_ms,
)
await session_repo.update_runtime_state(
chat_session=locked_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=locked_session.state_snapshot or {},
message_delta=1,
token_delta=message_payload.input_tokens + message_payload.output_tokens,
cost_delta=message_payload.cost,
)
await session.flush()
+139 -432
View File
@@ -1,30 +1,28 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4
from uuid import UUID
from ag_ui.core.types import RunAgentInput
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
from core.agentscope.events.persistence import MessageRepository, SessionRepository
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.prompts.agent_prompt import build_worker_contract_prompt
from core.agentscope.prompts.system_prompt import build_system_prompt
from core.agentscope.tools.toolkit import build_stage_toolkit, build_toolkit
from core.agentscope.runtime.utils import (
normalize_tool_name,
parse_tool_agent_output,
from core.agentscope.runtime.json_react_agent import JsonReActAgent
from core.agentscope.runtime.model_tracking import TrackingChatModel
from core.agentscope.runtime.router_persistence import persist_router_message
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
from core.agentscope.tools.toolkit import build_stage_toolkit
from core.agentscope.utils import (
finalize_json_response,
patch_agentscope_json_repair_compat,
)
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessageRole
from models.agent_chat_session import AgentChatSessionStatus
from models.llm import Llm
from models.system_agents import SystemAgents
from schemas.agent.runtime_models import (
@@ -33,7 +31,6 @@ from schemas.agent.runtime_models import (
resolve_worker_output_model,
)
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
from schemas.messages.chat_message import AgentChatMessage, AgentChatMessageMetadata
from schemas.user import UserContext
from services.litellm.service import LiteLLMService
from sqlalchemy import select
@@ -59,246 +56,9 @@ class StageExecutionResult:
response_metadata: dict[str, Any]
class _TrackingChatModel:
def __init__(self, inner: OpenAIChatModel) -> None:
self._inner = inner
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_latency_ms = 0
self._cached_prompt_tokens = 0
@property
def stream(self) -> bool:
return self._inner.stream
@stream.setter
def stream(self, value: bool) -> None:
self._inner.stream = value
def __getattr__(self, name: str) -> Any:
return getattr(self._inner, name)
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
tools = kwargs.get("tools")
tool_names: list[str] = []
generate_response_schema: dict[str, Any] | None = None
if isinstance(tools, list):
for tool in tools:
if not isinstance(tool, dict):
continue
function = tool.get("function")
if isinstance(function, dict):
name = function.get("name")
if isinstance(name, str):
tool_names.append(name)
if name == "generate_response":
parameters = function.get("parameters")
if isinstance(parameters, dict):
generate_response_schema = {
"required": parameters.get("required"),
"properties": list(
(
parameters.get("properties", {})
if isinstance(
parameters.get("properties", {}), dict
)
else {}
).keys()
),
}
logger.info(
"model_call_debug",
tool_choice=kwargs.get("tool_choice"),
tool_count=len(tool_names),
tool_names=tool_names,
generate_response_schema=generate_response_schema,
)
response = await self._inner(*args, **kwargs)
if isinstance(response, AsyncGenerator):
return self._track_stream(response)
self._record_usage(getattr(response, "usage", None))
return response
async def _track_stream(
self, response: AsyncGenerator[Any, None]
) -> AsyncGenerator[Any, None]:
latest_usage = None
async for chunk in response:
usage = getattr(chunk, "usage", None)
if usage is not None:
latest_usage = usage
yield chunk
self._record_usage(latest_usage)
def _record_usage(self, usage: Any) -> None:
if usage is None:
return
self._total_input_tokens += max(int(getattr(usage, "input_tokens", 0) or 0), 0)
self._total_output_tokens += max(
int(getattr(usage, "output_tokens", 0) or 0), 0
)
self._total_latency_ms += max(
int(round(float(getattr(usage, "time", 0) or 0) * 1000)), 0
)
metadata = getattr(usage, "metadata", None)
if metadata is not None:
cached_tokens = 0
if isinstance(metadata, dict):
prompt_details = metadata.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
cached_tokens = int(prompt_details.get("cached_tokens", 0) or 0)
else:
prompt_details = getattr(metadata, "prompt_tokens_details", None)
cached_tokens = int(getattr(prompt_details, "cached_tokens", 0) or 0)
self._cached_prompt_tokens += max(cached_tokens, 0)
def usage_summary(self) -> dict[str, int]:
return {
"input_tokens": self._total_input_tokens,
"output_tokens": self._total_output_tokens,
"latency_ms": self._total_latency_ms,
"cached_prompt_tokens": self._cached_prompt_tokens,
}
class _PipelineStageEmitter:
def __init__(
self,
*,
pipeline: PipelineLike,
session_id: str,
run_id: str,
stage: str,
emit_text_events: bool,
emit_tool_events: bool,
) -> None:
self._pipeline = pipeline
self._session_id = session_id
self._run_id = run_id
self._stage = stage
self._emit_text_events = emit_text_events
self._emit_tool_events = emit_tool_events
self._text_by_message_id: dict[str, str] = {}
self._emitted_tool_calls: set[str] = set()
self._emitted_tool_results: set[str] = set()
self.latest_text_message_id: str | None = None
self.latest_text: str = ""
async def handle_print(self, *, msg: Msg, last: bool) -> None:
del last
if self._emit_tool_events:
await self._emit_tool_events_from_msg(msg)
if self._emit_text_events:
await self._emit_text_events_from_msg(msg)
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
text = msg.get_text_content(separator="") or ""
if not text:
return
message_id = str(msg.id)
self._text_by_message_id[message_id] = text
self.latest_text_message_id = message_id
self.latest_text = text
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
for block in msg.get_content_blocks("tool_use"):
tool_call_id = str(block.get("id", "")).strip()
tool_name = str(block.get("name", "")).strip()
if (
not tool_call_id
or not tool_name
or tool_call_id in self._emitted_tool_calls
):
continue
payload = {
"messageId": str(msg.id),
"toolCallId": tool_call_id,
"toolCallName": tool_name,
"stage": self._stage,
}
await self._emit("TOOL_CALL_START", payload)
await self._emit(
"TOOL_CALL_ARGS",
{
**payload,
"args": block.get("input", {}),
},
)
await self._emit("TOOL_CALL_END", payload)
self._emitted_tool_calls.add(tool_call_id)
for block in msg.get_content_blocks("tool_result"):
tool_call_id = str(block.get("id", "")).strip()
if not tool_call_id or tool_call_id in self._emitted_tool_results:
continue
tool_output = parse_tool_agent_output(block.get("output"))
if tool_output is None:
continue
tool_output_dict = tool_output.model_dump(mode="json", exclude_none=True)
result_data = {
"messageId": str(msg.id),
"role": "tool",
"stage": self._stage,
"tool_name": tool_output.tool_name,
"tool_call_id": tool_output.tool_call_id,
"tool_call_args": tool_output.tool_call_args,
"status": tool_output.status.value,
"result_summary": tool_output.result_summary,
}
ui_hints = tool_output_dict.get("ui_hints")
if ui_hints is not None:
result_data["ui_hints"] = ui_hints
if tool_output.error:
result_data["error"] = tool_output.error.model_dump(mode="json")
await self._emit("TOOL_CALL_RESULT", result_data)
self._emitted_tool_results.add(tool_call_id)
async def emit_final_text_end(
self,
*,
worker_output: dict[str, Any],
response_metadata: dict[str, Any],
) -> None:
message_id = (
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
)
output_data = {
"messageId": message_id,
"role": "assistant",
"stage": self._stage,
"status": worker_output.get("status"),
"answer": worker_output.get("answer", ""),
"key_points": worker_output.get("key_points", []),
"result_type": worker_output.get("result_type"),
"suggested_actions": worker_output.get("suggested_actions", []),
"error": worker_output.get("error"),
}
ui_hints = worker_output.get("ui_hints")
if ui_hints is not None:
output_data["ui_hints"] = ui_hints
output_data.update(response_metadata)
await self._emit("TEXT_MESSAGE_END", output_data)
async def _emit(self, event_type: str, payload: dict[str, Any]) -> None:
await self._pipeline.emit(
session_id=self._session_id,
event={
"type": event_type,
"threadId": self._session_id,
"runId": self._run_id,
**payload,
},
)
class AgentScopeRunner:
def __init__(self, *, litellm_service: LiteLLMService | None = None) -> None:
patch_agentscope_json_repair_compat()
self._litellm_service = litellm_service or LiteLLMService()
async def execute(
@@ -310,76 +70,30 @@ class AgentScopeRunner:
run_input: RunAgentInput,
) -> dict[str, Any]:
owner_id = UUID(user_context.id)
enabled_tool_names = self._extract_tool_names(run_input)
async with AsyncSessionLocal() as session:
router_toolkit, worker_toolkit = self._build_toolkits(
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
worker_toolkit = self._build_worker_toolkit(
session=session, owner_id=owner_id
)
router_config, worker_config = await self._load_stage_configs(
session=session
)
router_config = await self._load_system_agent_config(
router_output = await self._execute_router_step(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.WORKER,
)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_STARTED",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
toolkit=router_toolkit,
run_input=run_input,
stage_config=router_config,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await self._persist_router_message(
session=session,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
model_code=router_config.model_code,
router_output=router_output,
response_metadata=router_result.response_metadata,
)
await session.commit()
await self._emit_step_event(
worker_output = await self._execute_worker_step(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_FINISHED",
)
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_STARTED",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
router_output=router_output,
toolkit=worker_toolkit,
run_input=run_input,
stage_config=worker_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_FINISHED",
)
return {
@@ -387,40 +101,107 @@ class AgentScopeRunner:
"worker": worker_output.model_dump(mode="json", exclude_none=True),
}
def _build_toolkits(
def _build_worker_toolkit(
self,
*,
session: AsyncSession,
owner_id: UUID,
enabled_tool_names: set[str] | None,
) -> tuple[Any, Any]:
return (
build_toolkit(
session=session,
owner_id=owner_id,
enabled_tool_names=set(),
),
build_stage_toolkit(
agent_type=AgentType.WORKER,
session=session,
owner_id=owner_id,
enabled_tool_names=enabled_tool_names,
),
) -> Any:
return build_stage_toolkit(
agent_type=AgentType.WORKER,
session=session,
owner_id=owner_id,
)
def _extract_tool_names(self, run_input: RunAgentInput) -> set[str] | None:
raw_tools = getattr(run_input, "tools", None)
if not isinstance(raw_tools, list):
return None
selected: set[str] = set()
for item in raw_tools:
if isinstance(item, dict):
name = item.get("name")
else:
name = getattr(item, "name", None)
if isinstance(name, str) and name.strip():
selected.add(normalize_tool_name(name))
return selected
async def _load_stage_configs(
self,
*,
session: AsyncSession,
) -> tuple[SystemAgentRuntimeConfig, SystemAgentRuntimeConfig]:
router_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.ROUTER,
)
worker_config = await self._load_system_agent_config(
session=session,
agent_type=AgentType.WORKER,
)
return router_config, worker_config
async def _execute_router_step(
self,
*,
session: AsyncSession,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
context_messages: list[Msg],
stage_config: SystemAgentRuntimeConfig,
) -> RouterAgentOutput:
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_STARTED",
)
router_result = await self._run_router_stage(
user_context=user_context,
context_messages=context_messages,
run_input=run_input,
stage_config=stage_config,
)
router_output = RouterAgentOutput.model_validate(router_result.payload)
await persist_router_message(
session=session,
thread_id=run_input.thread_id,
run_id=run_input.run_id,
model_code=stage_config.model_code,
router_output=router_output,
response_metadata=router_result.response_metadata,
)
await session.commit()
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="router",
event_type="STEP_FINISHED",
)
return router_output
async def _execute_worker_step(
self,
*,
pipeline: PipelineLike,
run_input: RunAgentInput,
user_context: UserContext,
router_output: RouterAgentOutput,
toolkit: Any,
stage_config: SystemAgentRuntimeConfig,
) -> WorkerAgentOutputLite:
worker_output_model = resolve_worker_output_model(router_output.ui.ui_mode)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_STARTED",
)
worker_result = await self._run_worker_stage(
user_context=user_context,
router_output=router_output,
toolkit=toolkit,
run_input=run_input,
stage_config=stage_config,
worker_output_model=worker_output_model,
pipeline=pipeline,
)
worker_output = worker_output_model.model_validate(worker_result.payload)
await self._emit_step_event(
pipeline=pipeline,
run_input=run_input,
step_name="worker",
event_type="STEP_FINISHED",
)
return worker_output
async def _load_system_agent_config(
self,
@@ -451,7 +232,6 @@ class AgentScopeRunner:
*,
user_context: UserContext,
context_messages: list[Msg],
toolkit: Any,
run_input: RunAgentInput,
stage_config: SystemAgentRuntimeConfig,
) -> StageExecutionResult:
@@ -462,28 +242,26 @@ class AgentScopeRunner:
now_utc=datetime.now(timezone.utc),
tools=None,
)
agent = self._build_agent(
agent_name="router",
system_prompt=system_prompt,
toolkit=toolkit,
response, payload = await finalize_json_response(
model=tracking_model,
)
response_msg = await agent.reply_json(
context_messages,
formatter=OpenAIChatFormatter(),
base_messages=[Msg("system", system_prompt, "system"), *context_messages],
output_model=RouterAgentOutput,
retries=0,
)
response_msg = Msg(
name="router",
role="assistant",
content=list(getattr(response, "content", [])),
metadata=payload,
)
logger.info(
"router_reply_received",
run_id=run_input.run_id,
thread_id=run_input.thread_id,
message_id=str(response_msg.id),
)
payload = RouterAgentOutput.model_validate(
response_msg.metadata or {}
).model_dump(
mode="json",
exclude_none=True,
)
return StageExecutionResult(
message=response_msg,
payload=payload,
@@ -504,11 +282,9 @@ class AgentScopeRunner:
worker_output_model: type[WorkerAgentOutputLite],
pipeline: PipelineLike,
) -> StageExecutionResult:
worker_input = self._build_worker_input_messages(
router_output=router_output,
)
worker_input = self._build_worker_input_messages(router_output=router_output)
tracking_model = self._build_model(stage_config=stage_config)
emitter = _PipelineStageEmitter(
emitter = PipelineStageEmitter(
pipeline=pipeline,
session_id=run_input.thread_id,
run_id=run_input.run_id,
@@ -522,15 +298,14 @@ class AgentScopeRunner:
agent_type=AgentType.WORKER,
user_context=user_context,
now_utc=datetime.now(timezone.utc),
tools=run_input.tools,
tools=None,
),
toolkit=toolkit,
model=tracking_model,
emitter=emitter,
)
response_msg = await agent.reply_json(
worker_input,
output_model=worker_output_model,
worker_input, output_model=worker_output_model
)
worker_payload = worker_output_model.model_validate(response_msg.metadata or {})
response_metadata = self._litellm_service.build_usage_metadata(
@@ -552,24 +327,17 @@ class AgentScopeRunner:
*,
router_output: RouterAgentOutput,
) -> list[Msg]:
routing_contract = json.dumps(
router_output.model_dump(mode="json", exclude_none=True),
ensure_ascii=False,
separators=(",", ":"),
)
routing_msg = Msg(
name="router",
role="user",
content=(
"Use the following routing contract as the execution source of truth. "
f"Do not change the routed objective:\n{routing_contract}"
),
)
return [routing_msg]
return [
Msg(
name="router",
role="user",
content=build_worker_contract_prompt(router_output=router_output),
)
]
def _build_model(
self, *, stage_config: SystemAgentRuntimeConfig
) -> _TrackingChatModel:
) -> TrackingChatModel:
generate_kwargs: dict[str, Any] = {
"temperature": stage_config.llm_config.temperature,
"max_tokens": stage_config.llm_config.max_tokens,
@@ -585,7 +353,7 @@ class AgentScopeRunner:
client_kwargs={"base_url": self._litellm_service.proxy_base_url},
generate_kwargs=generate_kwargs,
)
return _TrackingChatModel(model)
return TrackingChatModel(model)
def _build_agent(
self,
@@ -593,8 +361,8 @@ class AgentScopeRunner:
agent_name: str,
system_prompt: str,
toolkit: Any,
model: _TrackingChatModel,
emitter: _PipelineStageEmitter | None = None,
model: TrackingChatModel,
emitter: PipelineStageEmitter | None = None,
) -> JsonReActAgent:
return JsonReActAgent(
name=agent_name,
@@ -624,66 +392,5 @@ class AgentScopeRunner:
},
)
async def _persist_router_message(
self,
*,
session: AsyncSession,
thread_id: str,
run_id: str,
model_code: str,
router_output: RouterAgentOutput,
response_metadata: dict[str, Any],
) -> None:
session_id = UUID(thread_id)
message_repo = MessageRepository(session)
session_repo = SessionRepository(session)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
raise RuntimeError("chat session not found for router persistence")
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
metadata = AgentChatMessageMetadata(
run_id=run_id,
agent_type=AgentType.ROUTER,
router_agent_output=router_output,
)
message_payload = AgentChatMessage(
id=uuid4(),
seq=seq,
role=AgentChatMessageRole.ASSISTANT.value,
content="",
model_code=model_code,
tool_name=None,
input_tokens=int(response_metadata.get("inputTokens", 0) or 0),
output_tokens=int(response_metadata.get("outputTokens", 0) or 0),
cost=Decimal(str(response_metadata.get("cost", 0) or 0)),
latency_ms=int(response_metadata.get("latencyMs", 0) or 0),
metadata=metadata,
timestamp=datetime.now(timezone.utc),
)
await message_repo.append_message(
session_id=session_id,
seq=message_payload.seq,
role=AgentChatMessageRole.ASSISTANT,
content=message_payload.content,
model_code=message_payload.model_code,
tool_name=message_payload.tool_name,
metadata=metadata.model_dump(mode="json", exclude_none=True),
input_tokens=message_payload.input_tokens,
output_tokens=message_payload.output_tokens,
cost=message_payload.cost,
latency_ms=message_payload.latency_ms,
)
await session_repo.update_runtime_state(
chat_session=locked_session,
status=AgentChatSessionStatus.RUNNING,
state_snapshot=locked_session.state_snapshot or {},
message_delta=1,
token_delta=message_payload.input_tokens + message_payload.output_tokens,
cost_delta=message_payload.cost,
)
await session.flush()
AgentScopeReActRunner = AgentScopeRunner
@@ -0,0 +1,137 @@
from __future__ import annotations
from typing import Any, Protocol
from uuid import uuid4
from agentscope.message import Msg
from core.agentscope.utils import parse_tool_agent_output
class PipelineLike(Protocol):
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str: ...
class PipelineStageEmitter:
def __init__(
self,
*,
pipeline: PipelineLike,
session_id: str,
run_id: str,
stage: str,
emit_text_events: bool,
emit_tool_events: bool,
) -> None:
self._pipeline = pipeline
self._session_id = session_id
self._run_id = run_id
self._stage = stage
self._emit_text_events = emit_text_events
self._emit_tool_events = emit_tool_events
self._emitted_tool_calls: set[str] = set()
self._emitted_tool_results: set[str] = set()
self.latest_text_message_id: str | None = None
self.latest_text: str = ""
async def handle_print(self, *, msg: Msg, last: bool) -> None:
del last
if self._emit_tool_events:
await self._emit_tool_events_from_msg(msg)
if self._emit_text_events:
await self._emit_text_events_from_msg(msg)
async def emit_final_text_end(
self,
*,
worker_output: dict[str, Any],
response_metadata: dict[str, Any],
) -> None:
message_id = (
self.latest_text_message_id or f"worker-{self._run_id}-{uuid4().hex[:8]}"
)
payload = {
"messageId": message_id,
"role": "assistant",
"stage": self._stage,
"status": worker_output.get("status"),
"answer": worker_output.get("answer", ""),
"key_points": worker_output.get("key_points", []),
"result_type": worker_output.get("result_type"),
"suggested_actions": worker_output.get("suggested_actions", []),
"error": worker_output.get("error"),
**response_metadata,
}
ui_hints = worker_output.get("ui_hints")
if ui_hints is not None:
payload["ui_hints"] = ui_hints
await self._emit("TEXT_MESSAGE_END", payload)
async def _emit_text_events_from_msg(self, msg: Msg) -> None:
text = msg.get_text_content(separator="") or ""
if not text:
return
self.latest_text_message_id = str(msg.id)
self.latest_text = text
async def _emit_tool_events_from_msg(self, msg: Msg) -> None:
for block in msg.get_content_blocks("tool_use"):
tool_call_id = str(block.get("id", "")).strip()
tool_name = str(block.get("name", "")).strip()
if (
not tool_call_id
or not tool_name
or tool_call_id in self._emitted_tool_calls
):
continue
base_payload = {
"messageId": str(msg.id),
"toolCallId": tool_call_id,
"toolCallName": tool_name,
"stage": self._stage,
}
await self._emit("TOOL_CALL_START", base_payload)
await self._emit(
"TOOL_CALL_ARGS", {**base_payload, "args": block.get("input", {})}
)
await self._emit("TOOL_CALL_END", base_payload)
self._emitted_tool_calls.add(tool_call_id)
for block in msg.get_content_blocks("tool_result"):
tool_call_id = str(block.get("id", "")).strip()
if not tool_call_id or tool_call_id in self._emitted_tool_results:
continue
tool_output = parse_tool_agent_output(block.get("output"))
if tool_output is None:
continue
payload = {
"messageId": str(msg.id),
"role": "tool",
"stage": self._stage,
"tool_name": tool_output.tool_name,
"tool_call_id": tool_output.tool_call_id,
"tool_call_args": tool_output.tool_call_args,
"status": tool_output.status.value,
"result_summary": tool_output.result_summary,
}
ui_hints = tool_output.model_dump(mode="json", exclude_none=True).get(
"ui_hints"
)
if ui_hints is not None:
payload["ui_hints"] = ui_hints
if tool_output.error:
payload["error"] = tool_output.error.model_dump(mode="json")
await self._emit("TOOL_CALL_RESULT", payload)
self._emitted_tool_results.add(tool_call_id)
async def _emit(self, event_type: str, payload: dict[str, Any]) -> None:
await self._pipeline.emit(
session_id=self._session_id,
event={
"type": event_type,
"threadId": self._session_id,
"runId": self._run_id,
**payload,
},
)
+40 -33
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
import base64
from typing import Any
from typing import Any, cast
from uuid import UUID
from agentscope.message import Msg
@@ -21,10 +21,12 @@ from core.taskiq.app import bulk_broker, critical_broker, default_broker
from schemas.user import UserContext
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
from schemas.messages.chat_message import extract_user_message_attachments
from v1.agent.dependencies import get_agent_service
from v1.users.dependencies import get_user_service
logger = get_logger("core.agentscope.runtime.tasks")
_MAX_CONTEXT_ATTACHMENTS = 3
def _load_runtime() -> type[Any]:
@@ -63,38 +65,43 @@ async def _build_recent_context_messages(
metadata = msg.get("metadata")
if role == "user" and metadata:
attachments = metadata.get("user_message_attachments")
if attachments:
bucket = attachments.get("bucket")
path = attachments.get("path")
mime_type = attachments.get("mime_type")
if bucket and path:
try:
image_bytes = await supabase_service.download_bytes(
bucket=bucket,
path=path,
)
b64_data = base64.b64encode(image_bytes).decode("utf-8")
converted.append(
Msg(
name="user",
role="user",
content=[
{"type": "text", "text": content},
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type or "image/png",
"data": b64_data,
},
},
],
)
)
continue
except Exception:
pass
image_blocks: list[dict[str, Any]] = []
attachments = extract_user_message_attachments(metadata)[
:_MAX_CONTEXT_ATTACHMENTS
]
for attachment in attachments:
try:
image_bytes = await supabase_service.download_bytes(
bucket=attachment.bucket,
path=attachment.path,
)
except Exception:
continue
b64_data = base64.b64encode(image_bytes).decode("utf-8")
image_blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": attachment.mime_type or "image/png",
"data": b64_data,
},
}
)
if image_blocks:
multimodal_content: list[dict[str, Any]] = []
if isinstance(content, str) and content:
multimodal_content.append({"type": "text", "text": content})
multimodal_content.extend(image_blocks)
converted.append(
Msg(
name="user",
role="user",
content=cast(Any, multimodal_content),
)
)
continue
if role == "tool":
role = "assistant"
@@ -0,0 +1,23 @@
from core.agentscope.utils.compat import (
patch_agentscope_json_repair_compat,
safe_json_loads_with_repair,
)
from core.agentscope.utils.json_finalize import (
build_json_finalize_instruction,
finalize_json_response,
)
from core.agentscope.utils.parsing import (
extract_text_content,
parse_json_dict,
parse_tool_agent_output,
)
__all__ = [
"build_json_finalize_instruction",
"extract_text_content",
"finalize_json_response",
"parse_json_dict",
"parse_tool_agent_output",
"patch_agentscope_json_repair_compat",
"safe_json_loads_with_repair",
]
@@ -0,0 +1,48 @@
from __future__ import annotations
import json
from typing import Any
from core.logging import get_logger
logger = get_logger("core.agentscope.utils.compat")
_AGENTSCOPE_JSON_REPAIR_PATCHED = False
def safe_json_loads_with_repair(json_str: str) -> dict[str, Any]:
try:
from json_repair import repair_json
repair_json_any: Any = repair_json
try:
repaired = repair_json_any(json_str, **{"stream_stable": True})
except TypeError:
repaired = repair_json_any(json_str)
if isinstance(repaired, dict):
return repaired
if isinstance(repaired, str):
loaded = json.loads(repaired)
return loaded if isinstance(loaded, dict) else {}
return {}
except Exception: # noqa: BLE001
preview = json_str[:100] + "..." if len(json_str) > 100 else json_str
logger.warning("failed_to_parse_tool_arguments", preview=preview)
return {}
def patch_agentscope_json_repair_compat() -> None:
global _AGENTSCOPE_JSON_REPAIR_PATCHED
if _AGENTSCOPE_JSON_REPAIR_PATCHED:
return
try:
from agentscope._utils import _common as common_mod
from agentscope.model import _openai_model as openai_model_mod
except Exception: # noqa: BLE001
return
common_mod._json_loads_with_repair = safe_json_loads_with_repair
openai_model_mod._json_loads_with_repair = safe_json_loads_with_repair
_AGENTSCOPE_JSON_REPAIR_PATCHED = True
logger.info("patched_agentscope_json_repair_compat")
@@ -0,0 +1,96 @@
from __future__ import annotations
import json
from collections.abc import Awaitable
from typing import Any, Protocol
from agentscope.message import Msg
from pydantic import BaseModel, ValidationError
from core.agentscope.utils.parsing import extract_text_content, parse_json_dict
class FormatterProtocol(Protocol):
def format(self, *args: Any, **kwargs: Any) -> Awaitable[Any]: ...
def build_json_finalize_instruction(
*,
schema_json: str,
attempt: int,
validation_error: str = "",
) -> str:
error_part = (
""
if not validation_error
else (
"\n\n[Validation Error From Previous Attempt]\n"
f"{validation_error}\n"
"Fix all missing/invalid fields and regenerate."
)
)
return (
"Return JSON only. Do not output markdown, prose, or code fences. "
"Follow this JSON Schema exactly and include all required fields. "
"Do not call tools.\n\n"
f"[Schema]\n{schema_json}\n\n"
f"[Attempt]\n{attempt}{error_part}"
)
async def finalize_json_response(
*,
model: Any,
formatter: FormatterProtocol,
base_messages: list[Msg],
output_model: type[BaseModel],
retries: int,
) -> tuple[Any, dict[str, Any]]:
schema_json = json.dumps(
output_model.model_json_schema(),
ensure_ascii=True,
separators=(",", ":"),
)
last_error = ""
for attempt in range(1, retries + 2):
prompt = await formatter.format(
msgs=[
*base_messages,
Msg(
"user",
build_json_finalize_instruction(
schema_json=schema_json,
attempt=attempt,
validation_error=last_error,
),
"user",
),
]
)
original_stream = model.stream
model.stream = False
try:
response = await model(
prompt,
tool_choice="none",
response_format={"type": "json_object"},
)
finally:
model.stream = original_stream
raw_text = extract_text_content(getattr(response, "content", []))
payload = parse_json_dict(raw_text)
if payload is None:
last_error = "Model output is not a valid JSON object."
continue
try:
validated = output_model.model_validate(payload)
return response, validated.model_dump(mode="json", exclude_none=True)
except ValidationError as exc:
last_error = str(exc)
raise RuntimeError(
f"failed to finalize structured output for {output_model.__name__}: {last_error}"
)
@@ -4,23 +4,9 @@ import json
from collections.abc import Sequence
from typing import Any
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
from schemas.agent.runtime_models import ToolAgentOutput
def compile_ui_hints_safe(ui_hints: Any) -> dict[str, Any] | None:
if not ui_hints:
return None
try:
return compile_ui_hints(ui_hints)
except Exception:
return None
def normalize_tool_name(value: str) -> str:
return value.strip().replace(".", "_").replace("-", "_")
def parse_tool_agent_output(output: Any) -> ToolAgentOutput | None:
blocks = output if isinstance(output, Sequence) else []
for block in blocks:
+32 -3
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import ClassVar
from typing import Any, ClassVar
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
@@ -11,7 +11,7 @@ from schemas.agent.runtime_models import RouterAgentOutput, WorkerAgentOutputRic
from ..agent import AgentType, ToolAgentOutput
class UserMessageAttachments(BaseModel):
class UserMessageAttachment(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
bucket: str
@@ -23,7 +23,7 @@ class AgentChatMessageMetadata(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
run_id: str
agent_type: AgentType | None = None
user_message_attachments: UserMessageAttachments | None = None
user_message_attachments: list[UserMessageAttachment] | None = None
router_agent_output: RouterAgentOutput | None = None
tool_agent_output: ToolAgentOutput | None = None
worker_agent_output: WorkerAgentOutputRich | None = None
@@ -46,3 +46,32 @@ class AgentChatMessage(BaseModel):
latency_ms: int | None = Field(default=None, ge=0)
metadata: AgentChatMessageMetadata | dict[str, object] | None = None
timestamp: datetime
def extract_user_message_attachments(
metadata: AgentChatMessageMetadata | dict[str, object] | None,
) -> list[UserMessageAttachment]:
if metadata is None:
return []
if isinstance(metadata, AgentChatMessageMetadata):
raw_value: Any = metadata.user_message_attachments
else:
raw_value = metadata.get("user_message_attachments")
if raw_value is None:
return []
raw_items: list[Any]
if isinstance(raw_value, list):
raw_items = raw_value
else:
raw_items = [raw_value]
attachments: list[UserMessageAttachment] = []
for item in raw_items:
try:
attachments.append(UserMessageAttachment.model_validate(item))
except Exception:
continue
return attachments
+10 -3
View File
@@ -37,6 +37,13 @@ class AttachmentSignedUrlResponse(BaseModel):
url: str
class HistoryMessageAttachment(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
mime_type: str = Field(alias="mimeType")
url: str
class HistoryMessage(BaseModel):
"""History message schema for /history endpoint response."""
@@ -46,9 +53,9 @@ class HistoryMessage(BaseModel):
seq: int = Field(description="Message sequence number")
role: str = Field(description="Message role: user | assistant | tool")
content: str = Field(description="Message text content")
url: str | None = Field(
default=None,
description="Temporary signed URL for user-attached images",
attachments: list[HistoryMessageAttachment] = Field(
default_factory=list,
description="Temporary signed URLs for user-attached images",
)
ui_schema: UiSchemaRenderer | None = Field(
default=None,
+38 -18
View File
@@ -19,7 +19,8 @@ from core.config.settings import config
from core.logging import get_logger
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachments,
UserMessageAttachment,
extract_user_message_attachments,
)
from v1.agent.schemas import HistorySnapshotResponse
@@ -27,6 +28,7 @@ logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
_MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
_MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
_MAX_ATTACHMENTS_PER_MESSAGE = 3
@dataclass(frozen=True)
@@ -230,7 +232,7 @@ class AgentService:
) -> tuple[str, AgentChatMessageMetadata | None]:
text, content_blocks = extract_latest_user_payload(run_input)
user_attachments: UserMessageAttachments | None = None
user_attachments: list[UserMessageAttachment] = []
for block in content_blocks:
if not isinstance(block, dict):
continue
@@ -257,12 +259,15 @@ class AgentService:
thread_id=run_input.thread_id,
current_user=current_user,
)
user_attachments = UserMessageAttachments(
bucket=bucket,
path=path,
mime_type=mime_type,
user_attachments.append(
UserMessageAttachment(
bucket=bucket,
path=path,
mime_type=mime_type,
)
)
break
if len(user_attachments) > _MAX_ATTACHMENTS_PER_MESSAGE:
raise HTTPException(status_code=422, detail="Too many attachments")
except HTTPException:
raise
except Exception as exc: # noqa: BLE001
@@ -270,7 +275,7 @@ class AgentService:
raise HTTPException(status_code=422, detail="Invalid signed image url")
metadata: AgentChatMessageMetadata | None = None
if user_attachments is not None:
if user_attachments:
metadata = AgentChatMessageMetadata(
run_id=run_input.run_id,
user_message_attachments=user_attachments,
@@ -438,23 +443,38 @@ class AgentService:
messages: list[HistoryMessage] = []
if day_payload:
raw_messages = day_payload.get("messages") or []
raw_messages_obj = day_payload.get("messages")
raw_messages = (
raw_messages_obj if isinstance(raw_messages_obj, list) else []
)
for msg_dict in raw_messages:
msg = AgentChatMessage.model_validate(msg_dict)
signed_url: str | None = None
if self._attachment_storage and msg.metadata:
att = msg.metadata.user_message_attachments
if att:
signed_urls: dict[str, str] = {}
attachments = extract_user_message_attachments(msg.metadata)
if self._attachment_storage and attachments:
expected_prefix = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
)
for attachment in attachments:
if not _is_safe_attachment_path(
attachment.path,
expected_prefix=expected_prefix,
):
continue
signed_url = await self._attachment_storage.create_signed_url(
bucket=att.bucket,
path=att.path,
bucket=attachment.bucket,
path=attachment.path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
key = f"{attachment.bucket}/{attachment.path}"
signed_urls[key] = signed_url
converted = convert_message_to_history(msg, None)
if signed_url:
converted["url"] = signed_url
def _get_signed_url(payload: dict[str, str]) -> str:
key = f"{payload['bucket']}/{payload['path']}"
return signed_urls[key]
converted = convert_message_to_history(msg, _get_signed_url)
messages.append(HistoryMessage.model_validate(converted))
return HistorySnapshotResponse(
+28 -23
View File
@@ -11,7 +11,7 @@ from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
from schemas.messages.chat_message import (
AgentChatMessage,
AgentChatMessageMetadata,
UserMessageAttachments,
extract_user_message_attachments,
)
@@ -23,7 +23,7 @@ def convert_message_to_history(
将 AgentChatMessage 转换为 HistoryMessage 格式
转换规则:
- role=user: 读取 metadata.user_message_attachments将 bucket 转临时访问 url
- role=user: 读取 metadata.user_message_attachments转换为 attachments[]
- role=tool: 读取 content 和 metadata.tool_agent_output.ui_hints,编译成 ui_schema
- role=assistant: 读取 metadata.worker_agent_output.ui_hints,编译成 ui_schema
"""
@@ -31,11 +31,11 @@ def convert_message_to_history(
content = message.content
metadata = message.metadata
url: str | None = None
attachments: list[dict[str, str]] = []
ui_schema: dict[str, Any] | None = None
if role == "user":
url = _convert_user_attachments(metadata, get_signed_url_fn)
attachments = _convert_user_attachments(metadata, get_signed_url_fn)
elif role == "tool":
ui_schema = _compile_tool_ui_hints(metadata)
@@ -51,8 +51,8 @@ def convert_message_to_history(
"timestamp": message.timestamp.isoformat(),
}
if url:
result["url"] = url
if attachments:
result["attachments"] = attachments
if ui_schema:
result["ui_schema"] = ui_schema
@@ -63,28 +63,33 @@ def convert_message_to_history(
def _convert_user_attachments(
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
get_signed_url_fn: Callable[[dict[str, str]], str] | None,
) -> str | None:
"""转换用户附件为临时访问 URL"""
if not metadata:
return None
) -> list[dict[str, str]]:
"""转换用户附件为临时访问 URL 列表"""
if not metadata or not get_signed_url_fn:
return []
if isinstance(metadata, AgentChatMessageMetadata):
attachments = metadata.user_message_attachments
resolved = extract_user_message_attachments(metadata)
elif isinstance(metadata, dict):
resolved = extract_user_message_attachments(metadata)
else:
attachments_data = metadata.get("user_message_attachments")
if not attachments_data:
return None
attachments = UserMessageAttachments.model_validate(attachments_data)
return []
if not attachments or not get_signed_url_fn:
return None
try:
return get_signed_url_fn(
{"bucket": attachments.bucket, "path": attachments.path}
signed_attachments: list[dict[str, str]] = []
for attachment in resolved:
try:
signed_url = get_signed_url_fn(
{"bucket": attachment.bucket, "path": attachment.path}
)
except Exception:
continue
signed_attachments.append(
{
"url": signed_url,
"mimeType": attachment.mime_type,
}
)
except Exception:
return None
return signed_attachments
def _compile_tool_ui_hints(
+21 -9
View File
@@ -25,13 +25,14 @@ class AppVersionInfo(BaseModel):
router = APIRouter(prefix="/app", tags=["app"])
def _parse_version(filename: str) -> tuple[str, int] | None:
def _parse_version(filename: str) -> tuple[str, int, tuple[int, ...]] | None:
pattern = r"app[-_]v?(\d+\.\d+\.\d+)\+(\d+)\.(?:apk|ipa)"
match = re.search(pattern, filename, re.IGNORECASE)
if match:
version = match.group(1)
build = int(match.group(2))
return (version, build)
version_tuple = tuple(int(x) for x in version.split("."))
return (version, build, version_tuple)
return None
@@ -44,21 +45,32 @@ def _get_latest_release(
if not base_path.exists():
return None
ext = "ipa" if platform == "ios" else "apk"
target_ext = "ipa" if platform == "ios" else "apk"
candidates = []
MIN_APK_SIZE = 1024 * 1024 # 1MB
MIN_IPA_SIZE = 1024 * 1024 # 1MB
for f in base_path.iterdir():
if f.is_file() and f.suffix.lstrip(".").lower() == ext.lower():
parsed = _parse_version(f.name)
if parsed:
version, build = parsed
candidates.append((version, build, f.name))
if not f.is_file():
continue
ext = f.suffix.lstrip(".").lower()
if ext != target_ext:
continue
# 简单校验文件大小,排除伪装文件
if f.stat().st_size < (MIN_APK_SIZE if ext == "apk" else MIN_IPA_SIZE):
continue
parsed = _parse_version(f.name)
if parsed:
version, build, version_tuple = parsed
candidates.append((version_tuple, build, f.name))
if not candidates:
return None
candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
return candidates[0][0], candidates[0][1], candidates[0][2]
result = candidates[0]
return result[2].replace("+", "."), result[1], result[2]
def _compare_versions(