feat(agent): persist pending tool call in session snapshot
This commit is contained in:
@@ -14,7 +14,7 @@ from sqlalchemy import (
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
@@ -75,3 +75,7 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
|
||||
total_cost: Mapped[Decimal] = mapped_column(
|
||||
Numeric(12, 6), nullable=False, server_default=text("0")
|
||||
)
|
||||
state_snapshot: Mapped[dict | None] = mapped_column(
|
||||
JSONB,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
@@ -284,3 +284,54 @@ class AgentChatService(BaseService):
|
||||
"content": message,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
|
||||
}
|
||||
|
||||
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
return None
|
||||
return session.state_snapshot
|
||||
|
||||
async def set_pending_tool_call(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict,
|
||||
expires_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot or {}
|
||||
snapshot["pending_tool_call"] = {
|
||||
"interrupt_id": interrupt_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
}
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
async def update_pending_tool_call_status(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot
|
||||
if snapshot is None or "pending_tool_call" not in snapshot:
|
||||
raise ValueError("No pending tool call found")
|
||||
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
|
||||
raise ValueError("Interrupt ID mismatch")
|
||||
snapshot["pending_tool_call"]["status"] = status
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
if isinstance(obj, AgentChatSession):
|
||||
self._sessions[obj.id] = obj
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object) -> None:
|
||||
pass
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db() -> FakeAsyncSession:
|
||||
return FakeAsyncSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
|
||||
sess = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
fake_db.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestPendingToolCall:
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_pending_tool_call_to_state_snapshot(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
)
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is not None
|
||||
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
|
||||
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
|
||||
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_returns_none_when_empty(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pending_tool_call_status(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
await service.update_pending_tool_call_status(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
status="APPROVED_EXECUTING",
|
||||
)
|
||||
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||
Reference in New Issue
Block a user