feat(agent): persist pending tool call in session snapshot

This commit is contained in:
qzl
2026-03-03 15:39:56 +08:00
parent e03923e593
commit cff1436bc6
3 changed files with 171 additions and 1 deletions
+5 -1
View File
@@ -14,7 +14,7 @@ from sqlalchemy import (
func, func,
text, text,
) )
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin from core.db.base import Base, SoftDeleteMixin, TimestampMixin
@@ -75,3 +75,7 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
total_cost: Mapped[Decimal] = mapped_column( total_cost: Mapped[Decimal] = mapped_column(
Numeric(12, 6), nullable=False, server_default=text("0") Numeric(12, 6), nullable=False, server_default=text("0")
) )
state_snapshot: Mapped[dict | None] = mapped_column(
JSONB,
nullable=True,
)
+51
View File
@@ -284,3 +284,54 @@ class AgentChatService(BaseService):
"content": message, "content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"}, "usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
} }
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
return None
return session.state_snapshot
async def set_pending_tool_call(
self,
*,
session_id: UUID,
interrupt_id: str,
tool_name: str,
tool_args: dict,
expires_at: datetime,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot or {}
snapshot["pending_tool_call"] = {
"interrupt_id": interrupt_id,
"tool_name": tool_name,
"tool_args": tool_args,
"status": "PENDING_APPROVAL",
"expires_at": expires_at.isoformat(),
"decision": None,
"result": None,
}
session.state_snapshot = snapshot
async def update_pending_tool_call_status(
self,
*,
session_id: UUID,
interrupt_id: str,
status: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
raise ValueError("No pending tool call found")
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot
@@ -0,0 +1,115 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
def add(self, obj: object) -> None:
self.added.append(obj)
if isinstance(obj, AgentChatSession):
self._sessions[obj.id] = obj
async def flush(self) -> None:
return None
async def commit(self) -> None:
pass
async def rollback(self) -> None:
pass
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
@pytest.fixture
def fake_db() -> FakeAsyncSession:
return FakeAsyncSession()
@pytest.fixture
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
sess = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
fake_db.add(sess)
return sess
@pytest.fixture
def service(fake_db: FakeAsyncSession) -> AgentChatService:
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
class TestPendingToolCall:
@pytest.mark.asyncio
async def test_save_pending_tool_call_to_state_snapshot(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is not None
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
@pytest.mark.asyncio
async def test_get_state_snapshot_returns_none_when_empty(
self, service: AgentChatService, session: AgentChatSession
):
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is None
@pytest.mark.asyncio
async def test_update_pending_tool_call_status(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-2",
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
)
await service.update_pending_tool_call_status(
session_id=session.id,
interrupt_id="int-2",
status="APPROVED_EXECUTING",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"