feat(agent): 增强多模态链路与工具调用能力

This commit is contained in:
zl-q
2026-03-12 00:18:45 +08:00
parent 18db6c50e7
commit 21ba8e4a44
35 changed files with 2057 additions and 829 deletions
@@ -13,8 +13,9 @@ from core.agentscope.schemas.user_context import (
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
from core.agentscope.schemas import ReportOutput, RuntimeOutput
from core.agentscope.schemas.agent_runtime import RunCommand
from core.agentscope.schemas.execution import ExecutionBatchOutput
from core.agentscope.schemas.intent import IntentOutput
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
from core.agentscope.schemas.execution import ExecutionToolCall
from core.agentscope.schemas.intent import IntentOutput, IntentTask
def _user_context() -> UserAgentContext:
@@ -50,20 +51,43 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
route="TASK_EXECUTION",
intent_summary="summary",
direct_response="done",
tasks=[],
complexity="simple",
direct_response=None,
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
complexity="complex",
response_metadata={"latencyMs": 120},
),
execution=ExecutionBatchOutput(
task_results=[],
task_results=[
ExecutionTaskOutput(
task_id="t1",
status="SUCCESS",
execution_summary="execution-ok",
execution_data={},
user_feedback_needs=[],
response_metadata={"latencyMs": 300},
tool_calls=[
ExecutionToolCall(
tool_name="calendar_write",
args={"title": "A"},
result={"event_id": "evt-1"},
)
],
)
],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
response_metadata={
"model": "qwen3.5-flash",
"inputTokens": 10,
"outputTokens": 5,
"cost": 0.123,
"latencyMs": 250,
},
),
)
@@ -86,6 +110,13 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
"step.finish",
"step.start",
"step.finish",
"text.start",
"text.delta",
"text.end",
"text.start",
"text.delta",
"text.end",
"tool.result",
"step.start",
"text.start",
"text.delta",
@@ -97,11 +128,19 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
assert calls[2]["data"]["stepName"] == "intent"
assert calls[3]["data"]["stepName"] == "execution"
assert calls[4]["data"]["stepName"] == "execution"
assert calls[5]["data"]["stepName"] == "report"
assert calls[7]["data"]["delta"] == "hello world"
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
assert calls[9]["data"]["stepName"] == "report"
assert calls[5]["data"]["stage"] == "intent"
assert calls[8]["data"]["stage"] == "execution"
assert calls[11]["data"]["toolName"] == "calendar_write"
assert calls[12]["data"]["stepName"] == "report"
assert calls[14]["data"]["delta"] == "hello world"
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
assert calls[14]["data"]["messageId"] == calls[15]["data"]["messageId"]
assert calls[15]["data"]["model"] == "qwen3.5-flash"
assert calls[15]["data"]["inputTokens"] == 10
assert calls[15]["data"]["outputTokens"] == 5
assert calls[15]["data"]["cost"] == 0.123
assert calls[15]["data"]["latencyMs"] == 250
assert calls[16]["data"]["stepName"] == "report"
@pytest.mark.asyncio
@@ -140,3 +179,129 @@ async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
]
assert calls[1]["data"]["stepName"] == "intent"
assert calls[2]["data"]["message"] == "runtime execution failed"
@pytest.mark.asyncio
async def test_runtime_passes_binary_payload_to_orchestrator() -> None:
captured_user_input: object | None = None
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
return str(event.get("type", ""))
class _CaptureOrchestrator:
async def run(self, **kwargs: object) -> RuntimeOutput:
nonlocal captured_user_input
captured_user_input = kwargs.get("user_input")
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
intent_summary="summary",
direct_response="done",
tasks=[],
complexity="simple",
),
execution=None,
report=ReportOutput(
assistant_text="ok",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_CaptureOrchestrator(),
pipeline=_FakePipeline(),
)
command = RunCommand.model_validate(
{
"threadId": "thread-1",
"runId": "run-1",
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "hello"},
{
"type": "binary",
"mimeType": "image/png",
"data": "aGVsbG8=",
},
],
}
],
}
)
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
assert isinstance(captured_user_input, list)
first = captured_user_input[0]
assert isinstance(first, dict)
content = first.get("content")
assert isinstance(content, list)
binary = content[1]
assert isinstance(binary, dict)
assert binary.get("data") == "aGVsbG8="
@pytest.mark.asyncio
async def test_runtime_direct_response_finishes_without_report_stage() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _DirectOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
intent_summary="summary",
direct_response="direct-answer",
tasks=[],
complexity="simple",
response_metadata={"latencyMs": 88},
),
execution=None,
report=ReportOutput(
assistant_text="direct-answer",
response_metadata={"latencyMs": 88},
),
)
runtime = AgentRouteRuntime(
orchestrator=_DirectOrchestrator(),
pipeline=_FakePipeline(),
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
assert [item["type"] for item in calls] == [
"run.started",
"step.start",
"step.finish",
"text.start",
"text.delta",
"text.end",
"run.finished",
]
assert calls[3]["data"]["stage"] == "intent"
assert calls[4]["data"]["delta"] == "direct-answer"
@@ -68,6 +68,7 @@ class _FakeRunner:
"direct_response": "你好",
"tasks": [],
"complexity": "simple",
"response_metadata": {"model": "qwen3.5-flash", "latencyMs": 100},
}
self.report_calls += 1
return {
@@ -131,7 +132,7 @@ async def test_runtime_direct_response_skips_execution(
{
"type": "function",
"function": {
"name": "calendar.read",
"name": "calendar_read",
"description": "read",
"parameters": {"type": "object", "properties": {}},
},
@@ -162,8 +163,10 @@ async def test_runtime_direct_response_skips_execution(
assert result.intent.route == "DIRECT_RESPONSE"
assert result.execution is None
assert result.report.assistant_text == "已完成"
assert result.report.assistant_text == "你好"
assert result.report.response_metadata["model"] == "qwen3.5-flash"
assert fake_runner.execution_calls == 0
assert fake_runner.report_calls == 0
@pytest.mark.asyncio
@@ -183,7 +186,7 @@ async def test_runtime_complex_route_runs_execution(
{
"type": "function",
"function": {
"name": "calendar.read",
"name": "calendar_read",
"description": "read",
"parameters": {"type": "object", "properties": {}},
},
@@ -191,7 +194,7 @@ async def test_runtime_complex_route_runs_execution(
{
"type": "function",
"function": {
"name": "calendar.write",
"name": "calendar_write",
"description": "write",
"parameters": {"type": "object", "properties": {}},
},
@@ -9,6 +9,8 @@ from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
from core.agentscope.runtime.config_loader import RuntimeStageConfig
from core.agentscope.runtime.react_runner import (
AgentScopeReActRunner,
_chat_response_text,
_merge_stage_response_metadata,
_parse_json_text,
_to_litellm_model,
)
@@ -32,10 +34,10 @@ def test_to_litellm_model_keeps_prefixed_model() -> None:
)
def test_to_litellm_model_builds_prefixed_model() -> None:
def test_to_litellm_model_uses_plain_model_name_when_unprefixed() -> None:
assert (
_to_litellm_model(provider_name="dashscope", model_code="qwen3.5-flash")
== "dashscope/qwen3.5-flash"
== "qwen3.5-flash"
)
@@ -49,6 +51,24 @@ def test_parse_json_text_rejects_non_json() -> None:
_parse_json_text("not-json")
def test_chat_response_text_falls_back_to_choice_message_content() -> None:
response = SimpleNamespace(
content=None,
choices=[
{
"message": {
"content": '{"assistant_text":"fallback","response_metadata":{}}'
}
}
],
)
assert (
_chat_response_text(response)
== '{"assistant_text":"fallback","response_metadata":{}}'
)
@pytest.mark.asyncio
async def test_run_json_stage_wraps_json_decode_error(
monkeypatch: pytest.MonkeyPatch,
@@ -113,3 +133,88 @@ async def test_run_json_stage_wraps_runtime_error(
user_prompt="user",
toolkit=None,
)
@pytest.mark.asyncio
async def test_run_json_stage_report_merges_usage_metadata(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _FakeLiteLLMService:
def run_completion_with_cost(self, **kwargs: object) -> object:
del kwargs
return SimpleNamespace(
response={
"model": "dashscope/qwen3.5-flash",
"choices": [
{
"message": {
"content": '{"assistant_text":"ok","response_metadata":{}}'
}
}
],
},
usage=SimpleNamespace(
prompt_tokens=9,
completion_tokens=4,
cost=0.006,
),
)
runner = AgentScopeReActRunner()
monkeypatch.setattr(
runner,
"_build_litellm_service",
lambda: _FakeLiteLLMService(),
)
report_stage = RuntimeStageConfig(
stage="report",
model_code="qwen3.5-flash",
provider_name="dashscope",
llm_config=SystemAgentLLMConfig(
temperature=0.1,
max_tokens=128,
timeout_seconds=30,
),
)
payload = await runner.run_json_stage(
stage_config=report_stage,
agent_name="report-agent",
system_prompt="sys",
user_prompt="user",
toolkit=None,
)
metadata = payload["response_metadata"]
assert metadata["model"] == "dashscope/qwen3.5-flash"
assert metadata["inputTokens"] == 9
assert metadata["outputTokens"] == 4
assert metadata["cost"] == 0.006
assert isinstance(metadata["latencyMs"], int)
assert metadata["latencyMs"] >= 0
def test_merge_stage_response_metadata_estimates_cost_from_pricing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"core.agentscope.runtime.react_runner._estimate_cost_by_pricing",
lambda **kwargs: 0.0025,
)
payload = _merge_stage_response_metadata(
payload={"route": "DIRECT_RESPONSE", "response_metadata": {}},
stage_config=_stage_config(),
response=SimpleNamespace(
usage=SimpleNamespace(
prompt_tokens=12,
completion_tokens=8,
),
model="qwen3.5-flash",
),
latency_ms=50,
)
metadata = payload["response_metadata"]
assert metadata["inputTokens"] == 12
assert metadata["outputTokens"] == 8
assert metadata["cost"] == 0.0025
@@ -71,6 +71,63 @@ async def test_run_agentscope_task_calls_runtime_run(
assert called["resume"] == 0
@pytest.mark.asyncio
async def test_run_agentscope_task_includes_recent_context_messages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_messages: list[dict[str, Any]] = []
class _FakeRuntime:
def __init__(self, **kwargs: object) -> None:
del kwargs
async def run(self, **kwargs: object) -> object:
command = kwargs.get("command")
if command is not None:
raw_messages = getattr(command, "messages", [])
if isinstance(raw_messages, list):
captured_messages.extend(raw_messages)
return object()
async def resume(self, **kwargs: object) -> object:
del kwargs
return object()
async def _fake_get_redis_client() -> object:
return object()
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
del kwargs
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
"get_or_init_redis_client",
_fake_get_redis_client,
)
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
monkeypatch.setattr(
tasks_module,
"_build_recent_context_messages",
_fake_context,
)
run_input = _run_input_payload()
run_input["messages"] = [{"id": "u1", "role": "user", "content": "现在几点"}]
await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": run_input,
}
)
assert len(captured_messages) == 2
assert captured_messages[0]["id"] == "ctx-1"
assert captured_messages[1]["id"] == "u1"
@pytest.mark.asyncio
async def test_run_agentscope_task_calls_runtime_resume(
monkeypatch: pytest.MonkeyPatch,