from __future__ import annotations from datetime import datetime, timedelta, timezone from uuid import UUID, uuid4 import pytest from models.agent_chat_session import ( AgentChatSession, AgentChatSessionStatus, SessionType, ) from v1.agent.service import AgentChatService class FakeAsyncSession: def __init__(self) -> None: self.added: list[object] = [] self._sessions: dict[UUID, AgentChatSession] = {} def add(self, obj: object) -> None: self.added.append(obj) if isinstance(obj, AgentChatSession): self._sessions[obj.id] = obj async def flush(self) -> None: return None async def commit(self) -> None: pass async def rollback(self) -> None: pass async def refresh(self, obj: object) -> None: pass async def execute(self, stmt: object) -> None: pass async def scalar(self, stmt: object) -> AgentChatSession | None: for session in self._sessions.values(): return session return None @pytest.fixture def fake_db() -> FakeAsyncSession: return FakeAsyncSession() @pytest.fixture def session(fake_db: FakeAsyncSession) -> AgentChatSession: sess = AgentChatSession( id=uuid4(), user_id=uuid4(), session_type=SessionType.CHAT, status=AgentChatSessionStatus.RUNNING, ) fake_db.add(sess) return sess @pytest.fixture def service(fake_db: FakeAsyncSession) -> AgentChatService: return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type] class TestResumeIdempotency: @pytest.mark.asyncio async def test_resume_is_idempotent( self, service: AgentChatService, session: AgentChatSession ): expires_at = datetime.now(timezone.utc) + timedelta(hours=1) await service.set_pending_tool_call( session_id=session.id, interrupt_id="int-1", tool_name="srv.transfer_funds", tool_args={"to": "u2", "amount": 100}, expires_at=expires_at, ) first = await service.apply_resume_decision( session_id=session.id, interrupt_id="int-1", decision={"decision": "approved"}, ) second = await service.apply_resume_decision( session_id=session.id, interrupt_id="int-1", decision={"decision": "approved"}, ) assert first.applied is True assert second.applied is False @pytest.mark.asyncio async def test_resume_updates_status_to_approved( self, service: AgentChatService, session: AgentChatSession ): expires_at = datetime.now(timezone.utc) + timedelta(hours=1) await service.set_pending_tool_call( session_id=session.id, interrupt_id="int-2", tool_name="srv.delete_file", tool_args={"file_id": "f1"}, expires_at=expires_at, ) result = await service.apply_resume_decision( session_id=session.id, interrupt_id="int-2", decision={"decision": "approved"}, ) assert result.applied is True snapshot = await service.get_state_snapshot(session.id) assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING" assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"} @pytest.mark.asyncio async def test_resume_updates_status_to_rejected( self, service: AgentChatService, session: AgentChatSession ): expires_at = datetime.now(timezone.utc) + timedelta(hours=1) await service.set_pending_tool_call( session_id=session.id, interrupt_id="int-3", tool_name="srv.transfer_funds", tool_args={"to": "u2", "amount": 100}, expires_at=expires_at, ) result = await service.apply_resume_decision( session_id=session.id, interrupt_id="int-3", decision={"decision": "rejected"}, ) assert result.applied is True snapshot = await service.get_state_snapshot(session.id) assert snapshot["pending_tool_call"]["status"] == "REJECTED"