feat(agent): persist pending tool call in session snapshot

This commit is contained in:
qzl
2026-03-03 15:39:56 +08:00
parent e03923e593
commit cff1436bc6
3 changed files with 171 additions and 1 deletions
@@ -0,0 +1,115 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
def add(self, obj: object) -> None:
self.added.append(obj)
if isinstance(obj, AgentChatSession):
self._sessions[obj.id] = obj
async def flush(self) -> None:
return None
async def commit(self) -> None:
pass
async def rollback(self) -> None:
pass
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
@pytest.fixture
def fake_db() -> FakeAsyncSession:
return FakeAsyncSession()
@pytest.fixture
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
sess = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
fake_db.add(sess)
return sess
@pytest.fixture
def service(fake_db: FakeAsyncSession) -> AgentChatService:
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
class TestPendingToolCall:
@pytest.mark.asyncio
async def test_save_pending_tool_call_to_state_snapshot(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is not None
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
@pytest.mark.asyncio
async def test_get_state_snapshot_returns_none_when_empty(
self, service: AgentChatService, session: AgentChatSession
):
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is None
@pytest.mark.asyncio
async def test_update_pending_tool_call_status(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-2",
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
)
await service.update_pending_tool_call_status(
session_id=session.id,
interrupt_id="int-2",
status="APPROVED_EXECUTING",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"