Files
social-app/backend/tests/unit/v1/agent/test_resume_idempotency.py
T

143 lines
4.2 KiB
Python

from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
def add(self, obj: object) -> None:
self.added.append(obj)
if isinstance(obj, AgentChatSession):
self._sessions[obj.id] = obj
async def flush(self) -> None:
return None
async def commit(self) -> None:
pass
async def rollback(self) -> None:
pass
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
@pytest.fixture
def fake_db() -> FakeAsyncSession:
return FakeAsyncSession()
@pytest.fixture
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
sess = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
fake_db.add(sess)
return sess
@pytest.fixture
def service(fake_db: FakeAsyncSession) -> AgentChatService:
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
class TestResumeIdempotency:
@pytest.mark.asyncio
async def test_resume_is_idempotent(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
)
first = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
second = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
assert first.applied is True
assert second.applied is False
@pytest.mark.asyncio
async def test_resume_updates_status_to_approved(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-2",
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-2",
decision={"decision": "approved"},
)
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"}
@pytest.mark.asyncio
async def test_resume_updates_status_to_rejected(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-3",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-3",
decision={"decision": "rejected"},
)
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "REJECTED"