fix(agent): polish interrupt-resume flow for merge readiness

This commit is contained in:
qzl
2026-03-03 17:26:04 +08:00
parent 7be8669144
commit 30a4a1af5d
16 changed files with 1179 additions and 85 deletions
@@ -1,5 +1,17 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
from v1.agent.tool_registry import validate_tool_spec
@@ -18,5 +30,51 @@ class TestAgentSecurityRules:
else:
raise AssertionError("Should have raised ValueError for unknown namespace")
def test_frontend_result_fails_when_interrupt_mismatch(self):
pass
@pytest.mark.asyncio
async def test_frontend_result_fails_when_interrupt_mismatch(self):
session = AgentChatSession(
id=uuid4(),
user_id=UUID("00000000-0000-0000-0000-000000000001"),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
class FakeAsyncSession:
def __init__(self, session_obj: AgentChatSession) -> None:
self._session_obj = session_obj
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(self._session_obj)
async def scalar(self, stmt: object) -> AgentChatSession | None:
return self._session_obj
service = AgentChatService(
session=FakeAsyncSession(session), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-other",
decision={"decision": "approved"},
)
assert result.applied is False
@@ -17,6 +17,7 @@ class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
self.last_fetch_with_lock = False
def add(self, obj: object) -> None:
self.added.append(obj)
@@ -35,8 +36,17 @@ class FakeAsyncSession:
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def execute(self, stmt: object):
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
@@ -69,7 +79,10 @@ def service(fake_db: FakeAsyncSession) -> AgentChatService:
class TestResumeIdempotency:
@pytest.mark.asyncio
async def test_resume_is_idempotent(
self, service: AgentChatService, session: AgentChatSession
self,
service: AgentChatService,
session: AgentChatSession,
fake_db: FakeAsyncSession,
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
@@ -78,6 +91,8 @@ class TestResumeIdempotency:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
first = await service.apply_resume_decision(
@@ -93,6 +108,7 @@ class TestResumeIdempotency:
assert first.applied is True
assert second.applied is False
assert fake_db.last_fetch_with_lock is True
@pytest.mark.asyncio
async def test_resume_updates_status_to_approved(
@@ -105,6 +121,8 @@ class TestResumeIdempotency:
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
@@ -129,6 +147,8 @@ class TestResumeIdempotency:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
@@ -140,3 +160,28 @@ class TestResumeIdempotency:
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
@pytest.mark.asyncio
async def test_resume_expired_pending_marks_expired_and_not_applied(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-expired",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-expired",
decision={"decision": "approved"},
)
assert result.applied is False
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
+71 -1
View File
@@ -1,4 +1,8 @@
from v1.agent.schemas import RunAgentInput
from datetime import datetime, timezone
import pytest
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
class TestRunAgentInput:
@@ -55,3 +59,69 @@ class TestRunAgentInput:
assert model.state == {"key": "value"}
assert len(model.messages) == 1
assert model.messages[0]["role"] == "user"
class TestAgentSessionSnapshot:
def test_state_snapshot_v2_model_accepts_valid_payload(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
model = AgentSessionSnapshot.model_validate(payload)
assert model.version == 2
assert model.pending_tool_call is not None
assert model.pending_tool_call.interrupt_id == "int-1"
assert model.pending_tool_call.updated_at == datetime(
2026, 3, 3, 11, 59, tzinfo=timezone.utc
)
def test_state_snapshot_v2_rejects_wrong_version(self):
payload = {
"version": 1,
"pending_tool_call": None,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
payload = {
"version": 2,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_rejects_extra_fields(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
"unexpected": True,
},
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
@@ -35,8 +35,15 @@ class FakeAsyncSession:
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
@@ -78,9 +85,14 @@ class TestPendingToolCall:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is not None
assert snapshot["version"] == 2
assert snapshot["run_context"]["thread_id"] == "t1"
assert snapshot["run_context"]["run_id"] == "r1"
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
@@ -103,6 +115,8 @@ class TestPendingToolCall:
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
await service.update_pending_tool_call_status(
@@ -113,3 +127,42 @@ class TestPendingToolCall:
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
@pytest.mark.asyncio
async def test_invalid_legacy_snapshot_is_rejected(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-legacy",
decision={"decision": "approved"},
)
@pytest.mark.asyncio
async def test_snapshot_rejects_naive_datetime(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-naive",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-naive",
decision={"decision": "approved"},
)
@@ -0,0 +1,126 @@
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.schemas import RunAgentInput
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self, sessions: list[AgentChatSession]) -> None:
self._sessions = {session.id: session for session in sessions}
self.commit_called = False
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
for session in self._sessions.values():
return _Result(session)
return _Result(None)
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
async def commit(self) -> None:
self.commit_called = True
def _build_input(run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": "t1",
"runId": run_id,
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
)
@pytest.mark.asyncio
async def test_stream_resume_rejects_non_owner_session() -> None:
session = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": datetime.now(timezone.utc).isoformat(),
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
service = AgentChatService(
session=FakeAsyncSession([session]), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_prepare_resume_commits_expired_state_before_410() -> None:
owner_id = UUID("00000000-0000-0000-0000-000000000001")
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2000-01-01T00:00:00+00:00",
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
fake_db = FakeAsyncSession([session])
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=owner_id),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 410
assert fake_db.commit_called is True
@@ -1,5 +1,7 @@
from __future__ import annotations
import pytest
from v1.agent.tool_dispatcher import (
BackendExecutionResult,
InterruptResult,
@@ -46,9 +48,18 @@ class TestToolDispatcher:
def test_dispatcher_class_can_dispatch(self):
dispatcher = ToolDispatcher()
tool = {
"name": "ui.show_dialog",
"name": "ui.navigate_to",
"execution_target": "frontend",
"args": {"message": "Hello"},
}
result = dispatcher.dispatch(tool)
assert isinstance(result, InterruptResult)
def test_unknown_frontend_tool_is_rejected(self):
tool = {
"name": "ui.unknown_action",
"execution_target": "frontend",
"args": {},
}
with pytest.raises(ValueError, match="not in allowlist"):
dispatch_tool_call(tool)