feat(agent): 增强多模态链路与工具调用能力
This commit is contained in:
@@ -99,6 +99,16 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_START",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"messageId": "assistant-run-1",
|
||||
"role": "assistant",
|
||||
"stage": "report",
|
||||
}
|
||||
)
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
@@ -128,6 +138,8 @@ async def test_store_persists_assistant_message_and_aggregates(
|
||||
assert append_kwargs["output_tokens"] == 5
|
||||
assert append_kwargs["cost"] == Decimal("0.123")
|
||||
assert append_kwargs["metadata"]["latency_ms"] == 250
|
||||
assert append_kwargs["metadata"]["stage"] == "report"
|
||||
assert append_kwargs["latency_ms"] == 250
|
||||
assert captured["message_delta"] == 1
|
||||
assert captured["token_delta"] == 8
|
||||
assert captured["cost_delta"] == Decimal("0.123")
|
||||
@@ -255,6 +267,60 @@ async def test_store_clears_buffer_on_run_finished(
|
||||
assert "append_kwargs" not in captured
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_persists_tool_call_result_as_tool_message(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
fake_chat_session = SimpleNamespace(state_snapshot={}, message_count=2)
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def get_session(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id): # noqa: ANN001
|
||||
del session_id
|
||||
return fake_chat_session
|
||||
|
||||
async def update_runtime_state(self, **kwargs): # noqa: ANN003
|
||||
captured.update(kwargs)
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs): # noqa: ANN003
|
||||
captured["append_kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(store_module, "SessionRepository", _FakeSessionRepository)
|
||||
monkeypatch.setattr(store_module, "MessageRepository", _FakeMessageRepository)
|
||||
monkeypatch.setattr(store_module, "AgentChatSessionStatus", _SessionStatus)
|
||||
|
||||
store = store_module.SqlAlchemyEventStore(session_factory=lambda: _FakeSessionCtx())
|
||||
await store.persist(
|
||||
{
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-1",
|
||||
"toolName": "calendar_write",
|
||||
"taskId": "t1",
|
||||
"stage": "execution",
|
||||
"args": {"title": "A"},
|
||||
"result": {"event_id": "evt-1"},
|
||||
}
|
||||
)
|
||||
|
||||
append_kwargs = cast(dict[str, Any], captured["append_kwargs"])
|
||||
assert getattr(append_kwargs["role"], "value", None) == "tool"
|
||||
assert append_kwargs["tool_name"] == "calendar_write"
|
||||
assert append_kwargs["metadata"]["task_id"] == "t1"
|
||||
assert captured["message_delta"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_drops_buffer_when_session_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
Reference in New Issue
Block a user