feat(agent): complete task4-6 tool result persistence flow
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool import (
|
||||
_execute_create_calendar_event,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_calendar_event_tool_returns_ui_schema_v1_top_level(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
event_id = uuid4()
|
||||
created = SimpleNamespace(
|
||||
id=event_id,
|
||||
title="晨会",
|
||||
description="同步计划",
|
||||
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
|
||||
end_at=datetime(2026, 3, 8, 2, 0, tzinfo=timezone.utc),
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
async def create_agent_generated(self, payload):
|
||||
del payload
|
||||
return created
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self, session) -> None:
|
||||
del session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool.ScheduleItemService",
|
||||
_FakeService,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool.SQLAlchemyScheduleItemRepository",
|
||||
_FakeRepository,
|
||||
)
|
||||
|
||||
result = cast(
|
||||
dict[str, object],
|
||||
await _execute_create_calendar_event(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
tool_args={"title": "晨会"},
|
||||
),
|
||||
)
|
||||
|
||||
assert result["type"] == "calendar_card.v1"
|
||||
assert result["version"] == "v1"
|
||||
data = cast(dict[str, object], result["data"])
|
||||
actions = cast(list[dict[str, object]], result["actions"])
|
||||
assert data["id"] == str(event_id)
|
||||
assert actions
|
||||
@@ -116,7 +116,7 @@ def test_runtime_needs_execution_and_collects_front_tool_call() -> None:
|
||||
tools = item["tools"]
|
||||
assert isinstance(tools, list)
|
||||
assert any(t.get("name") == "front.navigate_to_route" for t in tools)
|
||||
execution_tools = calls[1]["tools"]
|
||||
execution_tools = cast(list[dict[str, object]], calls[1]["tools"])
|
||||
assert any(t.get("name") == "back.create_calendar_event" for t in execution_tools)
|
||||
assert result["assistant_text"] == "do it"
|
||||
assert result["pending_front_tool"] == {
|
||||
@@ -131,3 +131,51 @@ def test_runtime_backend_registry_check() -> None:
|
||||
runtime = _build_runtime()
|
||||
assert runtime.is_registered_backend_tool("back.create_calendar_event") is True
|
||||
assert runtime.is_registered_backend_tool("back.unknown") is False
|
||||
|
||||
|
||||
def test_runtime_emits_step_started_finished_for_all_three_stages() -> None:
|
||||
runtime = _build_runtime()
|
||||
|
||||
def _fake_run_stage(self, **kwargs):
|
||||
stage = kwargs["stage"]
|
||||
if stage == "intent":
|
||||
return (
|
||||
'{"route":"NEEDS_EXECUTION","intent_summary":"need tool","execution_brief":"do it","safety_flags":[]}',
|
||||
UsageCost(1, 1, 2, 0.01),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
if stage == "execution":
|
||||
return (
|
||||
'{"status":"SUCCESS","execution_summary":"done","execution_data":{},"report_brief":"ok"}',
|
||||
UsageCost(2, 2, 4, 0.02),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
return (
|
||||
'{"assistant_text":"final answer","response_metadata":{"source":"organization"}}',
|
||||
UsageCost(3, 3, 6, 0.03),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
runtime._run_stage_with_crewai = MethodType(_fake_run_stage, runtime) # type: ignore[method-assign]
|
||||
result = runtime.execute(user_input="go", tools=[])
|
||||
|
||||
agui_events = cast(list[dict[str, object]], result["agui_events"])
|
||||
step_events = [
|
||||
event
|
||||
for event in agui_events
|
||||
if event.get("type") in {"STEP_STARTED", "STEP_FINISHED"}
|
||||
]
|
||||
assert len(step_events) == 6
|
||||
assert [
|
||||
cast(dict[str, object], event["data"])["stage"] for event in step_events
|
||||
] == [
|
||||
"intent",
|
||||
"intent",
|
||||
"execution",
|
||||
"execution",
|
||||
"organization",
|
||||
"organization",
|
||||
]
|
||||
|
||||
@@ -406,6 +406,114 @@ async def test_resume_service_rejects_tool_result_when_not_ok(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_offloads_large_tool_result_payload_to_object_storage(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: list[dict[str, object]] = []
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "front.navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object) -> int:
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.append(kwargs)
|
||||
|
||||
class _FakeStorage:
|
||||
async def upload_json(
|
||||
self, *, bucket: str, path: str, payload: dict[str, object]
|
||||
) -> str:
|
||||
del bucket, path, payload
|
||||
return "etag-1"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.resume_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
|
||||
service = ResumeService( # type: ignore[call-arg]
|
||||
session_factory=_FakeSessionFactory(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeStorage(),
|
||||
tool_result_offload_threshold_bytes=1,
|
||||
tool_result_bucket="private",
|
||||
tool_result_prefix="tool-results",
|
||||
)
|
||||
await service.resume(
|
||||
run_input=_build_resume_input(
|
||||
thread_id=str(session_id),
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "front.navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": True, "payload": "x" * 4096},
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
metadata = captured[0]["metadata"]
|
||||
assert isinstance(metadata, dict)
|
||||
assert metadata["storage_bucket"] == "private"
|
||||
assert metadata["storage_path"].startswith("tool-results/")
|
||||
assert isinstance(metadata["payload_sha256"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_agent_model_selection_returns_validated_llm_config() -> None:
|
||||
run_service = RunService()
|
||||
|
||||
@@ -1,43 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from v1.agent.repository import AgentRepository
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
class _FakeToolResultStorage:
|
||||
def __init__(self, payload: dict[str, object] | None) -> None:
|
||||
self._payload = payload
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def refresh(self, _obj: object) -> None:
|
||||
return None
|
||||
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
|
||||
del bucket, path
|
||||
return self._payload
|
||||
|
||||
|
||||
async def test_create_session_for_user_creates_session_row() -> None:
|
||||
session = _FakeSession()
|
||||
repository = AgentRepository(session=session) # type: ignore[arg-type]
|
||||
|
||||
await repository.create_session_for_user(
|
||||
user_id="00000000-0000-0000-0000-000000000001"
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_hydrates_content_from_object_storage() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(
|
||||
{
|
||||
"toolName": "front.navigate_to_route",
|
||||
"result": {"ok": True, "applied": True, "content": "已跳转"},
|
||||
}
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content='{"offloaded":true}',
|
||||
metadata_json={
|
||||
"tool_call_id": "call-1",
|
||||
"storage_bucket": "private",
|
||||
"storage_path": "tool-results/run-1/call-1.json",
|
||||
},
|
||||
)
|
||||
|
||||
session_row = session.added[0]
|
||||
assert str(getattr(session_row, "user_id")) == "00000000-0000-0000-0000-000000000001"
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-1"
|
||||
assert payload["content"] == "已跳转"
|
||||
|
||||
|
||||
async def test_create_session_for_user_rejects_invalid_uuid() -> None:
|
||||
session = _FakeSession()
|
||||
repository = AgentRepository(session=session) # type: ignore[arg-type]
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
|
||||
repository = AgentRepository(
|
||||
session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
tool_result_storage=_FakeToolResultStorage(None),
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
id=uuid4(),
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
content="inline-tool-content",
|
||||
metadata_json={
|
||||
"tool_call_id": "call-2",
|
||||
"storage_bucket": "private",
|
||||
"storage_path": "tool-results/run-1/call-2.json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await repository.create_session_for_user(user_id="invalid-uuid")
|
||||
raise AssertionError("expected invalid user_id")
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 422
|
||||
assert exc.detail == "Invalid user_id"
|
||||
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
|
||||
|
||||
assert payload["toolCallId"] == "call-2"
|
||||
assert payload["content"] == "inline-tool-content"
|
||||
|
||||
Reference in New Issue
Block a user