refactor: 简化 AgentScope 运行时模块与事件处理
- 移除冗余的 user_token 参数传递 - 重构 tool.result 事件使用 ToolAgentOutput 模型 - 重构 text.end 事件使用 WorkerAgentOutput 模型 - 简化 store 模块的 tool result 处理逻辑 - 更新 router/service 适配新事件结构 - 清理废弃的测试文件与设计文档 - 新增 AgentRuns 多模态存储设计文档
This commit is contained in:
@@ -50,10 +50,13 @@ def test_tool_result_wire_event_filters_sensitive_fields() -> None:
|
||||
"data": {
|
||||
"messageId": "tool-result-1",
|
||||
"toolCallId": "call-1",
|
||||
"callId": "call-1",
|
||||
"toolName": "calendar_write",
|
||||
"content": "summary",
|
||||
"ui": {"type": "calendar_operation.v1", "data": {"ok": True}},
|
||||
"toolAgentOutput": {
|
||||
"tool_name": "calendar_write",
|
||||
"tool_call_id": "call-1",
|
||||
"status": "success",
|
||||
"result_summary": "summary",
|
||||
"tool_call_args": {},
|
||||
},
|
||||
"args": {"token": "secret"},
|
||||
"result": {"raw": "secret"},
|
||||
"error": "stack trace",
|
||||
@@ -65,9 +68,32 @@ def test_tool_result_wire_event_filters_sensitive_fields() -> None:
|
||||
assert result["type"] == "TOOL_CALL_RESULT"
|
||||
assert result["messageId"] == "tool-result-1"
|
||||
assert result["toolCallId"] == "call-1"
|
||||
assert result["toolName"] == "calendar_write"
|
||||
assert result["content"] == "summary"
|
||||
assert isinstance(result.get("ui"), dict)
|
||||
assert isinstance(result.get("toolAgentOutput"), dict)
|
||||
assert "args" not in result
|
||||
assert "result" not in result
|
||||
assert "error" not in result
|
||||
|
||||
|
||||
def test_text_end_event_only_keeps_protocol_fields() -> None:
|
||||
internal = {
|
||||
"type": "text.end",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": {
|
||||
"messageId": "assistant-run-1",
|
||||
"workerAgentOutput": {"answer": "done", "status": "success"},
|
||||
"stage": "worker",
|
||||
"model": "qwen",
|
||||
"inputTokens": 1,
|
||||
"outputTokens": 2,
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "TEXT_MESSAGE_END"
|
||||
assert result["messageId"] == "assistant-run-1"
|
||||
assert isinstance(result.get("workerAgentOutput"), dict)
|
||||
assert "stage" not in result
|
||||
assert "model" not in result
|
||||
assert "inputTokens" not in result
|
||||
|
||||
@@ -49,49 +49,11 @@ class _FakeToolResultStorage:
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_marks_session_running_on_run_started(
|
||||
def _patch_repositories(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
captured: dict[str, object],
|
||||
fake_chat_session: Any,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot=None)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
captured["session_id"] = session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
}
|
||||
)
|
||||
|
||||
assert captured["status"] == _SessionStatus.RUNNING
|
||||
assert captured["message_delta"] == 0
|
||||
assert captured["token_delta"] == 0
|
||||
assert captured["cost_delta"] == Decimal("0")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_assistant_message_and_aggregates(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={"k": "v"}, message_count=6)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
@@ -118,6 +80,14 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_worker_output_with_answer_as_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=6)
|
||||
_patch_repositories(monkeypatch, captured, fake_chat_session)
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
@@ -127,7 +97,7 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"role": "assistant",
|
||||
"stage": "report",
|
||||
"stage": "worker",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
@@ -136,7 +106,7 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "hello",
|
||||
"delta": "legacy-text",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
@@ -149,177 +119,34 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
"outputTokens": 5,
|
||||
"cost": "0.123",
|
||||
"latencyMs": 250,
|
||||
"workerAgentOutput": {
|
||||
"status": "success",
|
||||
"answer": "worker-answer",
|
||||
"key_points": [],
|
||||
"result_type": "summary",
|
||||
"suggested_actions": [],
|
||||
"error": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert append_kwargs["seq"] == 7
|
||||
assert append_kwargs["content"] == "hello"
|
||||
assert append_kwargs["input_tokens"] == 3
|
||||
assert append_kwargs["output_tokens"] == 5
|
||||
assert append_kwargs["content"] == "worker-answer"
|
||||
metadata = cast(dict[str, Any], append_kwargs["metadata"])
|
||||
assert metadata["worker_agent_output"]["answer"] == "worker-answer"
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["metadata"]["latency_ms"] == 250
|
||||
assert append_kwargs["metadata"]["stage"] == "report"
|
||||
assert append_kwargs["latency_ms"] == 250
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
assert captured["cost_delta"] == Decimal("0.123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_uses_canonical_thread_id_for_buffer_keys(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=1)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
compact_thread_id = "00000000000000000000000000000001"
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": compact_thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "hello",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"threadId": compact_thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert append_kwargs["content"] == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_clears_buffer_on_run_finished(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=0)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "stale",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "RUN_FINISHED",
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"threadId": thread_id,
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
}
|
||||
)
|
||||
|
||||
assert "append_kwargs" not in captured
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_tool_call_result_as_tool_message(
|
||||
async def test_store_persists_tool_output_with_summary_as_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
_patch_repositories(monkeypatch, captured, fake_chat_session)
|
||||
|
||||
fake_storage = _FakeToolResultStorage()
|
||||
store = store_module.SqlAlchemyEventStore(
|
||||
@@ -334,128 +161,23 @@ async def test_store_persists_tool_call_result_as_tool_message(
|
||||
"runId": "run-1",
|
||||
"toolName": "calendar_write",
|
||||
"taskId": "t1",
|
||||
"stage": "execution",
|
||||
"args": {"title": "A"},
|
||||
"result": {"event_id": "evt-1", "token": "secret"},
|
||||
"stage": "worker",
|
||||
"toolAgentOutput": {
|
||||
"tool_name": "calendar_write",
|
||||
"tool_call_id": "call-1",
|
||||
"tool_call_args": {"title": "A"},
|
||||
"status": "success",
|
||||
"result_summary": "已创建日程 A",
|
||||
"ui_hints": None,
|
||||
"error": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert getattr(append_kwargs["role"], "value", None) == "tool"
|
||||
assert append_kwargs["tool_name"] == "calendar_write"
|
||||
assert append_kwargs["metadata"]["task_id"] == "t1"
|
||||
tool_call_id = append_kwargs["metadata"]["tool_call_id"]
|
||||
assert isinstance(tool_call_id, str)
|
||||
assert tool_call_id.startswith("run-1-t1-")
|
||||
assert append_kwargs["metadata"]["storage_bucket"] == "agent-tool-results"
|
||||
assert isinstance(append_kwargs["metadata"]["storage_path"], str)
|
||||
assert append_kwargs["content"].startswith("已创建日程")
|
||||
assert append_kwargs["content"] == "已创建日程 A"
|
||||
metadata = cast(dict[str, Any], append_kwargs["metadata"])
|
||||
assert metadata["tool_agent_output"]["result_summary"] == "已创建日程 A"
|
||||
assert metadata["storage_bucket"] == "agent-tool-results"
|
||||
assert len(fake_storage.upload_calls) == 1
|
||||
uploaded = fake_storage.upload_calls[0]
|
||||
assert uploaded["bucket"] == "agent-tool-results"
|
||||
payload = cast(dict[str, Any], uploaded["payload"])
|
||||
assert payload["toolName"] == "calendar_write"
|
||||
assert "args" not in payload
|
||||
assert isinstance(payload.get("result"), dict)
|
||||
assert payload["result"]["token"] == "[REDACTED]"
|
||||
assert captured["message_delta"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_sanitizes_nested_sensitive_fields_in_result_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=0)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
fake_storage = _FakeToolResultStorage()
|
||||
store = store_module.SqlAlchemyEventStore(
|
||||
session_factory=lambda: _FakeSessionCtx(),
|
||||
tool_result_storage=fake_storage,
|
||||
tool_result_bucket="agent-tool-results",
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"toolName": "calendar_write",
|
||||
"result": {
|
||||
"data": {
|
||||
"ok": True,
|
||||
"accessToken": "secret-a",
|
||||
"nested": {
|
||||
"refresh_token": "secret-b",
|
||||
},
|
||||
"items": [
|
||||
{"authorizationHeader": "secret-c"},
|
||||
],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
payload = cast(dict[str, Any], fake_storage.upload_calls[0]["payload"])
|
||||
stored_result = cast(dict[str, Any], payload["result"])
|
||||
data = cast(dict[str, Any], stored_result["data"])
|
||||
assert data["accessToken"] == "[REDACTED]"
|
||||
nested = cast(dict[str, Any], data["nested"])
|
||||
assert nested["refresh_token"] == "[REDACTED]"
|
||||
items = cast(list[Any], data["items"])
|
||||
assert isinstance(items[0], dict)
|
||||
assert items[0]["authorizationHeader"] == "[REDACTED]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_drops_buffer_when_session_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
thread_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"threadId": thread_id,
|
||||
"messageId": "assistant-run-1",
|
||||
"delta": "orphan",
|
||||
}
|
||||
)
|
||||
|
||||
assert store._message_buffers == {}
|
||||
|
||||
@@ -1,608 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.schemas.user_context import (
|
||||
UserAgentContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
||||
from core.agentscope.schemas import ReportOutput, RuntimeOutput
|
||||
from core.agentscope.schemas.agent_runtime import RunCommand
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
|
||||
from core.agentscope.schemas.execution import ExecutionToolCall
|
||||
from core.agentscope.schemas.intent import IntentOutput, IntentTask
|
||||
|
||||
|
||||
def _user_context() -> UserAgentContext:
|
||||
return UserAgentContext(
|
||||
user_id=uuid4(),
|
||||
username="tester",
|
||||
bio=None,
|
||||
settings=parse_profile_settings(
|
||||
{
|
||||
"version": 1,
|
||||
"preferences": {
|
||||
"interface_language": "zh-CN",
|
||||
"ai_language": "zh-CN",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"country": "CN",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={"latencyMs": 120},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={"latencyMs": 300},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result={"event_id": "evt-1"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={
|
||||
"model": "qwen3.5-flash",
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 5,
|
||||
"cost": 0.123,
|
||||
"latencyMs": 250,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"tool.result",
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"step.finish",
|
||||
"run.finished",
|
||||
]
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["stepName"] == "intent"
|
||||
assert calls[3]["data"]["stepName"] == "execution"
|
||||
assert calls[4]["data"]["stage"] == "intent"
|
||||
assert calls[7]["data"]["stage"] == "execution"
|
||||
assert calls[10]["data"]["toolName"] == "calendar_write"
|
||||
assert calls[10]["data"]["toolCallId"] == "run-1-t1-1"
|
||||
assert calls[10]["data"]["messageId"] == "tool-result-run-1-t1-1"
|
||||
tool_content = calls[10]["data"]["content"]
|
||||
assert tool_content == "calendar_write 执行完成"
|
||||
assert calls[11]["data"]["stepName"] == "execution"
|
||||
assert calls[12]["data"]["stepName"] == "report"
|
||||
assert calls[14]["data"]["delta"] == "hello world"
|
||||
assert calls[13]["data"]["messageId"] == calls[14]["data"]["messageId"]
|
||||
assert calls[14]["data"]["messageId"] == calls[15]["data"]["messageId"]
|
||||
assert calls[15]["data"]["model"] == "qwen3.5-flash"
|
||||
assert calls[15]["data"]["inputTokens"] == 10
|
||||
assert calls[15]["data"]["outputTokens"] == 5
|
||||
assert calls[15]["data"]["cost"] == 0.123
|
||||
assert calls[15]["data"]["latencyMs"] == 250
|
||||
assert calls[16]["data"]["stepName"] == "report"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FailOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FailOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"run.error",
|
||||
]
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["message"] == "runtime execution failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_passes_binary_payload_to_orchestrator() -> None:
|
||||
captured_user_input: object | None = None
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
return str(event.get("type", ""))
|
||||
|
||||
class _CaptureOrchestrator:
|
||||
async def run(self, **kwargs: object) -> RuntimeOutput:
|
||||
nonlocal captured_user_input
|
||||
captured_user_input = kwargs.get("user_input")
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="done",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
),
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text="ok",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_CaptureOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"data": "aGVsbG8=",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert isinstance(captured_user_input, list)
|
||||
first = captured_user_input[0]
|
||||
assert isinstance(first, dict)
|
||||
content = first.get("content")
|
||||
assert isinstance(content, list)
|
||||
binary = content[1]
|
||||
assert isinstance(binary, dict)
|
||||
assert binary.get("data") == "aGVsbG8="
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_direct_response_finishes_without_report_stage() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _DirectOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="direct-answer",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
response_metadata={"latencyMs": 88},
|
||||
),
|
||||
execution=None,
|
||||
report=ReportOutput(
|
||||
assistant_text="direct-answer",
|
||||
response_metadata={"latencyMs": 88},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_DirectOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"run.finished",
|
||||
]
|
||||
assert calls[3]["data"]["stage"] == "intent"
|
||||
assert calls[4]["data"]["delta"] == "direct-answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_parses_json_string_ui_payload() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result='{"type":"calendar_card.v1","version":"v1","data":{"ok":true,"title":"A"},"actions":[]}',
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert isinstance(data.get("ui"), dict)
|
||||
assert data["ui"]["type"] == "calendar_card.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_keeps_plain_text_content() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result="created successfully",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert data["content"] == "created successfully"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_sanitizes_sensitive_payload() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={
|
||||
"title": "A",
|
||||
"accessToken": "arg-secret",
|
||||
"author": "alice",
|
||||
},
|
||||
result={
|
||||
"ok": True,
|
||||
"accessToken": "secret-token",
|
||||
"message": "Authorization: Bearer inline-token",
|
||||
"nested": [
|
||||
{
|
||||
"authorizationHeader": "Bearer abc",
|
||||
}
|
||||
],
|
||||
},
|
||||
error="failed authorization=Bearer abc123 detail",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert isinstance(data["result"], dict)
|
||||
assert data["result"]["accessToken"] == "[REDACTED]"
|
||||
assert data["result"]["message"] == "Authorization=[REDACTED]"
|
||||
nested = data["result"]["nested"]
|
||||
assert isinstance(nested, list)
|
||||
assert nested[0]["authorizationHeader"] == "[REDACTED]"
|
||||
assert isinstance(data["args"], dict)
|
||||
assert data["args"]["accessToken"] == "[REDACTED]"
|
||||
assert data["args"]["author"] == "alice"
|
||||
assert data["error"] == "failed authorization=[REDACTED] detail"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_tool_result_keeps_non_object_result() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="TASK_EXECUTION",
|
||||
intent_summary="summary",
|
||||
direct_response=None,
|
||||
tasks=[IntentTask(task_id="t1", title="exec", objective="do")],
|
||||
complexity="complex",
|
||||
response_metadata={},
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[
|
||||
ExecutionTaskOutput(
|
||||
task_id="t1",
|
||||
status="SUCCESS",
|
||||
execution_summary="execution-ok",
|
||||
execution_data={},
|
||||
user_feedback_needs=[],
|
||||
response_metadata={},
|
||||
tool_calls=[
|
||||
ExecutionToolCall(
|
||||
tool_name="calendar_write",
|
||||
args={"title": "A"},
|
||||
result=["evt-1", "evt-2"],
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
tool_events = [item for item in calls if item.get("type") == "tool.result"]
|
||||
assert len(tool_events) == 1
|
||||
data = tool_events[0]["data"]
|
||||
assert isinstance(data, dict)
|
||||
assert isinstance(data["result"], dict)
|
||||
assert data["result"]["value"] == ["evt-1", "evt-2"]
|
||||
@@ -1,229 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.schemas.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agentscope.schemas.user_context import (
|
||||
UserAgentContext,
|
||||
parse_profile_settings,
|
||||
)
|
||||
from core.agentscope.runtime.config_loader import RuntimeStageConfig
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from schemas.user import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
def _ctx() -> UserAgentContext:
|
||||
return UserAgentContext(
|
||||
user_id=uuid4(),
|
||||
username="alice",
|
||||
bio=None,
|
||||
settings=parse_profile_settings(
|
||||
{
|
||||
"version": 1,
|
||||
"preferences": {
|
||||
"interface_language": "zh-CN",
|
||||
"ai_language": "zh-CN",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"country": "CN",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
class _FakePipeline:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
def _stage_config() -> dict[str, RuntimeStageConfig]:
|
||||
llm = SystemAgentLLMConfig(temperature=0.1, max_tokens=256, timeout_seconds=30)
|
||||
return {
|
||||
"intent": RuntimeStageConfig("intent", "qwen3.5-flash", "dashscope", llm),
|
||||
"execution": RuntimeStageConfig("execution", "deepseek-chat", "deepseek", llm),
|
||||
"report": RuntimeStageConfig("report", "deepseek-chat", "deepseek", llm),
|
||||
}
|
||||
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str:
|
||||
self.events.append({"session_id": session_id, "event": event})
|
||||
return "1-0"
|
||||
|
||||
|
||||
class _FakeRunner:
|
||||
def __init__(self) -> None:
|
||||
self.intent_calls = 0
|
||||
self.execution_calls = 0
|
||||
self.report_calls = 0
|
||||
self.last_user_input: str | list[dict[str, Any]] | None = None
|
||||
|
||||
async def run_json_stage(
|
||||
async def run_router_then_worker(
|
||||
self,
|
||||
*,
|
||||
stage_config: RuntimeStageConfig,
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
toolkit: Any | None,
|
||||
session,
|
||||
user_context,
|
||||
user_input,
|
||||
router_toolkit,
|
||||
worker_toolkit,
|
||||
extra_context=None,
|
||||
) -> dict[str, Any]:
|
||||
del agent_name, system_prompt, user_prompt, toolkit
|
||||
if stage_config.stage == "intent":
|
||||
self.intent_calls += 1
|
||||
return {
|
||||
"route": "DIRECT_RESPONSE",
|
||||
"intent_summary": "直接问候",
|
||||
"direct_response": "你好",
|
||||
"tasks": [],
|
||||
"complexity": "simple",
|
||||
"response_metadata": {"model": "qwen3.5-flash", "latencyMs": 100},
|
||||
}
|
||||
self.report_calls += 1
|
||||
del session, user_context, router_toolkit, worker_toolkit, extra_context
|
||||
self.last_user_input = user_input
|
||||
return {
|
||||
"assistant_text": "已完成",
|
||||
"response_metadata": {"source": "report-agent"},
|
||||
"worker": {
|
||||
"status": "success",
|
||||
"answer": "done",
|
||||
"key_points": [],
|
||||
"result_type": "summary",
|
||||
"suggested_actions": [],
|
||||
"error": None,
|
||||
"response_metadata": {
|
||||
"model": "qwen3.5-flash",
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class _ComplexRunner(_FakeRunner):
|
||||
async def run_json_stage(
|
||||
self,
|
||||
*,
|
||||
stage_config: RuntimeStageConfig,
|
||||
agent_name: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
toolkit: Any | None,
|
||||
) -> dict[str, Any]:
|
||||
del agent_name, system_prompt, user_prompt, toolkit
|
||||
if stage_config.stage == "intent":
|
||||
self.intent_calls += 1
|
||||
return {
|
||||
"route": "TASK_EXECUTION",
|
||||
"intent_summary": "需要写入日历",
|
||||
"direct_response": None,
|
||||
"tasks": [
|
||||
{"task_id": "t1", "title": "创建事件", "objective": "写入明天会议"}
|
||||
],
|
||||
"complexity": "complex",
|
||||
}
|
||||
if stage_config.stage == "execution":
|
||||
self.execution_calls += 1
|
||||
return {
|
||||
"task_id": "t1",
|
||||
"status": "SUCCESS",
|
||||
"execution_summary": "done",
|
||||
"execution_data": {},
|
||||
"user_feedback_needs": [],
|
||||
}
|
||||
self.report_calls += 1
|
||||
return {
|
||||
"assistant_text": "任务执行完成",
|
||||
"response_metadata": {"source": "report-agent"},
|
||||
}
|
||||
def _user_context() -> UserContext:
|
||||
return UserContext(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_direct_response_skips_execution(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_runner = _FakeRunner()
|
||||
def _run_command_with_binary() -> Any:
|
||||
from ag_ui.core import RunAgentInput
|
||||
|
||||
async def _fake_config_loader(
|
||||
_session: AsyncSession,
|
||||
) -> dict[str, RuntimeStageConfig]:
|
||||
return _stage_config()
|
||||
|
||||
class _FakeToolkit:
|
||||
def get_json_schemas(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000010",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar_read",
|
||||
"description": "read",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "看这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://example.com/signed.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
async def call_tool_function(self, tool_call: dict[str, Any]):
|
||||
del tool_call
|
||||
if False:
|
||||
yield None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.orchestrator.build_stage_toolkit",
|
||||
lambda **_: _FakeToolkit(),
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
orchestrator = AgentScopeRuntimeOrchestrator(
|
||||
runner=fake_runner,
|
||||
config_loader=_fake_config_loader,
|
||||
)
|
||||
result = await orchestrator.run(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_ctx(),
|
||||
user_input="你好",
|
||||
)
|
||||
|
||||
assert result.intent.route == "DIRECT_RESPONSE"
|
||||
assert result.execution is None
|
||||
assert result.report.assistant_text == "你好"
|
||||
assert result.report.response_metadata["model"] == "qwen3.5-flash"
|
||||
assert fake_runner.execution_calls == 0
|
||||
assert fake_runner.report_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_complex_route_runs_execution(
|
||||
async def test_orchestrator_maps_binary_to_model_image_url(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_runner = _ComplexRunner()
|
||||
|
||||
async def _fake_config_loader(
|
||||
_session: AsyncSession,
|
||||
) -> dict[str, RuntimeStageConfig]:
|
||||
return _stage_config()
|
||||
|
||||
class _FakeToolkit:
|
||||
def get_json_schemas(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar_read",
|
||||
"description": "read",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calendar_write",
|
||||
"description": "write",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
async def call_tool_function(self, tool_call: dict[str, Any]):
|
||||
del tool_call
|
||||
if False:
|
||||
yield None
|
||||
|
||||
pipeline = _FakePipeline()
|
||||
runner = _FakeRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.orchestrator.build_stage_toolkit",
|
||||
lambda **_: _FakeToolkit(),
|
||||
lambda **_: None,
|
||||
)
|
||||
orchestrator = AgentScopeRuntimeOrchestrator(pipeline=pipeline, runner=runner)
|
||||
|
||||
await orchestrator.run(
|
||||
command=_run_command_with_binary(),
|
||||
owner_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
user_context=_user_context(),
|
||||
session=None,
|
||||
)
|
||||
|
||||
orchestrator = AgentScopeRuntimeOrchestrator(
|
||||
runner=fake_runner,
|
||||
config_loader=_fake_config_loader,
|
||||
)
|
||||
result = await orchestrator.run(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_ctx(),
|
||||
user_input="帮我安排明天会议",
|
||||
assert isinstance(runner.last_user_input, list)
|
||||
assert runner.last_user_input[0]["type"] == "text"
|
||||
assert runner.last_user_input[1]["type"] == "image_url"
|
||||
assert (
|
||||
runner.last_user_input[1]["image_url"]["url"]
|
||||
== "https://example.com/signed.png"
|
||||
)
|
||||
|
||||
assert result.intent.route == "TASK_EXECUTION"
|
||||
assert result.execution is not None
|
||||
assert result.execution.overall_status == "SUCCESS"
|
||||
assert fake_runner.execution_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_emits_worker_output_on_text_end(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
pipeline = _FakePipeline()
|
||||
runner = _FakeRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.orchestrator.build_stage_toolkit",
|
||||
lambda **_: None,
|
||||
)
|
||||
orchestrator = AgentScopeRuntimeOrchestrator(pipeline=pipeline, runner=runner)
|
||||
|
||||
await orchestrator.run(
|
||||
command=_run_command_with_binary(),
|
||||
owner_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
user_context=_user_context(),
|
||||
session=None,
|
||||
)
|
||||
|
||||
emitted = [item["event"] for item in pipeline.events]
|
||||
text_end = next(item for item in emitted if item.get("type") == "text.end")
|
||||
assert text_end["data"]["workerAgentOutput"]["answer"] == "done"
|
||||
assert any(item.get("type") == "run.finished" for item in emitted)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from schemas.messages.chat_message import AgentChatMessage
|
||||
|
||||
|
||||
def test_agent_chat_message_schema_matches_messages_columns() -> None:
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
"id": uuid4(),
|
||||
"seq": 3,
|
||||
"role": "assistant",
|
||||
"content": "hello",
|
||||
"metadata": {"run_id": "run-1"},
|
||||
"timestamp": now,
|
||||
}
|
||||
|
||||
message = AgentChatMessage.model_validate(payload)
|
||||
|
||||
assert message.seq == 3
|
||||
assert message.role == "assistant"
|
||||
assert message.content == "hello"
|
||||
assert message.metadata is not None
|
||||
if isinstance(message.metadata, dict):
|
||||
assert message.metadata == {"run_id": "run-1"}
|
||||
else:
|
||||
assert message.metadata.model_dump(exclude_none=True) == {"run_id": "run-1"}
|
||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.config.settings import config
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
@@ -36,243 +35,27 @@ class _FakeSession:
|
||||
self.flushed = True
|
||||
|
||||
|
||||
class _FakeToolResultStorage:
|
||||
def __init__(self, payload: dict[str, object] | None) -> None:
|
||||
self._payload = payload
|
||||
|
||||
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
|
||||
del bucket, path
|
||||
return self._payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_hydrates_content_from_object_storage() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"toolName": "front.navigate_to_route",
|
||||
"result": {"ok": True, "applied": True, "content": "已跳转"},
|
||||
}
|
||||
),
|
||||
)
|
||||
async def test_snapshot_message_returns_raw_db_columns() -> None:
|
||||
repository = AgentRepository(session=SimpleNamespace()) # type: ignore[arg-type]
|
||||
now = datetime.now(timezone.utc)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
seq=7,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-1",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-1.json",
|
||||
},
|
||||
metadata_json={"tool_call_id": "call-1"},
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-1"
|
||||
assert payload["content"] == "已跳转"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_hydrates_ui_from_ui_schema_field() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"toolName": "calendar_write",
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True, "operation": "create"},
|
||||
"actions": [],
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="已创建日程:项目评审(明天 10:00)",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-3",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-3.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-3"
|
||||
assert payload["content"] == "已创建日程:项目评审(明天 10:00)"
|
||||
ui = payload.get("ui")
|
||||
assert isinstance(ui, dict)
|
||||
assert ui["type"] == "calendar_operation.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(None),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="inline-tool-content",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-2",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/run-1/call-2.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-2"
|
||||
assert payload["content"] == "inline-tool-content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_skips_storage_when_path_not_matching_session() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="summary",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-x",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/foreign-session/call-y.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "summary"
|
||||
assert "ui" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_rejects_path_traversal() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="summary",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-z",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/ok/../../evil/call-z.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "summary"
|
||||
assert "ui" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_supports_legacy_storage_path() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"ui_schema": {
|
||||
"type": "calendar_operation.v1",
|
||||
"version": "v1",
|
||||
"data": {"ok": True},
|
||||
"actions": [],
|
||||
},
|
||||
"content": "legacy content",
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-legacy",
|
||||
"storage_bucket": config.storage.bucket,
|
||||
"storage_path": "tool-results/old-run/call-legacy.json",
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["content"] == "legacy content"
|
||||
ui = payload.get("ui")
|
||||
assert isinstance(ui, dict)
|
||||
assert ui["type"] == "calendar_operation.v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_message_snapshot_includes_renderable_attachments() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
session_id=uuid4(),
|
||||
role=AgentChatMessageRole.USER,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="请分析这张图",
|
||||
metadata_json={
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-chat-attachments",
|
||||
"path": "agent-inputs/u1/t1/r1/m1/att-1.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["role"] == "user"
|
||||
assert payload["content"] == "请分析这张图"
|
||||
attachments = payload.get("attachments")
|
||||
assert isinstance(attachments, list)
|
||||
assert len(attachments) == 1
|
||||
first = attachments[0]
|
||||
assert isinstance(first, dict)
|
||||
assert first["mimeType"] == "image/png"
|
||||
assert isinstance(first.get("previewPath"), str)
|
||||
assert payload["seq"] == 7
|
||||
assert payload["role"] == "tool"
|
||||
assert payload["content"] == '{"offloaded":true}'
|
||||
assert payload["metadata"] == {"tool_call_id": "call-1"}
|
||||
assert "timestamp" in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -318,32 +101,3 @@ async def test_persist_user_message_keeps_existing_session_title() -> None:
|
||||
|
||||
assert session_row.title == "已有标题"
|
||||
assert session_row.message_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_message_attachment_reference_returns_item() -> None:
|
||||
session_id = str(uuid4())
|
||||
message_id = str(uuid4())
|
||||
message = SimpleNamespace(
|
||||
metadata_json={
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "bucket-test",
|
||||
"path": "agent-inputs/u/t/r/a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
fake_session = _FakeSession(message)
|
||||
repository = AgentRepository(session=fake_session) # type: ignore[arg-type]
|
||||
|
||||
ref = await repository.get_message_attachment_reference(
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
attachment_index=0,
|
||||
)
|
||||
|
||||
assert ref is not None
|
||||
assert ref["bucket"] == "bucket-test"
|
||||
assert ref["mimeType"] == "image/png"
|
||||
|
||||
@@ -12,48 +12,6 @@ from core.auth.models import CurrentUser
|
||||
from v1.agent import router as agent_router
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_run_request_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._allow_run_request(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_sse_slot_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._acquire_sse_slot(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_transcribe_request_fails_closed_when_redis_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _raise_redis_error():
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
monkeypatch.setattr(agent_router, "get_or_init_redis_client", _raise_redis_error)
|
||||
|
||||
allowed = await agent_router._allow_transcribe_request(user_id="user-1")
|
||||
|
||||
assert allowed is False
|
||||
|
||||
|
||||
def _resume_input_with_tool_message() -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
@@ -82,13 +40,7 @@ async def test_enqueue_resume_rejects_without_tool_contract() -> None:
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-resume-invalid",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "continue",
|
||||
}
|
||||
],
|
||||
"messages": [{"id": "u1", "role": "user", "content": "continue"}],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
@@ -109,10 +61,6 @@ async def test_enqueue_resume_rejects_without_tool_contract() -> None:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert (
|
||||
exc_info.value.detail
|
||||
== "RunAgentInput.messages requires a tool message with toolCallId for resume"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -141,7 +89,6 @@ async def test_enqueue_resume_rejects_when_rate_limited(
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert exc_info.value.detail == "Too many run requests"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -173,96 +120,4 @@ async def test_enqueue_resume_accepts_valid_tool_contract(
|
||||
)
|
||||
|
||||
assert result.task_id == "task-resume-1"
|
||||
assert result.thread_id == "00000000-0000-0000-0000-000000000001"
|
||||
assert result.run_id == "run-resume-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_retries_on_redis_timeout(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def _acquire(*, user_id: str) -> bool:
|
||||
del user_id
|
||||
return True
|
||||
|
||||
async def _release(*, user_id: str) -> None:
|
||||
del user_id
|
||||
|
||||
monkeypatch.setattr(agent_router, "_acquire_sse_slot", _acquire)
|
||||
monkeypatch.setattr(agent_router, "_release_sse_slot", _release)
|
||||
|
||||
class _Request:
|
||||
async def is_disconnected(self) -> bool:
|
||||
return False
|
||||
|
||||
class _Service:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def stream_events(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise RuntimeError("Timeout reading from localhost:6379")
|
||||
if self.calls == 2:
|
||||
return [{"id": "1-0", "event": {"type": "RUN_FINISHED"}}]
|
||||
return []
|
||||
|
||||
response = await agent_router.stream_events(
|
||||
request=cast(Any, _Request()),
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
last_event_id=None,
|
||||
idle_limit=2,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(str(chunk))
|
||||
if any("RUN_FINISHED" in item for item in chunks):
|
||||
break
|
||||
|
||||
merged = "".join(chunks)
|
||||
assert "event: RUN_FINISHED" in merged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_attachment_preview_rejects_negative_index() -> None:
|
||||
class _Service:
|
||||
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
raise AssertionError("get_attachment_preview should not be called")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await agent_router.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=-1,
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_attachment_preview_returns_streaming_response() -> None:
|
||||
class _Service:
|
||||
async def get_attachment_preview(self, **kwargs): # noqa: ANN003
|
||||
del kwargs
|
||||
return b"png-bytes", "image/png"
|
||||
|
||||
response = await agent_router.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
service=cast(Any, _Service()),
|
||||
current_user=CurrentUser(id=uuid4(), email="user@example.com"),
|
||||
)
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(cast(bytes, chunk))
|
||||
|
||||
assert response.media_type == "image/png"
|
||||
assert b"".join(chunks) == b"png-bytes"
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from fastapi import HTTPException
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
import v1.agent.service as agent_service_module
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
import v1.agent.service as agent_service_module
|
||||
from v1.agent.service import AgentService, AsrService
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.deleted_session_id: str | None = None
|
||||
self.created_with_session_id: str | None = None
|
||||
self.persisted_user_messages: list[dict[str, object]] = []
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
@@ -31,33 +28,23 @@ class _FakeRepository:
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id
|
||||
self.created_with_session_id = session_id
|
||||
return session_id or "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
self.deleted_session_id = session_id
|
||||
return None
|
||||
|
||||
async def get_history_day(
|
||||
self, *, session_id: str, before: date | None
|
||||
) -> dict[str, object] | None:
|
||||
del session_id
|
||||
if before is not None and before <= date(2026, 3, 6):
|
||||
return None
|
||||
return {
|
||||
"day": "2026-03-06",
|
||||
"hasMore": False,
|
||||
"messages": [{"id": "m1", "role": "assistant", "content": "hello"}],
|
||||
}
|
||||
del session_id, before
|
||||
return None
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
return None
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
@@ -76,22 +63,6 @@ class _FakeRepository:
|
||||
}
|
||||
)
|
||||
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
del session_id, message_id
|
||||
if attachment_index != 0:
|
||||
return None
|
||||
return {
|
||||
"bucket": config.storage.bucket,
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/run-1/attachment-0-a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
def __init__(self) -> None:
|
||||
@@ -100,33 +71,20 @@ class _FakeQueue:
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
self.commands.append(command)
|
||||
del dedup_key
|
||||
self.commands.append(command)
|
||||
return "task-1"
|
||||
|
||||
|
||||
class _FailingQueue:
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
del command, dedup_key
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
async def read(
|
||||
self, *, session_id: str, last_event_id: str | None
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
]
|
||||
del session_id, last_event_id
|
||||
return []
|
||||
|
||||
|
||||
class _FakeAttachmentStorage:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
@@ -135,65 +93,12 @@ class _FakeAttachmentStorage:
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"content": content,
|
||||
"content_type": content_type,
|
||||
}
|
||||
)
|
||||
del bucket, content, content_type
|
||||
return path
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"download": True,
|
||||
}
|
||||
)
|
||||
return b"png-bytes"
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
self.calls.append(
|
||||
{
|
||||
"bucket": bucket,
|
||||
"path": path,
|
||||
"signed": True,
|
||||
"expires_in_seconds": expires_in_seconds,
|
||||
}
|
||||
)
|
||||
return f"https://signed.example/{path}?exp={expires_in_seconds}"
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
||||
if url.startswith("https://signed.example/"):
|
||||
path = url.replace("https://signed.example/", "").split("?")[0]
|
||||
return "agent-test-bucket", path
|
||||
raise RuntimeError("Invalid signed URL")
|
||||
|
||||
|
||||
class _AlwaysFailAttachmentStorage:
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
del bucket, path, content, content_type
|
||||
raise RuntimeError("upload failed")
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes:
|
||||
del bucket, path
|
||||
raise RuntimeError("download failed")
|
||||
return b""
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
@@ -202,12 +107,16 @@ class _AlwaysFailAttachmentStorage:
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str:
|
||||
del bucket, path, expires_in_seconds
|
||||
raise RuntimeError("sign failed")
|
||||
del expires_in_seconds
|
||||
return f"https://signed.example/{bucket}/{path}"
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]:
|
||||
del url
|
||||
raise RuntimeError("parse failed")
|
||||
parsed = url.split("/storage/v1/object/sign/")
|
||||
if len(parsed) != 2:
|
||||
raise RuntimeError("invalid")
|
||||
bucket, path = parsed[1].split("/", 1)
|
||||
path = path.split("?", 1)[0]
|
||||
return bucket, path
|
||||
|
||||
|
||||
def _user() -> CurrentUser:
|
||||
@@ -217,13 +126,22 @@ def _user() -> CurrentUser:
|
||||
)
|
||||
|
||||
|
||||
def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
def _build_run_input(*, url: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"state": {},
|
||||
"messages": [{"id": "u1", "role": "user", "content": "hello"}],
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "binary", "mimeType": "image/png", "url": url},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
@@ -231,454 +149,69 @@ def _build_run_input(*, thread_id: str, run_id: str) -> RunAgentInput:
|
||||
)
|
||||
|
||||
|
||||
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_rejects_non_project_host_signed_url(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
user = _user()
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
first = await service.enqueue_resume(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
second = await service.enqueue_resume(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_input=run_input,
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
assert first.task_id == second.task_id
|
||||
|
||||
|
||||
async def test_enqueue_run_creates_missing_thread_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.thread_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.run_id == "run-1"
|
||||
assert accepted.created is True
|
||||
assert repository.created_with_session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert repository.committed is True
|
||||
assert queue.commands[0]["user_token"] is None
|
||||
|
||||
|
||||
async def test_enqueue_run_uses_explicit_user_token() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
user_token="Bearer access-token-1",
|
||||
)
|
||||
|
||||
assert queue.commands
|
||||
assert queue.commands[0]["user_token"] == "access-token-1"
|
||||
|
||||
|
||||
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FailingQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
raise AssertionError("expected RuntimeError")
|
||||
except RuntimeError as exc:
|
||||
assert str(exc) == "enqueue failed"
|
||||
|
||||
assert repository.deleted_session_id is None
|
||||
|
||||
|
||||
async def test_enqueue_run_handles_session_create_race() -> None:
|
||||
class _RaceRepository(_FakeRepository):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.create_calls = 0
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
if self.create_calls == 0:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
del user_id, session_id
|
||||
self.create_calls += 1
|
||||
raise IntegrityError("insert", {}, Exception("duplicate key"))
|
||||
|
||||
repository = _RaceRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
thread_id="00000000-0000-0000-0000-000000000999",
|
||||
run_id="run-1",
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
run_input=run_input,
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.created is False
|
||||
assert repository.rolled_back is True
|
||||
|
||||
|
||||
async def test_enqueue_run_parses_signed_url_and_injects_metadata(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/agent-inputs/u/t/r/file.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
assert repository.persisted_user_messages
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
assert persisted["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
assert persisted["run_id"] == "run-with-image"
|
||||
metadata = persisted["metadata"]
|
||||
assert isinstance(metadata, dict)
|
||||
attachments = metadata.get("user_message_attachments")
|
||||
assert isinstance(attachments, dict)
|
||||
assert attachments["bucket"] == "agent-test-bucket"
|
||||
assert attachments["path"] == "agent-inputs/u/t/r/file.png"
|
||||
assert attachments["mime_type"] == "image/png"
|
||||
|
||||
|
||||
async def test_enqueue_run_with_invalid_signed_url_still_succeeds(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-invalid-url",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "帮我看下这张图"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "invalid-url-format",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
assert repository.persisted_user_messages
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
metadata = persisted["metadata"]
|
||||
assert metadata is None
|
||||
|
||||
|
||||
async def test_enqueue_run_rejects_unsupported_attachment_type(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-bad-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请看附件"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/gif",
|
||||
"url": "https://signed.example/upload.gif",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.gif",
|
||||
"mimeType": "image/gif",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
url="https://evil.example.com/storage/v1/object/sign/agent-test-bucket/a.png?token=1"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
assert exc_info.value.detail == "Unsupported attachment type"
|
||||
assert attachment_storage.calls == []
|
||||
assert exc_info.value.detail == "INVALID_BINARY_URL_HOST"
|
||||
|
||||
|
||||
async def test_enqueue_run_rejects_attachment_too_large(
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_run_persists_attachment_and_queue_without_user_token(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(agent_service_module, "_MAX_ATTACHMENT_BYTES", 4)
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
repository = _FakeRepository()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-big-image",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请看附件"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": "agent-test-bucket",
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert exc_info.value.detail == "Attachment too large"
|
||||
assert len(attachment_storage.calls) == 1
|
||||
assert attachment_storage.calls[0]["download"] is True
|
||||
|
||||
|
||||
async def test_enqueue_run_accepts_binary_url_and_persists_metadata() -> None:
|
||||
repository = _FakeRepository()
|
||||
queue = _FakeQueue()
|
||||
attachment_storage = _FakeAttachmentStorage()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=queue,
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=attachment_storage,
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-with-binary-url",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{
|
||||
"type": "binary",
|
||||
"mimeType": "image/png",
|
||||
"url": "https://signed.example/upload-1.png",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"attachments": [
|
||||
{
|
||||
"bucket": config.storage.bucket,
|
||||
"path": "agent-inputs/00000000-0000-0000-0000-000000000001/00000000-0000-0000-0000-000000000001/upload-1.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
base_url = str(config.supabase.url).rstrip("/")
|
||||
safe_path = quote(
|
||||
"agent-inputs/00000000-0000-0000-0000-000000000001/"
|
||||
"00000000-0000-0000-0000-000000000001/uploads/a.png"
|
||||
)
|
||||
run_input = _build_run_input(
|
||||
url=f"{base_url}/storage/v1/object/sign/agent-test-bucket/{safe_path}?token=1"
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(run_input=run_input, current_user=_user())
|
||||
|
||||
assert accepted.task_id == "task-1"
|
||||
persisted = repository.persisted_user_messages[-1]
|
||||
persisted = repository.persisted_user_messages[0]
|
||||
metadata = persisted["metadata"]
|
||||
assert isinstance(metadata, dict)
|
||||
attachments = metadata.get("attachments")
|
||||
assert isinstance(attachments, list)
|
||||
assert attachments[0]["path"].endswith("upload-1.png")
|
||||
queue_input = queue.commands[-1]["run_input"]
|
||||
assert isinstance(queue_input, dict)
|
||||
content = queue_input["messages"][0]["content"]
|
||||
assert isinstance(content, list)
|
||||
assert content[1]["type"] == "binary"
|
||||
assert content[1]["url"] == "https://signed.example/upload-1.png"
|
||||
attachment = metadata["user_message_attachments"]
|
||||
assert attachment["bucket"] == "agent-test-bucket"
|
||||
command = queue.commands[0]
|
||||
assert "user_token" not in command
|
||||
|
||||
|
||||
async def test_get_history_snapshot_wraps_history_day_as_state_snapshot_event() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_attachment_signed_url_returns_url(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
|
||||
event = await service.get_history_snapshot(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
before=date(2026, 3, 7),
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
snapshot = event["snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["scope"] == "history_day"
|
||||
assert snapshot["day"] == "2026-03-06"
|
||||
assert snapshot["messages"][0]["id"] == "m1"
|
||||
|
||||
|
||||
async def test_get_user_history_snapshot_uses_latest_thread_when_absent() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
event = await service.get_user_history_snapshot(
|
||||
current_user=_user(),
|
||||
thread_id=None,
|
||||
before=None,
|
||||
)
|
||||
assert event["type"] == "STATE_SNAPSHOT"
|
||||
assert event["threadId"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
async def test_get_attachment_preview_returns_payload_and_mime() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
@@ -686,120 +219,36 @@ async def test_get_attachment_preview_returns_payload_and_mime() -> None:
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
payload, mime_type = await service.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
payload = await service.create_attachment_signed_url(
|
||||
bucket="agent-test-bucket",
|
||||
path="agent-inputs/00000000-0000-0000-0000-000000000001/thread-x/uploads/a.png",
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert payload == b"png-bytes"
|
||||
assert mime_type == "image/png"
|
||||
assert payload["bucket"] == "agent-test-bucket"
|
||||
assert payload["path"].endswith("/a.png")
|
||||
assert payload["url"].startswith("https://signed.example/")
|
||||
|
||||
|
||||
async def test_get_attachment_preview_rejects_invalid_path() -> None:
|
||||
class _BadPathRepository(_FakeRepository):
|
||||
async def get_message_attachment_reference(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
attachment_index: int,
|
||||
) -> dict[str, str] | None:
|
||||
del session_id, message_id, attachment_index
|
||||
return {
|
||||
"bucket": "bucket-test",
|
||||
"path": "agent-inputs/other-user/other-thread/run-1/a.png",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_attachment_signed_url_rejects_out_of_scope_path(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
agent_service_module.config.storage, "bucket", "agent-test-bucket"
|
||||
)
|
||||
service = AgentService(
|
||||
repository=_BadPathRepository(),
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
attachment_storage=_FakeAttachmentStorage(),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.get_attachment_preview(
|
||||
thread_id="00000000-0000-0000-0000-000000000001",
|
||||
message_id="00000000-0000-0000-0000-000000000010",
|
||||
attachment_index=0,
|
||||
await service.create_attachment_signed_url(
|
||||
bucket="agent-test-bucket",
|
||||
path="agent-inputs/other-user/thread-x/uploads/a.png",
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
async def test_asr_service_parses_dict_output_sentence(monkeypatch) -> None:
|
||||
result = SimpleNamespace(
|
||||
status_code=200,
|
||||
message="ok",
|
||||
output={"sentence": {"text": "你好,世界"}},
|
||||
request_id="req-test",
|
||||
)
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == "你好,世界"
|
||||
|
||||
|
||||
async def test_asr_service_parses_sentence_when_result_is_dict(monkeypatch) -> None:
|
||||
result = {
|
||||
"status_code": 200,
|
||||
"message": "ok",
|
||||
"output": {"sentence": {"text": "字典结果"}},
|
||||
"request_id": "req-dict",
|
||||
}
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == "字典结果"
|
||||
|
||||
|
||||
async def test_asr_service_returns_empty_when_sentence_missing(monkeypatch) -> None:
|
||||
result = {
|
||||
"status_code": 200,
|
||||
"message": "ok",
|
||||
"output": {},
|
||||
}
|
||||
|
||||
class _FakeRecognition:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
def call(self, *, file: str):
|
||||
del file
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(agent_service_module, "Recognition", _FakeRecognition)
|
||||
monkeypatch.setattr(AsrService, "_get_api_key", lambda self: "test-key")
|
||||
service = AsrService()
|
||||
|
||||
transcript = await service.transcribe_file("/tmp/test.wav", "test.wav")
|
||||
|
||||
assert transcript == ""
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
Reference in New Issue
Block a user