fix(agent): polish interrupt-resume flow for merge readiness
This commit is contained in:
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, sessions: list[AgentChatSession]) -> None:
|
||||
self._sessions = {session.id: session for session in sessions}
|
||||
self.commit_called = False
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
for session in self._sessions.values():
|
||||
return _Result(session)
|
||||
return _Result(None)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_called = True
|
||||
|
||||
|
||||
def _build_input(run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "t1",
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_resume_rejects_non_owner_session() -> None:
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": datetime.now(timezone.utc).isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession([session]), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_resume_commits_expired_state_before_410() -> None:
|
||||
owner_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2000-01-01T00:00:00+00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
fake_db = FakeAsyncSession([session])
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 410
|
||||
assert fake_db.commit_called is True
|
||||
Reference in New Issue
Block a user