from __future__ import annotations from datetime import datetime, timezone from uuid import UUID, uuid4 import pytest from fastapi import HTTPException from core.auth.models import CurrentUser from models.agent_chat_session import ( AgentChatSession, AgentChatSessionStatus, SessionType, ) from v1.agent.schemas import RunAgentInput from v1.agent.service import AgentChatService class FakeAsyncSession: def __init__(self, sessions: list[AgentChatSession]) -> None: self._sessions = {session.id: session for session in sessions} self.commit_called = False async def execute(self, stmt: object): class _Result: def __init__(self, session_obj: AgentChatSession | None) -> None: self._session_obj = session_obj def scalar_one_or_none(self) -> AgentChatSession | None: return self._session_obj for session in self._sessions.values(): return _Result(session) return _Result(None) async def scalar(self, stmt: object) -> AgentChatSession | None: for session in self._sessions.values(): return session return None async def commit(self) -> None: self.commit_called = True def _build_input(run_id: str) -> RunAgentInput: return RunAgentInput.model_validate( { "threadId": "t1", "runId": run_id, "state": {}, "messages": [], "tools": [], "context": [], "forwardedProps": {}, "resume": {"interruptId": "int-1", "payload": {"decision": "approved"}}, } ) @pytest.mark.asyncio async def test_stream_resume_rejects_non_owner_session() -> None: session = AgentChatSession( id=uuid4(), user_id=uuid4(), session_type=SessionType.CHAT, status=AgentChatSessionStatus.RUNNING, state_snapshot={ "version": 2, "pending_tool_call": { "interrupt_id": "int-1", "tool_name": "srv.transfer_funds", "tool_args": {"to": "u2", "amount": 100}, "status": "PENDING_APPROVAL", "expires_at": datetime.now(timezone.utc).isoformat(), "decision": None, "result": None, "updated_at": datetime.now(timezone.utc).isoformat(), }, "run_context": {"thread_id": "t1", "run_id": str(uuid4())}, }, ) service = AgentChatService( session=FakeAsyncSession([session]), # type: ignore[arg-type] current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")), ) with pytest.raises(HTTPException) as exc_info: await service.prepare_resume(str(session.id), _build_input(str(session.id))) assert exc_info.value.status_code == 404 @pytest.mark.asyncio async def test_prepare_resume_commits_expired_state_before_410() -> None: owner_id = UUID("00000000-0000-0000-0000-000000000001") session = AgentChatSession( id=uuid4(), user_id=owner_id, session_type=SessionType.CHAT, status=AgentChatSessionStatus.RUNNING, state_snapshot={ "version": 2, "pending_tool_call": { "interrupt_id": "int-1", "tool_name": "srv.transfer_funds", "tool_args": {"to": "u2", "amount": 100}, "status": "PENDING_APPROVAL", "expires_at": "2000-01-01T00:00:00+00:00", "decision": None, "result": None, "updated_at": datetime.now(timezone.utc).isoformat(), }, "run_context": {"thread_id": "t1", "run_id": str(uuid4())}, }, ) fake_db = FakeAsyncSession([session]) service = AgentChatService( session=fake_db, # type: ignore[arg-type] current_user=CurrentUser(id=owner_id), ) with pytest.raises(HTTPException) as exc_info: await service.prepare_resume(str(session.id), _build_input(str(session.id))) assert exc_info.value.status_code == 410 assert fake_db.commit_called is True