fix(agent): polish interrupt-resume flow for merge readiness

This commit is contained in:
qzl
2026-03-03 17:26:04 +08:00
parent 7be8669144
commit 30a4a1af5d
16 changed files with 1179 additions and 85 deletions
+3 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from v1.agent.dependencies import get_agent_service
@@ -25,7 +25,7 @@ async def create_run(
@router.post("/runs/{run_id}/resume")
async def resume_run(
run_id: str,
run_id: Annotated[str, Path(min_length=1, max_length=255)],
input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse:
@@ -34,6 +34,7 @@ async def resume_run(
status_code=409,
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
)
await service.prepare_resume(run_id, input_data)
return StreamingResponse(
service.stream_resume(run_id, input_data),
media_type="text/event-stream",
+47 -1
View File
@@ -1,9 +1,12 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Literal
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
class RunAgentInput(BaseModel):
@@ -40,3 +43,46 @@ class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
class PendingToolStatus(str, Enum):
PENDING_APPROVAL = "PENDING_APPROVAL"
APPROVED_EXECUTING = "APPROVED_EXECUTING"
EXECUTED = "EXECUTED"
REJECTED = "REJECTED"
EXPIRED = "EXPIRED"
class PendingToolCall(BaseModel):
model_config = ConfigDict(extra="forbid")
interrupt_id: str = Field(min_length=1, max_length=255)
tool_name: str = Field(min_length=1, max_length=255)
tool_args: dict[str, Any] = Field(default_factory=dict)
status: PendingToolStatus
expires_at: datetime
decision: dict[str, Any] | None = None
result: dict[str, Any] | None = None
updated_at: datetime
@field_validator("expires_at", "updated_at")
@classmethod
def _validate_timezone_aware(cls, value: datetime) -> datetime:
if value.tzinfo is None or value.utcoffset() is None:
raise ValueError("datetime must be timezone-aware")
return value
class SnapshotRunContext(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str = Field(min_length=1, max_length=255)
run_id: str = Field(min_length=1, max_length=255)
class AgentSessionSnapshot(BaseModel):
model_config = ConfigDict(extra="forbid")
version: Literal[2]
pending_tool_call: PendingToolCall | None
run_context: SnapshotRunContext
+193 -44
View File
@@ -21,10 +21,14 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent.schemas import (
AgentSessionSnapshot,
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
PendingToolCall,
PendingToolStatus,
RunAgentInput,
SnapshotRunContext,
)
if TYPE_CHECKING:
@@ -294,6 +298,41 @@ class AgentChatService(BaseService):
return None
return session.state_snapshot
@staticmethod
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
try:
return AgentSessionSnapshot.model_validate(raw_snapshot)
except Exception as exc: # noqa: BLE001
raise ValueError("Invalid state_snapshot format") from exc
async def _get_session_for_update(
self, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _assert_session_owner(self, session: AgentChatSession) -> None:
if self._current_user is None:
return
if session.user_id != self.require_user_id():
raise HTTPException(status_code=404, detail="Session not found")
@staticmethod
def _validate_no_newlines(value: str, *, field_name: str) -> None:
if "\n" in value or "\r" in value:
raise ValueError(f"{field_name} must not contain newlines")
@staticmethod
def _sse_data(payload: dict[str, Any]) -> str:
return f"data: {json.dumps(payload)}\\n\\n"
async def set_pending_tool_call(
self,
*,
@@ -302,22 +341,30 @@ class AgentChatService(BaseService):
tool_name: str,
tool_args: dict,
expires_at: datetime,
thread_id: str,
run_id: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot or {}
snapshot["pending_tool_call"] = {
"interrupt_id": interrupt_id,
"tool_name": tool_name,
"tool_args": tool_args,
"status": "PENDING_APPROVAL",
"expires_at": expires_at.isoformat(),
"decision": None,
"result": None,
}
session.state_snapshot = snapshot
self._assert_session_owner(session)
snapshot = AgentSessionSnapshot(
version=2,
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
pending_tool_call=PendingToolCall(
interrupt_id=interrupt_id,
tool_name=tool_name,
tool_args=tool_args,
status=PendingToolStatus.PENDING_APPROVAL,
expires_at=expires_at,
decision=None,
result=None,
updated_at=datetime.now(timezone.utc),
),
)
session.state_snapshot = snapshot.model_dump(mode="json")
async def update_pending_tool_call_status(
self,
@@ -330,13 +377,27 @@ class AgentChatService(BaseService):
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
self._assert_session_owner(session)
if session.state_snapshot is None:
raise ValueError("No pending tool call found")
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
raise ValueError("No pending tool call found")
if pending.interrupt_id != interrupt_id:
raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": PendingToolStatus(status),
"updated_at": datetime.now(timezone.utc),
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
async def apply_resume_decision(
self,
@@ -345,52 +406,140 @@ class AgentChatService(BaseService):
interrupt_id: str,
decision: dict[str, Any],
) -> ResumeDecisionResult:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
session = await self._get_session_for_update(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
if session.state_snapshot is None:
return ResumeDecisionResult(applied=False)
pending = snapshot["pending_tool_call"]
if pending["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
return ResumeDecisionResult(applied=False)
if pending["status"] != "PENDING_APPROVAL":
if pending.interrupt_id != interrupt_id:
return ResumeDecisionResult(applied=False)
if pending.status != PendingToolStatus.PENDING_APPROVAL:
return ResumeDecisionResult(applied=False)
now = datetime.now(timezone.utc)
if pending.expires_at <= now:
expired_pending = pending.model_copy(
update={
"status": PendingToolStatus.EXPIRED,
"updated_at": now,
}
)
expired_snapshot = snapshot.model_copy(
update={"pending_tool_call": expired_pending}
)
session.state_snapshot = expired_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=False)
decision_value = decision.get("decision", "approved")
if decision_value == "approved":
new_status = "APPROVED_EXECUTING"
else:
new_status = "REJECTED"
next_status = (
PendingToolStatus.APPROVED_EXECUTING
if decision_value == "approved"
else PendingToolStatus.REJECTED
)
snapshot["pending_tool_call"]["status"] = new_status
snapshot["pending_tool_call"]["decision"] = decision
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": next_status,
"decision": decision,
"updated_at": now,
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
if "\n" in input_data.runId or "\r" in input_data.runId:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(input_data.runId, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
self._validate_no_newlines(run_id, field_name="runId")
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_resume",
identifier=str(user_id),
limit=DEFAULT_RATE_LIMIT,
window_seconds=DEFAULT_RATE_LIMIT,
)
try:
session_id = UUID(run_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="run_id must be a valid UUID"
) from exc
session = await self._get_session_for_update(session_id)
if session is None or session.user_id != user_id:
raise HTTPException(status_code=404, detail="Session not found")
if input_data.resume is None:
raise HTTPException(status_code=422, detail="resume payload is required")
interrupt_id = input_data.resume.get("interruptId")
if not isinstance(interrupt_id, str) or not interrupt_id:
raise HTTPException(
status_code=422, detail="resume.interruptId is required"
)
decision_payload = input_data.resume.get("payload", {})
if not isinstance(decision_payload, dict):
raise HTTPException(
status_code=422,
detail="resume.payload must be an object",
)
try:
decision_result = await self.apply_resume_decision(
session_id=session_id,
interrupt_id=interrupt_id,
decision=decision_payload,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if not decision_result.applied:
if session.state_snapshot is not None:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
await self._session.commit()
raise HTTPException(
status_code=410,
detail="Pending tool call expired",
)
raise HTTPException(
status_code=409,
detail="Resume decision not applicable",
)
await self._session.commit()
async def stream_resume(
self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]:
if "\n" in run_id or "\r" in run_id:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(run_id, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})
+8
View File
@@ -17,6 +17,12 @@ ALLOWED_BACKEND_TOOLS = frozenset(
}
)
ALLOWED_FRONTEND_TOOLS = frozenset(
{
"ui.navigate_to",
}
)
class InterruptResult(BaseModel):
interrupt_type: str
@@ -47,6 +53,8 @@ def dispatch_tool_call(
args = tool.get("args", {})
if target == "frontend":
if name not in ALLOWED_FRONTEND_TOOLS:
raise ValueError(f"Frontend tool '{name}' not in allowlist")
return InterruptResult(
interrupt_type="tool_execution",
tool_name=name,
@@ -13,6 +13,9 @@ from v1.users.dependencies import get_current_user
class FakeAgentService:
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
return None
async def stream_run(self, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from uuid import UUID
import pytest
@@ -13,6 +14,9 @@ from v1.users.dependencies import get_current_user
class FakeAgentServiceWithInterrupt:
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
return None
async def stream_run(self, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
@@ -30,7 +34,7 @@ class FakeAgentServiceWithInterrupt:
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
yield (
'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": '
+ str(payload.get("result", {}))
+ json.dumps(payload.get("result", {}))
+ "}\n\n"
)
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
@@ -100,10 +104,7 @@ class TestInterruptResumeFlow:
0
]
assert '"toolName": "ui.navigate_to"' in tool_result_event
assert (
"'success': True" in tool_result_event
or '"success": true' in tool_result_event.lower()
)
assert '"success": true' in tool_result_event.lower()
def test_backend_tool_approval_rejected(self, client: TestClient):
payload = {
@@ -1,5 +1,17 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
from v1.agent.tool_registry import validate_tool_spec
@@ -18,5 +30,51 @@ class TestAgentSecurityRules:
else:
raise AssertionError("Should have raised ValueError for unknown namespace")
def test_frontend_result_fails_when_interrupt_mismatch(self):
pass
@pytest.mark.asyncio
async def test_frontend_result_fails_when_interrupt_mismatch(self):
session = AgentChatSession(
id=uuid4(),
user_id=UUID("00000000-0000-0000-0000-000000000001"),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
class FakeAsyncSession:
def __init__(self, session_obj: AgentChatSession) -> None:
self._session_obj = session_obj
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(self._session_obj)
async def scalar(self, stmt: object) -> AgentChatSession | None:
return self._session_obj
service = AgentChatService(
session=FakeAsyncSession(session), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-other",
decision={"decision": "approved"},
)
assert result.applied is False
@@ -17,6 +17,7 @@ class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
self.last_fetch_with_lock = False
def add(self, obj: object) -> None:
self.added.append(obj)
@@ -35,8 +36,17 @@ class FakeAsyncSession:
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def execute(self, stmt: object):
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
@@ -69,7 +79,10 @@ def service(fake_db: FakeAsyncSession) -> AgentChatService:
class TestResumeIdempotency:
@pytest.mark.asyncio
async def test_resume_is_idempotent(
self, service: AgentChatService, session: AgentChatSession
self,
service: AgentChatService,
session: AgentChatSession,
fake_db: FakeAsyncSession,
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
@@ -78,6 +91,8 @@ class TestResumeIdempotency:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
first = await service.apply_resume_decision(
@@ -93,6 +108,7 @@ class TestResumeIdempotency:
assert first.applied is True
assert second.applied is False
assert fake_db.last_fetch_with_lock is True
@pytest.mark.asyncio
async def test_resume_updates_status_to_approved(
@@ -105,6 +121,8 @@ class TestResumeIdempotency:
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
@@ -129,6 +147,8 @@ class TestResumeIdempotency:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
@@ -140,3 +160,28 @@ class TestResumeIdempotency:
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
@pytest.mark.asyncio
async def test_resume_expired_pending_marks_expired_and_not_applied(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-expired",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-expired",
decision={"decision": "approved"},
)
assert result.applied is False
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
+71 -1
View File
@@ -1,4 +1,8 @@
from v1.agent.schemas import RunAgentInput
from datetime import datetime, timezone
import pytest
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
class TestRunAgentInput:
@@ -55,3 +59,69 @@ class TestRunAgentInput:
assert model.state == {"key": "value"}
assert len(model.messages) == 1
assert model.messages[0]["role"] == "user"
class TestAgentSessionSnapshot:
def test_state_snapshot_v2_model_accepts_valid_payload(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
model = AgentSessionSnapshot.model_validate(payload)
assert model.version == 2
assert model.pending_tool_call is not None
assert model.pending_tool_call.interrupt_id == "int-1"
assert model.pending_tool_call.updated_at == datetime(
2026, 3, 3, 11, 59, tzinfo=timezone.utc
)
def test_state_snapshot_v2_rejects_wrong_version(self):
payload = {
"version": 1,
"pending_tool_call": None,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
payload = {
"version": 2,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_rejects_extra_fields(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
"unexpected": True,
},
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
@@ -35,8 +35,15 @@ class FakeAsyncSession:
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
@@ -78,9 +85,14 @@ class TestPendingToolCall:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is not None
assert snapshot["version"] == 2
assert snapshot["run_context"]["thread_id"] == "t1"
assert snapshot["run_context"]["run_id"] == "r1"
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
@@ -103,6 +115,8 @@ class TestPendingToolCall:
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
await service.update_pending_tool_call_status(
@@ -113,3 +127,42 @@ class TestPendingToolCall:
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
@pytest.mark.asyncio
async def test_invalid_legacy_snapshot_is_rejected(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-legacy",
decision={"decision": "approved"},
)
@pytest.mark.asyncio
async def test_snapshot_rejects_naive_datetime(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-naive",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-naive",
decision={"decision": "approved"},
)
@@ -0,0 +1,126 @@
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.schemas import RunAgentInput
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self, sessions: list[AgentChatSession]) -> None:
self._sessions = {session.id: session for session in sessions}
self.commit_called = False
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
for session in self._sessions.values():
return _Result(session)
return _Result(None)
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
async def commit(self) -> None:
self.commit_called = True
def _build_input(run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": "t1",
"runId": run_id,
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
)
@pytest.mark.asyncio
async def test_stream_resume_rejects_non_owner_session() -> None:
session = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": datetime.now(timezone.utc).isoformat(),
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
service = AgentChatService(
session=FakeAsyncSession([session]), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_prepare_resume_commits_expired_state_before_410() -> None:
owner_id = UUID("00000000-0000-0000-0000-000000000001")
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2000-01-01T00:00:00+00:00",
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
fake_db = FakeAsyncSession([session])
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=owner_id),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 410
assert fake_db.commit_called is True
@@ -1,5 +1,7 @@
from __future__ import annotations
import pytest
from v1.agent.tool_dispatcher import (
BackendExecutionResult,
InterruptResult,
@@ -46,9 +48,18 @@ class TestToolDispatcher:
def test_dispatcher_class_can_dispatch(self):
dispatcher = ToolDispatcher()
tool = {
"name": "ui.show_dialog",
"name": "ui.navigate_to",
"execution_target": "frontend",
"args": {"message": "Hello"},
}
result = dispatcher.dispatch(tool)
assert isinstance(result, InterruptResult)
def test_unknown_frontend_tool_is_rejected(self):
tool = {
"name": "ui.unknown_action",
"execution_target": "frontend",
"args": {},
}
with pytest.raises(ValueError, match="not in allowlist"):
dispatch_tool_call(tool)