feat(agent): 增强多模态链路与工具调用能力
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
@@ -12,6 +14,9 @@ from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
|
||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||
FIXTURE_IMAGE_PATH = (
|
||||
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
|
||||
)
|
||||
|
||||
|
||||
async def _live_access_token(client: httpx.AsyncClient) -> str:
|
||||
@@ -108,6 +113,8 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
||||
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
||||
|
||||
image_data = base64.b64encode(FIXTURE_IMAGE_PATH.read_bytes()).decode("ascii")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
token = await _live_access_token(client)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
@@ -128,7 +135,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
{"type": "text", "text": "请描述图片里的内容"},
|
||||
{
|
||||
"type": "binary",
|
||||
"data": "aGVsbG8=",
|
||||
"data": image_data,
|
||||
"mimeType": "image/png",
|
||||
},
|
||||
],
|
||||
@@ -142,19 +149,20 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
|
||||
sse_resp = await client.get(
|
||||
events_url,
|
||||
headers=headers,
|
||||
params={"idle_limit": 150},
|
||||
timeout=60.0,
|
||||
)
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
|
||||
event_names = [
|
||||
line.split(":", 1)[1].strip()
|
||||
for line in sse_resp.text.splitlines()
|
||||
if line.startswith("event:")
|
||||
]
|
||||
event_names: list[str] = []
|
||||
async with client.stream(
|
||||
"GET", events_url, headers=headers, timeout=90.0
|
||||
) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith(
|
||||
"text/event-stream"
|
||||
)
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("event:"):
|
||||
event_name = line.split(":", 1)[1].strip()
|
||||
event_names.append(event_name)
|
||||
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
|
||||
break
|
||||
|
||||
assert "RUN_STARTED" in event_names
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
@@ -194,7 +202,14 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
)
|
||||
all_messages = list(rows.scalars().all())
|
||||
assert all_messages
|
||||
user_rows = [row for row in all_messages if str(row.role) == "user"]
|
||||
user_rows = [
|
||||
row
|
||||
for row in all_messages
|
||||
if (
|
||||
getattr(row.role, "value", row.role) == "user"
|
||||
or str(getattr(row.role, "value", row.role)) == "user"
|
||||
)
|
||||
]
|
||||
assert user_rows
|
||||
metadata = user_rows[0].metadata_json or {}
|
||||
attachments = metadata.get("attachments")
|
||||
|
||||
@@ -99,6 +99,16 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_START",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"role": "assistant",
|
||||
"stage": "report",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
@@ -128,6 +138,8 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
assert append_kwargs["output_tokens"] == 5
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["metadata"]["latency_ms"] == 250
|
||||
assert append_kwargs["metadata"]["stage"] == "report"
|
||||
assert append_kwargs["latency_ms"] == 250
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
assert captured["cost_delta"] == Decimal("0.123")
|
||||
@@ -255,6 +267,60 @@ async def test_store_clears_buffer_on_run_finished(
|
||||
assert "append_kwargs" not in captured
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_tool_call_result_as_tool_message(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"toolName": "calendar_write",
|
||||
"taskId": "t1",
|
||||
"stage": "execution",
|
||||
"args": {"title": "A"},
|
||||
"result": {"event_id": "evt-1"},
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert getattr(append_kwargs["role"], "value", None) == "tool"
|
||||
assert append_kwargs["tool_name"] == "calendar_write"
|
||||
assert append_kwargs["metadata"]["task_id"] == "t1"
|
||||
assert captured["message_delta"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_drops_buffer_when_session_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -13,8 +13,9 @@ from core.agentscope.schemas.user_context import (
|
||||
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
||||
from core.agentscope.schemas import ReportOutput, RuntimeOutput
|
||||
from core.agentscope.schemas.agent_runtime import RunCommand
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput
|
||||
from core.agentscope.schemas.intent import IntentOutput
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
|
||||
from core.agentscope.schemas.execution import ExecutionToolCall
|
||||
from core.agentscope.schemas.intent import IntentOutput, IntentTask
|
||||
|
||||
|
||||
def _user_context() -> UserAgentContext:
|
||||
@@ -50,20 +51,43 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response="done",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={"latencyMs": 120},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[],
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={"latencyMs": 300},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result={"event_id": "evt-1"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
response_metadata={
|
||||
"model": "qwen3.5-flash",
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 5,
|
||||
"cost": 0.123,
|
||||
"latencyMs": 250,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -86,6 +110,13 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"tool.result",
|
||||
"step.start",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
@@ -97,11 +128,19 @@ async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
assert calls[2]["data"]["stepName"] == "intent"
|
||||
assert calls[3]["data"]["stepName"] == "execution"
|
||||
assert calls[4]["data"]["stepName"] == "execution"
|
||||
assert calls[5]["data"]["stepName"] == "report"
|
||||
assert calls[7]["data"]["delta"] == "hello world"
|
||||
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
|
||||
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
|
||||
assert calls[9]["data"]["stepName"] == "report"
|
||||
assert calls[5]["data"]["stage"] == "intent"
|
||||
assert calls[8]["data"]["stage"] == "execution"
|
||||
assert calls[11]["data"]["toolName"] == "calendar_write"
|
||||
assert calls[12]["data"]["stepName"] == "report"
|
||||
assert calls[14]["data"]["delta"] == "hello world"
|
||||
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
|
||||
assert calls[14]["data"]["messageId"] == calls[15]["data"]["messageId"]
|
||||
assert calls[15]["data"]["model"] == "qwen3.5-flash"
|
||||
assert calls[15]["data"]["inputTokens"] == 10
|
||||
assert calls[15]["data"]["outputTokens"] == 5
|
||||
assert calls[15]["data"]["cost"] == 0.123
|
||||
assert calls[15]["data"]["latencyMs"] == 250
|
||||
assert calls[16]["data"]["stepName"] == "report"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -140,3 +179,129 @@ async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
|
||||
]
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["message"] == "runtime execution failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_passes_binary_payload_to_orchestrator() -> None:
|
||||
captured_user_input: object | None = None
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
return str(event.get("type", ""))
|
||||
|
||||
class _CaptureOrchestrator:
|
||||
async def run(self, **kwargs: object) -> RuntimeOutput:
|
||||
nonlocal captured_user_input
|
||||
captured_user_input = kwargs.get("user_input")
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="done",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
),
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text="ok",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_CaptureOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert isinstance(captured_user_input, list)
|
||||
first = captured_user_input[0]
|
||||
assert isinstance(first, dict)
|
||||
content = first.get("content")
|
||||
assert isinstance(content, list)
|
||||
binary = content[1]
|
||||
assert isinstance(binary, dict)
|
||||
assert binary.get("data") == "aGVsbG8="
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_direct_response_finishes_without_report_stage() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _DirectOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="direct-answer",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
response_metadata={"latencyMs": 88},
|
||||
),
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text="direct-answer",
|
||||
response_metadata={"latencyMs": 88},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_DirectOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"run.finished",
|
||||
]
|
||||
assert calls[3]["data"]["stage"] == "intent"
|
||||
assert calls[4]["data"]["delta"] == "direct-answer"
|
||||
|
||||
@@ -68,6 +68,7 @@ class _FakeRunner:
|
||||
"direct_response": "你好",
|
||||
"tasks": [],
|
||||
"complexity": "simple",
|
||||
"response_metadata": {"model": "qwen3.5-flash", "latencyMs": 100},
|
||||
}
|
||||
self.report_calls += 1
|
||||
return {
|
||||
@@ -131,7 +132,7 @@ async def test_runtime_direct_response_skips_execution(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar.read",
|
||||
"name": "calendar_read",
|
||||
"description": "read",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
@@ -162,8 +163,10 @@ async def test_runtime_direct_response_skips_execution(
|
||||
|
||||
assert result.intent.route == "DIRECT_RESPONSE"
|
||||
assert result.execution is None
|
||||
assert result.report.assistant_text == "已完成"
|
||||
assert result.report.assistant_text == "你好"
|
||||
assert result.report.response_metadata["model"] == "qwen3.5-flash"
|
||||
assert fake_runner.execution_calls == 0
|
||||
assert fake_runner.report_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -183,7 +186,7 @@ async def test_runtime_complex_route_runs_execution(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar.read",
|
||||
"name": "calendar_read",
|
||||
"description": "read",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
@@ -191,7 +194,7 @@ async def test_runtime_complex_route_runs_execution(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar.write",
|
||||
"name": "calendar_write",
|
||||
"description": "write",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
|
||||
@@ -9,6 +9,8 @@ from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||
from core.agentscope.runtime.react_runner import (
|
||||
AgentScopeReActRunner,
|
||||
_chat_response_text,
|
||||
_merge_stage_response_metadata,
|
||||
_parse_json_text,
|
||||
_to_litellm_model,
|
||||
)
|
||||
@@ -32,10 +34,10 @@ def test_to_litellm_model_keeps_prefixed_model() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_to_litellm_model_builds_prefixed_model() -> None:
|
||||
def test_to_litellm_model_uses_plain_model_name_when_unprefixed() -> None:
|
||||
assert (
|
||||
_to_litellm_model(provider_name="dashscope", model_code="qwen3.5-flash")
|
||||
== "dashscope/qwen3.5-flash"
|
||||
== "qwen3.5-flash"
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +51,24 @@ def test_parse_json_text_rejects_non_json() -> None:
|
||||
_parse_json_text("not-json")
|
||||
|
||||
|
||||
def test_chat_response_text_falls_back_to_choice_message_content() -> None:
|
||||
response = SimpleNamespace(
|
||||
content=None,
|
||||
choices=[
|
||||
{
|
||||
"message": {
|
||||
"content": '{"assistant_text":"fallback","response_metadata":{}}'
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
_chat_response_text(response)
|
||||
== '{"assistant_text":"fallback","response_metadata":{}}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_json_stage_wraps_json_decode_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -113,3 +133,88 @@ async def test_run_json_stage_wraps_runtime_error(
|
||||
user_prompt="user",
|
||||
toolkit=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_json_stage_report_merges_usage_metadata(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeLiteLLMService:
|
||||
def run_completion_with_cost(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
return SimpleNamespace(
|
||||
response={
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '{"assistant_text":"ok","response_metadata":{}}'
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=9,
|
||||
completion_tokens=4,
|
||||
cost=0.006,
|
||||
),
|
||||
)
|
||||
|
||||
runner = AgentScopeReActRunner()
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_build_litellm_service",
|
||||
lambda: _FakeLiteLLMService(),
|
||||
)
|
||||
|
||||
report_stage = RuntimeStageConfig(
|
||||
stage="report",
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
llm_config=SystemAgentLLMConfig(
|
||||
temperature=0.1,
|
||||
max_tokens=128,
|
||||
timeout_seconds=30,
|
||||
),
|
||||
)
|
||||
payload = await runner.run_json_stage(
|
||||
stage_config=report_stage,
|
||||
agent_name="report-agent",
|
||||
system_prompt="sys",
|
||||
user_prompt="user",
|
||||
toolkit=None,
|
||||
)
|
||||
|
||||
metadata = payload["response_metadata"]
|
||||
assert metadata["model"] == "dashscope/qwen3.5-flash"
|
||||
assert metadata["inputTokens"] == 9
|
||||
assert metadata["outputTokens"] == 4
|
||||
assert metadata["cost"] == 0.006
|
||||
assert isinstance(metadata["latencyMs"], int)
|
||||
assert metadata["latencyMs"] >= 0
|
||||
|
||||
|
||||
def test_merge_stage_response_metadata_estimates_cost_from_pricing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.react_runner._estimate_cost_by_pricing",
|
||||
lambda **kwargs: 0.0025,
|
||||
)
|
||||
payload = _merge_stage_response_metadata(
|
||||
payload={"route": "DIRECT_RESPONSE", "response_metadata": {}},
|
||||
stage_config=_stage_config(),
|
||||
response=SimpleNamespace(
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=12,
|
||||
completion_tokens=8,
|
||||
),
|
||||
model="qwen3.5-flash",
|
||||
),
|
||||
latency_ms=50,
|
||||
)
|
||||
|
||||
metadata = payload["response_metadata"]
|
||||
assert metadata["inputTokens"] == 12
|
||||
assert metadata["outputTokens"] == 8
|
||||
assert metadata["cost"] == 0.0025
|
||||
|
||||
@@ -71,6 +71,63 @@ async def test_run_agentscope_task_calls_runtime_run(
|
||||
assert called["resume"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_includes_recent_context_messages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured_messages: list[dict[str, Any]] = []
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
command = kwargs.get("command")
|
||||
if command is not None:
|
||||
raw_messages = getattr(command, "messages", [])
|
||||
if isinstance(raw_messages, list):
|
||||
captured_messages.extend(raw_messages)
|
||||
return object()
|
||||
|
||||
async def resume(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
return object()
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
async def _fake_context(**kwargs: object) -> list[dict[str, Any]]:
|
||||
del kwargs
|
||||
return [{"id": "ctx-1", "role": "assistant", "content": "历史上下文"}]
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"get_or_init_redis_client",
|
||||
_fake_get_redis_client,
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"_build_recent_context_messages",
|
||||
_fake_context,
|
||||
)
|
||||
|
||||
run_input = _run_input_payload()
|
||||
run_input["messages"] = [{"id": "u1", "role": "user", "content": "现在几点"}]
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": run_input,
|
||||
}
|
||||
)
|
||||
|
||||
assert len(captured_messages) == 2
|
||||
assert captured_messages[0]["id"] == "ctx-1"
|
||||
assert captured_messages[1]["id"] == "u1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_calls_runtime_resume(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -178,3 +178,89 @@ async def test_calendar_write_rejects_invalid_reminder_minutes(
|
||||
|
||||
assert result["data"]["ok"] is False
|
||||
assert result["data"]["code"] == "INVALID_ARGUMENT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_write_maps_invite_arguments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||
return {"type": "calendar_card.v1", "version": "v1", "data": {"ok": True}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
calendar_module,
|
||||
"_execute_mutate_calendar_event",
|
||||
_fake_execute,
|
||||
)
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
await calendar_module.calendar_write(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-abc",
|
||||
operation="create",
|
||||
invite_user_emails=["a@example.com"],
|
||||
invite_user_names=["alice"],
|
||||
invite_user_ids=[str(uuid4())],
|
||||
invite_permission_view=True,
|
||||
invite_permission_edit=True,
|
||||
invite_permission_invite=True,
|
||||
)
|
||||
|
||||
assert captured["inviteUserEmails"] == ["a@example.com"]
|
||||
assert captured["inviteUserNames"] == ["alice"]
|
||||
assert isinstance(captured["inviteUserIds"], list)
|
||||
assert captured["invitePermissionView"] is True
|
||||
assert captured["invitePermissionEdit"] is True
|
||||
assert captured["invitePermissionInvite"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_resolve_maps_identity_arguments(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||
return {"type": "user_lookup.v1", "version": "v1", "data": {"ok": True}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
calendar_module,
|
||||
"_execute_resolve_user_identity",
|
||||
_fake_execute,
|
||||
)
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
result = await calendar_module.user_resolve(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-abc",
|
||||
user_email="a@example.com",
|
||||
)
|
||||
|
||||
assert result["type"] == "user_lookup.v1"
|
||||
assert captured == {"userEmail": "a@example.com", "userName": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_resolve_requires_valid_user_token(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: False)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
result = await calendar_module.user_resolve(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="bad-token",
|
||||
user_name="alice",
|
||||
)
|
||||
|
||||
assert result["data"]["ok"] is False
|
||||
assert result["data"]["code"] == "UNAUTHORIZED"
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.prompts.runtime_prompt import build_intent_user_prompt
|
||||
|
||||
|
||||
def test_build_intent_user_prompt_keeps_multimodal_blocks() -> None:
|
||||
prompt = build_intent_user_prompt(
|
||||
user_input=[
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请识别图片内容"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(prompt, list)
|
||||
assert prompt
|
||||
assert prompt[0]["type"] == "text"
|
||||
assert "[Output Schema]" in prompt[0]["text"]
|
||||
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||
assert len(image_blocks) == 1
|
||||
source = image_blocks[0]["source"]
|
||||
assert source["type"] == "base64"
|
||||
assert source["media_type"] == "image/png"
|
||||
assert source["data"] == "aGVsbG8="
|
||||
|
||||
|
||||
def test_build_intent_user_prompt_filters_non_image_binary_block() -> None:
|
||||
prompt = build_intent_user_prompt(
|
||||
user_input=[
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请处理这个输入"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "application/pdf",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(prompt, list)
|
||||
image_blocks = [item for item in prompt if item.get("type") == "image"]
|
||||
assert image_blocks == []
|
||||
@@ -20,11 +20,12 @@ async def test_build_toolkit_registers_calendar_tools() -> None:
|
||||
)
|
||||
schemas = toolkit.get_json_schemas()
|
||||
names = {item["function"]["name"] for item in schemas}
|
||||
assert "calendar.read" in names
|
||||
assert "calendar.write" in names
|
||||
assert "calendar_read" in names
|
||||
assert "calendar_write" in names
|
||||
assert "user_resolve" in names
|
||||
|
||||
write_schema = next(
|
||||
item for item in schemas if item["function"]["name"] == "calendar.write"
|
||||
item for item in schemas if item["function"]["name"] == "calendar_write"
|
||||
)
|
||||
params = write_schema["function"]["parameters"]["properties"]
|
||||
assert "user_token" not in params
|
||||
|
||||
@@ -33,11 +33,11 @@ def test_calculate_cost_uses_second_qwen_tier() -> None:
|
||||
|
||||
def test_run_completion_extracts_usage_and_cost() -> None:
|
||||
service = LiteLLMService()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
result = service.run_completion_with_cost(
|
||||
model="dashscope/qwen3.5-flash",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
completion_fn=lambda **_: {
|
||||
def _fake_completion(**kwargs: object) -> dict[str, object]:
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
@@ -46,10 +46,17 @@ def test_run_completion_extracts_usage_and_cost() -> None:
|
||||
"prompt_tokens_details": {"cached_tokens": 500},
|
||||
},
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
},
|
||||
}
|
||||
|
||||
result = service.run_completion_with_cost(
|
||||
model="dashscope/qwen3.5-flash",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
response_format={"type": "json_object"},
|
||||
completion_fn=_fake_completion,
|
||||
)
|
||||
|
||||
assert result.usage.prompt_tokens == 2000
|
||||
assert result.usage.completion_tokens == 100
|
||||
assert result.usage.total_tokens == 2100
|
||||
assert result.usage.cost == pytest.approx(0.00051)
|
||||
assert captured["response_format"] == {"type": "json_object"}
|
||||
|
||||
@@ -10,6 +10,31 @@ from models.agent_chat_message import AgentChatMessageRole
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
|
||||
class _ExecuteResult:
|
||||
def __init__(self, value: object) -> None:
|
||||
self._value = value
|
||||
|
||||
def scalar_one_or_none(self) -> object:
|
||||
return self._value
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, session_row: object) -> None:
|
||||
self.session_row = session_row
|
||||
self.added: list[object] = []
|
||||
self.flushed = False
|
||||
|
||||
async def execute(self, stmt): # noqa: ANN001
|
||||
del stmt
|
||||
return _ExecuteResult(self.session_row)
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
self.flushed = True
|
||||
|
||||
|
||||
class _FakeToolResultStorage:
|
||||
def __init__(self, payload: dict[str, object] | None) -> None:
|
||||
self._payload = payload
|
||||
@@ -104,3 +129,48 @@ async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_user_message_sets_session_title_when_empty() -> None:
|
||||
session_id = str(uuid4())
|
||||
session_row = SimpleNamespace(
|
||||
message_count=0,
|
||||
title=None,
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
fake_session = _FakeSession(session_row)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
await repository.persist_user_message(
|
||||
session_id=session_id,
|
||||
run_id="run-1",
|
||||
content_text=" 请帮我安排明天下午开会 ",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert session_row.title == "请帮我安排明天下午开会"
|
||||
assert session_row.message_count == 1
|
||||
assert fake_session.flushed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||
session_id = str(uuid4())
|
||||
session_row = SimpleNamespace(
|
||||
message_count=1,
|
||||
title="已有标题",
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
fake_session = _FakeSession(session_row)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
await repository.persist_user_message(
|
||||
session_id=session_id,
|
||||
run_id="run-2",
|
||||
content_text="新的消息内容",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert session_row.title == "已有标题"
|
||||
assert session_row.message_count == 2
|
||||
|
||||
@@ -175,3 +175,53 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
|
||||
assert result.task_id == "task-resume-1"
|
||||
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||
assert result.run_id == "run-resume-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_retries_on_redis_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _acquire(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
|
||||
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
|
||||
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
|
||||
|
||||
class _Request:
|
||||
async def is_disconnected(self) -> bool:
|
||||
return False
|
||||
|
||||
class _Service:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def stream_events(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise RuntimeError("Timeout reading from localhost:6379")
|
||||
if self.calls == 2:
|
||||
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
|
||||
return []
|
||||
|
||||
response = await agent_router.stream_events(
|
||||
request=cast(Any, _Request()),
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
last_event_id=None,
|
||||
idle_limit=2,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(str(chunk))
|
||||
if any("RUN_FINISHED" in item for item in chunks):
|
||||
break
|
||||
|
||||
merged = "".join(chunks)
|
||||
assert "event: RUN_FINISHED" in merged
|
||||
|
||||
@@ -124,6 +124,19 @@ class _FakeAttachmentStorage:
|
||||
return path
|
||||
|
||||
|
||||
class _AlwaysFailAttachmentStorage:
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
del bucket, path, content, content_type
|
||||
raise RuntimeError("upload failed")
|
||||
|
||||
|
||||
def _user() -> CurrentUser:
|
||||
return CurrentUser(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
@@ -317,6 +330,54 @@ async def test_enqueue_run_uploads_user_image_to_supabase_and_injects_metadata(
|
||||
assert isinstance(attachments[0]["path"], str)
|
||||
|
||||
|
||||
async def test_enqueue_run_raises_when_attachment_upload_fails_without_fallback(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_AlwaysFailAttachmentStorage(),
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-image-fail",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"data": "aGVsbG8=",
|
||||
"mimeType": "image/png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
raise AssertionError("expected HTTPException")
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 502
|
||||
assert exc.detail == "Failed to upload attachment"
|
||||
|
||||
assert repository.persisted_user_messages == []
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
|
||||
Reference in New Issue
Block a user