feat(agent): complete task4-6 tool result persistence flow

This commit is contained in:
zl-q
2026-03-08 17:07:09 +08:00
parent 5ada60e834
commit daa1c86d02
15 changed files with 903 additions and 92 deletions
@@ -12,6 +12,10 @@ from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from services.base.supabase import supabase_service
from core.db import AsyncSessionLocal, engine
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
@@ -242,6 +246,299 @@ async def test_run_then_resume_persists_messages_and_session_state(
await cleanup_session.commit()
@pytest.mark.asyncio
async def test_resume_tool_result_offloads_to_supabase_storage_for_calendar_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
call_count = {"n": 0}
def _fake_execute(
self,
*,
user_input: str,
system_prompt: str | None = None,
tools: list[dict[str, object]] | None = None,
) -> dict[str, object]:
del self, user_input, system_prompt, tools
call_count["n"] += 1
if call_count["n"] == 1:
return {
"assistant_text": "我来创建日历事件,请稍候确认。",
"prompt_tokens": 10,
"completion_tokens": 6,
"total_tokens": 16,
"cost": 0.002,
"pending_front_tool": {
"name": "front.create_calendar_event",
"args": {
"title": "测试日程",
"start": "2026-03-09T09:00:00+08:00",
"end": "2026-03-09T10:00:00+08:00",
},
"target": "frontend",
},
"agui_events": [],
}
return {
"assistant_text": "日历已创建。",
"prompt_tokens": 2,
"completion_tokens": 2,
"total_tokens": 4,
"cost": 0.001,
"pending_front_tool": None,
"agui_events": [],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
factory_id = uuid.uuid4()
test_user_id: str | None = None
test_user_email = f"agent-it-{uuid.uuid4().hex[:8]}@example.com"
owner_id = uuid.uuid4()
initialized = await supabase_service.initialize()
if not initialized:
pytest.skip("Supabase service is unavailable")
admin_client = supabase_service.get_admin_client()
tool_result_storage = create_tool_result_storage()
assert tool_result_storage is not None
created_user = admin_client.auth.admin.create_user(
{
"email": test_user_email,
"password": "Passw0rd!123",
"email_confirm": True,
"user_metadata": {"source": "integration-test"},
}
)
test_user_id = str(created_user.user.id)
owner_id = uuid.UUID(test_user_id)
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
llm_row = await lookup_session.execute(select(Llm.id).limit(1))
llm_id = llm_row.scalar_one_or_none()
if llm_id is None:
async with AsyncSessionLocal() as seed_session:
factory_row = await seed_session.execute(
select(LlmFactory.id).where(LlmFactory.name == "dashscope").limit(1)
)
existing_factory_id = factory_row.scalar_one_or_none()
if existing_factory_id is None:
seed_session.add(
LlmFactory(
id=factory_id,
name="dashscope",
request_url="https://dashscope.example",
)
)
await seed_session.commit()
else:
factory_id = existing_factory_id
async with AsyncSessionLocal() as seed_session:
llm_id = uuid.uuid4()
seed_session.add(
Llm(
id=llm_id,
factory_id=factory_id,
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
)
)
await seed_session.commit()
storage = admin_client.storage
try:
storage.get_bucket("private")
except Exception:
storage.create_bucket("private", "private", {"public": False})
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
uploaded_path: str | None = None
try:
probe_path = f"tool-results/probe/{uuid.uuid4().hex}.json"
try:
storage.from_("private").upload(probe_path, b"{}")
storage.from_("private").remove([probe_path])
except Exception:
pytest.skip(
"Supabase Storage upload API unavailable in current environment"
)
async with AsyncSessionLocal() as seed_session:
existing_profile = await seed_session.get(Profile, owner_id)
if existing_profile is None:
seed_session.add(
Profile(
id=owner_id,
username=f"it_{uuid.uuid4().hex[:8]}",
)
)
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
run_result = await run_agent_task(
{
"command": "run",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-storage-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "帮我创建明天9点到10点的日历",
}
],
"tools": [
{
"name": "front.create_calendar_event",
"description": "Create calendar event",
"parameters": {"type": "object"},
}
],
"context": [],
"forwardedProps": {},
},
},
run_service=RunService(),
resume_service=ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
snapshot = run_result["state_snapshot"]
assert isinstance(snapshot, dict)
pending_tool_nonce = snapshot.get("pending_tool_nonce")
assert isinstance(pending_tool_nonce, str)
await run_agent_task(
{
"command": "resume",
"run_input": {
"threadId": str(session_uuid),
"runId": "run-storage-2",
"state": {},
"messages": [
{
"id": "tool-1",
"role": "tool",
"toolCallId": pending_tool_call_id,
"content": json.dumps(
{
"toolName": "front.create_calendar_event",
"toolArgs": {
"title": "测试日程",
"start": "2026-03-09T09:00:00+08:00",
"end": "2026-03-09T10:00:00+08:00",
"__nonce": pending_tool_nonce,
},
"nonce": pending_tool_nonce,
"result": {
"ok": True,
"type": "calendar_card.v1",
"version": "v1",
"data": {
"id": "evt-test",
"title": "测试日程",
"description": "x" * 9000,
},
"actions": [
{
"type": "link",
"label": "查看详情",
"target": "/calendar/events/evt-test",
}
],
},
},
ensure_ascii=True,
separators=(",", ":"),
),
}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
},
run_service=RunService(),
resume_service=ResumeService(
tool_result_storage=tool_result_storage,
tool_result_bucket="private",
tool_result_prefix="tool-results",
),
)
await engine.dispose()
async with AsyncSessionLocal() as verify_session:
rows = await verify_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.role == AgentChatMessageRole.TOOL)
.order_by(AgentChatMessage.seq.desc())
)
tool_message = rows.scalars().first()
assert tool_message is not None
metadata = tool_message.metadata_json or {}
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
assert storage_bucket == "private"
assert isinstance(storage_path, str)
assert storage_path.startswith("tool-results/")
uploaded_path = storage_path
downloaded = storage.from_("private").download(uploaded_path)
if isinstance(downloaded, bytes):
downloaded_payload = json.loads(downloaded.decode("utf-8"))
else:
downloaded_payload = json.loads(str(downloaded))
assert downloaded_payload["toolName"] == "front.create_calendar_event"
result_payload = downloaded_payload["result"]
assert result_payload["type"] == "calendar_card.v1"
assert result_payload["data"]["id"] == "evt-test"
finally:
if uploaded_path:
try:
storage.from_("private").remove([uploaded_path])
except Exception:
pass
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.execute(delete(Profile).where(Profile.id == owner_id))
await cleanup_session.execute(
delete(Llm).where(Llm.factory_id == factory_id)
)
await cleanup_session.execute(
delete(LlmFactory).where(LlmFactory.id == factory_id)
)
await cleanup_session.commit()
if test_user_id is not None:
try:
admin_client.auth.admin.delete_user(test_user_id)
except Exception:
pass
await supabase_service.close()
@pytest.mark.asyncio
async def test_run_service_embeds_profile_settings_in_runtime_system_prompt(
monkeypatch: pytest.MonkeyPatch,
@@ -0,0 +1,65 @@
from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import cast
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool import (
_execute_create_calendar_event,
)
@pytest.mark.asyncio
async def test_create_calendar_event_tool_returns_ui_schema_v1_top_level(
monkeypatch: pytest.MonkeyPatch,
) -> None:
event_id = uuid4()
created = SimpleNamespace(
id=event_id,
title="晨会",
description="同步计划",
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
end_at=datetime(2026, 3, 8, 2, 0, tzinfo=timezone.utc),
timezone="Asia/Shanghai",
)
class _FakeService:
def __init__(self, **kwargs) -> None:
del kwargs
async def create_agent_generated(self, payload):
del payload
return created
class _FakeRepository:
def __init__(self, session) -> None:
del session
monkeypatch.setattr(
"core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool.ScheduleItemService",
_FakeService,
)
monkeypatch.setattr(
"core.agent.infrastructure.crewai.tools.backend.create_calendar_event_tool.SQLAlchemyScheduleItemRepository",
_FakeRepository,
)
result = cast(
dict[str, object],
await _execute_create_calendar_event(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
tool_args={"title": "晨会"},
),
)
assert result["type"] == "calendar_card.v1"
assert result["version"] == "v1"
data = cast(dict[str, object], result["data"])
actions = cast(list[dict[str, object]], result["actions"])
assert data["id"] == str(event_id)
assert actions
@@ -116,7 +116,7 @@ def test_runtime_needs_execution_and_collects_front_tool_call() -> None:
tools = item["tools"]
assert isinstance(tools, list)
assert any(t.get("name") == "front.navigate_to_route" for t in tools)
execution_tools = calls[1]["tools"]
execution_tools = cast(list[dict[str, object]], calls[1]["tools"])
assert any(t.get("name") == "back.create_calendar_event" for t in execution_tools)
assert result["assistant_text"] == "do it"
assert result["pending_front_tool"] == {
@@ -131,3 +131,51 @@ def test_runtime_backend_registry_check() -> None:
runtime = _build_runtime()
assert runtime.is_registered_backend_tool("back.create_calendar_event") is True
assert runtime.is_registered_backend_tool("back.unknown") is False
def test_runtime_emits_step_started_finished_for_all_three_stages() -> None:
runtime = _build_runtime()
def _fake_run_stage(self, **kwargs):
stage = kwargs["stage"]
if stage == "intent":
return (
'{"route":"NEEDS_EXECUTION","intent_summary":"need tool","execution_brief":"do it","safety_flags":[]}',
UsageCost(1, 1, 2, 0.01),
[],
None,
)
if stage == "execution":
return (
'{"status":"SUCCESS","execution_summary":"done","execution_data":{},"report_brief":"ok"}',
UsageCost(2, 2, 4, 0.02),
[],
None,
)
return (
'{"assistant_text":"final answer","response_metadata":{"source":"organization"}}',
UsageCost(3, 3, 6, 0.03),
[],
None,
)
runtime._run_stage_with_crewai = MethodType(_fake_run_stage, runtime) # type: ignore[method-assign]
result = runtime.execute(user_input="go", tools=[])
agui_events = cast(list[dict[str, object]], result["agui_events"])
step_events = [
event
for event in agui_events
if event.get("type") in {"STEP_STARTED", "STEP_FINISHED"}
]
assert len(step_events) == 6
assert [
cast(dict[str, object], event["data"])["stage"] for event in step_events
] == [
"intent",
"intent",
"execution",
"execution",
"organization",
"organization",
]
@@ -406,6 +406,114 @@ async def test_resume_service_rejects_tool_result_when_not_ok(
)
@pytest.mark.asyncio
async def test_resume_service_offloads_large_tool_result_payload_to_object_storage(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_id = uuid4()
user_id = uuid4()
captured: list[dict[str, object]] = []
class _FakeDbSession:
async def commit(self) -> None:
return None
class _FakeSessionFactory:
def __call__(self) -> "_FakeSessionFactory":
return self
async def __aenter__(self) -> _FakeDbSession:
return _FakeDbSession()
async def __aexit__(self, exc_type, exc, tb) -> bool:
del exc_type, exc, tb
return False
class _FakeSessionRepository:
def __init__(self, session: object) -> None:
del session
async def lock_session_for_update(self, *, session_id: object):
return SimpleNamespace(
id=session_id,
user_id=user_id,
status=AgentChatSessionStatus.RUNNING,
message_count=0,
total_tokens=0,
total_cost=0,
state_snapshot={
"pending_tool_call_id": "call-1",
"pending_tool_name": "front.navigate_to_route",
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
"pending_tool_nonce": "nonce-1",
},
)
async def next_message_seq(self, *, session_id: object) -> int:
del session_id
return 1
async def update_runtime_state(self, **kwargs) -> None:
del kwargs
class _FakeMessageRepository:
def __init__(self, session: object) -> None:
del session
async def append_message(self, **kwargs) -> None:
captured.append(kwargs)
class _FakeStorage:
async def upload_json(
self, *, bucket: str, path: str, payload: dict[str, object]
) -> str:
del bucket, path, payload
return "etag-1"
monkeypatch.setattr(
"core.agent.application.resume_service.SessionRepository",
_FakeSessionRepository,
)
monkeypatch.setattr(
"core.agent.application.resume_service.MessageRepository",
_FakeMessageRepository,
)
service = ResumeService( # type: ignore[call-arg]
session_factory=_FakeSessionFactory(), # type: ignore[arg-type]
tool_result_storage=_FakeStorage(),
tool_result_offload_threshold_bytes=1,
tool_result_bucket="private",
tool_result_prefix="tool-results",
)
await service.resume(
run_input=_build_resume_input(
thread_id=str(session_id),
tool_call_id="call-1",
content=json.dumps(
{
"toolName": "front.navigate_to_route",
"toolArgs": {
"target": "/calendar/dayweek",
"replace": False,
"__nonce": "nonce-1",
},
"nonce": "nonce-1",
"result": {"ok": True, "payload": "x" * 4096},
},
ensure_ascii=True,
separators=(",", ":"),
),
)
)
metadata = captured[0]["metadata"]
assert isinstance(metadata, dict)
assert metadata["storage_bucket"] == "private"
assert metadata["storage_path"].startswith("tool-results/")
assert isinstance(metadata["payload_sha256"], str)
@pytest.mark.asyncio
async def test_load_agent_model_selection_returns_validated_llm_config() -> None:
run_service = RunService()
+58 -29
View File
@@ -1,43 +1,72 @@
from __future__ import annotations
from fastapi import HTTPException
from datetime import datetime, timezone
from types import SimpleNamespace
from uuid import uuid4
import pytest
from models.agent_chat_message import AgentChatMessageRole
from v1.agent.repository import AgentRepository
class _FakeSession:
def __init__(self) -> None:
self.added: list[object] = []
class _FakeToolResultStorage:
def __init__(self, payload: dict[str, object] | None) -> None:
self._payload = payload
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
return None
async def refresh(self, _obj: object) -> None:
return None
async def read_json(self, *, bucket: str, path: str) -> dict[str, object] | None:
del bucket, path
return self._payload
async def test_create_session_for_user_creates_session_row() -> None:
session = _FakeSession()
repository = AgentRepository(session=session) # type: ignore[arg-type]
await repository.create_session_for_user(
user_id="00000000-0000-0000-0000-000000000001"
@pytest.mark.asyncio
async def test_tool_message_hydrates_content_from_object_storage() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(
{
"toolName": "front.navigate_to_route",
"result": {"ok": True, "applied": True, "content": "已跳转"},
}
),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content='{"offloaded":true}',
metadata_json={
"tool_call_id": "call-1",
"storage_bucket": "private",
"storage_path": "tool-results/run-1/call-1.json",
},
)
session_row = session.added[0]
assert str(getattr(session_row, "user_id")) == "00000000-0000-0000-0000-000000000001"
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["toolCallId"] == "call-1"
assert payload["content"] == "已跳转"
async def test_create_session_for_user_rejects_invalid_uuid() -> None:
session = _FakeSession()
repository = AgentRepository(session=session) # type: ignore[arg-type]
@pytest.mark.asyncio
async def test_tool_message_keeps_inline_content_when_storage_payload_missing() -> None:
repository = AgentRepository(
session=SimpleNamespace(), # type: ignore[arg-type]
tool_result_storage=_FakeToolResultStorage(None),
)
message = SimpleNamespace(
id=uuid4(),
role=AgentChatMessageRole.TOOL,
created_at=datetime.now(timezone.utc),
content="inline-tool-content",
metadata_json={
"tool_call_id": "call-2",
"storage_bucket": "private",
"storage_path": "tool-results/run-1/call-2.json",
},
)
try:
await repository.create_session_for_user(user_id="invalid-uuid")
raise AssertionError("expected invalid user_id")
except HTTPException as exc:
assert exc.status_code == 422
assert exc.detail == "Invalid user_id"
payload = await repository._to_snapshot_message(message) # type: ignore[arg-type]
assert payload["toolCallId"] == "call-2"
assert payload["content"] == "inline-tool-content"