diff --git a/backend/src/models/agent_chat_session.py b/backend/src/models/agent_chat_session.py index 55ac364..a888ff1 100644 --- a/backend/src/models/agent_chat_session.py +++ b/backend/src/models/agent_chat_session.py @@ -14,7 +14,7 @@ from sqlalchemy import ( func, text, ) -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column from core.db.base import Base, SoftDeleteMixin, TimestampMixin @@ -75,3 +75,7 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base): total_cost: Mapped[Decimal] = mapped_column( Numeric(12, 6), nullable=False, server_default=text("0") ) + state_snapshot: Mapped[dict | None] = mapped_column( + JSONB, + nullable=True, + ) diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index 69ed9e9..a506eb6 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -284,3 +284,54 @@ class AgentChatService(BaseService): "content": message, "usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"}, } + + async def get_state_snapshot(self, session_id: UUID) -> dict | None: + stmt = select(AgentChatSession).where(AgentChatSession.id == session_id) + session = await self._session.scalar(stmt) + if session is None: + return None + return session.state_snapshot + + async def set_pending_tool_call( + self, + *, + session_id: UUID, + interrupt_id: str, + tool_name: str, + tool_args: dict, + expires_at: datetime, + ) -> None: + stmt = select(AgentChatSession).where(AgentChatSession.id == session_id) + session = await self._session.scalar(stmt) + if session is None: + raise ValueError(f"Session {session_id} not found") + snapshot = session.state_snapshot or {} + snapshot["pending_tool_call"] = { + "interrupt_id": interrupt_id, + "tool_name": tool_name, + "tool_args": tool_args, + "status": "PENDING_APPROVAL", + "expires_at": expires_at.isoformat(), + "decision": None, + "result": None, + } + session.state_snapshot = snapshot + + async def update_pending_tool_call_status( + self, + *, + session_id: UUID, + interrupt_id: str, + status: str, + ) -> None: + stmt = select(AgentChatSession).where(AgentChatSession.id == session_id) + session = await self._session.scalar(stmt) + if session is None: + raise ValueError(f"Session {session_id} not found") + snapshot = session.state_snapshot + if snapshot is None or "pending_tool_call" not in snapshot: + raise ValueError("No pending tool call found") + if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id: + raise ValueError("Interrupt ID mismatch") + snapshot["pending_tool_call"]["status"] = status + session.state_snapshot = snapshot diff --git a/backend/tests/unit/v1/agent/test_service_pending_tool_call.py b/backend/tests/unit/v1/agent/test_service_pending_tool_call.py new file mode 100644 index 0000000..3a801a0 --- /dev/null +++ b/backend/tests/unit/v1/agent/test_service_pending_tool_call.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +import pytest + +from models.agent_chat_session import ( + AgentChatSession, + AgentChatSessionStatus, + SessionType, +) +from v1.agent.service import AgentChatService + + +class FakeAsyncSession: + def __init__(self) -> None: + self.added: list[object] = [] + self._sessions: dict[UUID, AgentChatSession] = {} + + def add(self, obj: object) -> None: + self.added.append(obj) + if isinstance(obj, AgentChatSession): + self._sessions[obj.id] = obj + + async def flush(self) -> None: + return None + + async def commit(self) -> None: + pass + + async def rollback(self) -> None: + pass + + async def refresh(self, obj: object) -> None: + pass + + async def execute(self, stmt: object) -> None: + pass + + async def scalar(self, stmt: object) -> AgentChatSession | None: + for session in self._sessions.values(): + return session + return None + + +@pytest.fixture +def fake_db() -> FakeAsyncSession: + return FakeAsyncSession() + + +@pytest.fixture +def session(fake_db: FakeAsyncSession) -> AgentChatSession: + sess = AgentChatSession( + id=uuid4(), + user_id=uuid4(), + session_type=SessionType.CHAT, + status=AgentChatSessionStatus.RUNNING, + ) + fake_db.add(sess) + return sess + + +@pytest.fixture +def service(fake_db: FakeAsyncSession) -> AgentChatService: + return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type] + + +class TestPendingToolCall: + @pytest.mark.asyncio + async def test_save_pending_tool_call_to_state_snapshot( + self, service: AgentChatService, session: AgentChatSession + ): + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-1", + tool_name="srv.transfer_funds", + tool_args={"to": "u2", "amount": 100}, + expires_at=expires_at, + ) + snapshot = await service.get_state_snapshot(session.id) + assert snapshot is not None + assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL" + assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1" + assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds" + + @pytest.mark.asyncio + async def test_get_state_snapshot_returns_none_when_empty( + self, service: AgentChatService, session: AgentChatSession + ): + snapshot = await service.get_state_snapshot(session.id) + assert snapshot is None + + @pytest.mark.asyncio + async def test_update_pending_tool_call_status( + self, service: AgentChatService, session: AgentChatSession + ): + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-2", + tool_name="srv.delete_file", + tool_args={"file_id": "f1"}, + expires_at=expires_at, + ) + + await service.update_pending_tool_call_status( + session_id=session.id, + interrupt_id="int-2", + status="APPROVED_EXECUTING", + ) + + snapshot = await service.get_state_snapshot(session.id) + assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"