fix(agent): polish interrupt-resume flow for merge readiness

This commit is contained in:
qzl
2026-03-03 17:26:04 +08:00
parent 7be8669144
commit 30a4a1af5d
16 changed files with 1179 additions and 85 deletions
+25
View File
@@ -0,0 +1,25 @@
# Deletion Log
## 2026-03-03 feature polish
- Scope: `backend/src/v1/agent/*` and `backend/tests/unit/v1/agent/*` only.
- Candidate review source: scoped `refactor-cleaner` run for the directories above.
### Executed cleanup
1. Merged duplicated newline validation logic in `backend/src/v1/agent/service.py`.
- Before: duplicated checks in `prepare_resume`, `stream_run`, `stream_resume`.
- After: centralized `_validate_no_newlines` helper.
- Behavior impact: none (same validation semantics).
2. Merged duplicated SSE event string formatting in `backend/src/v1/agent/service.py`.
- Before: repeated `f"data: {json.dumps(...)}\n\n"` fragments.
- After: centralized `_sse_data` helper.
- Behavior impact: none (same payload format).
### Candidates not deleted (insufficient evidence)
- `backend/src/v1/agent/crewai_flow.py`
- Reason: candidate report suggested possible dead code, but no deletion was done in this polish pass because cross-module usage certainty was insufficient.
- Legacy `run()` path in `backend/src/v1/agent/service.py`
- Reason: potentially still relied on by non-scope code paths; deletion deferred.
+3 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from v1.agent.dependencies import get_agent_service
@@ -25,7 +25,7 @@ async def create_run(
@router.post("/runs/{run_id}/resume")
async def resume_run(
run_id: str,
run_id: Annotated[str, Path(min_length=1, max_length=255)],
input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse:
@@ -34,6 +34,7 @@ async def resume_run(
status_code=409,
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
)
await service.prepare_resume(run_id, input_data)
return StreamingResponse(
service.stream_resume(run_id, input_data),
media_type="text/event-stream",
+47 -1
View File
@@ -1,9 +1,12 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Literal
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
class RunAgentInput(BaseModel):
@@ -40,3 +43,46 @@ class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
class PendingToolStatus(str, Enum):
PENDING_APPROVAL = "PENDING_APPROVAL"
APPROVED_EXECUTING = "APPROVED_EXECUTING"
EXECUTED = "EXECUTED"
REJECTED = "REJECTED"
EXPIRED = "EXPIRED"
class PendingToolCall(BaseModel):
model_config = ConfigDict(extra="forbid")
interrupt_id: str = Field(min_length=1, max_length=255)
tool_name: str = Field(min_length=1, max_length=255)
tool_args: dict[str, Any] = Field(default_factory=dict)
status: PendingToolStatus
expires_at: datetime
decision: dict[str, Any] | None = None
result: dict[str, Any] | None = None
updated_at: datetime
@field_validator("expires_at", "updated_at")
@classmethod
def _validate_timezone_aware(cls, value: datetime) -> datetime:
if value.tzinfo is None or value.utcoffset() is None:
raise ValueError("datetime must be timezone-aware")
return value
class SnapshotRunContext(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str = Field(min_length=1, max_length=255)
run_id: str = Field(min_length=1, max_length=255)
class AgentSessionSnapshot(BaseModel):
model_config = ConfigDict(extra="forbid")
version: Literal[2]
pending_tool_call: PendingToolCall | None
run_context: SnapshotRunContext
+193 -44
View File
@@ -21,10 +21,14 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent.schemas import (
AgentSessionSnapshot,
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
PendingToolCall,
PendingToolStatus,
RunAgentInput,
SnapshotRunContext,
)
if TYPE_CHECKING:
@@ -294,6 +298,41 @@ class AgentChatService(BaseService):
return None
return session.state_snapshot
@staticmethod
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
try:
return AgentSessionSnapshot.model_validate(raw_snapshot)
except Exception as exc: # noqa: BLE001
raise ValueError("Invalid state_snapshot format") from exc
async def _get_session_for_update(
self, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _assert_session_owner(self, session: AgentChatSession) -> None:
if self._current_user is None:
return
if session.user_id != self.require_user_id():
raise HTTPException(status_code=404, detail="Session not found")
@staticmethod
def _validate_no_newlines(value: str, *, field_name: str) -> None:
if "\n" in value or "\r" in value:
raise ValueError(f"{field_name} must not contain newlines")
@staticmethod
def _sse_data(payload: dict[str, Any]) -> str:
return f"data: {json.dumps(payload)}\\n\\n"
async def set_pending_tool_call(
self,
*,
@@ -302,22 +341,30 @@ class AgentChatService(BaseService):
tool_name: str,
tool_args: dict,
expires_at: datetime,
thread_id: str,
run_id: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot or {}
snapshot["pending_tool_call"] = {
"interrupt_id": interrupt_id,
"tool_name": tool_name,
"tool_args": tool_args,
"status": "PENDING_APPROVAL",
"expires_at": expires_at.isoformat(),
"decision": None,
"result": None,
}
session.state_snapshot = snapshot
self._assert_session_owner(session)
snapshot = AgentSessionSnapshot(
version=2,
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
pending_tool_call=PendingToolCall(
interrupt_id=interrupt_id,
tool_name=tool_name,
tool_args=tool_args,
status=PendingToolStatus.PENDING_APPROVAL,
expires_at=expires_at,
decision=None,
result=None,
updated_at=datetime.now(timezone.utc),
),
)
session.state_snapshot = snapshot.model_dump(mode="json")
async def update_pending_tool_call_status(
self,
@@ -330,13 +377,27 @@ class AgentChatService(BaseService):
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
self._assert_session_owner(session)
if session.state_snapshot is None:
raise ValueError("No pending tool call found")
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
raise ValueError("No pending tool call found")
if pending.interrupt_id != interrupt_id:
raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": PendingToolStatus(status),
"updated_at": datetime.now(timezone.utc),
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
async def apply_resume_decision(
self,
@@ -345,52 +406,140 @@ class AgentChatService(BaseService):
interrupt_id: str,
decision: dict[str, Any],
) -> ResumeDecisionResult:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
session = await self._get_session_for_update(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
if session.state_snapshot is None:
return ResumeDecisionResult(applied=False)
pending = snapshot["pending_tool_call"]
if pending["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
return ResumeDecisionResult(applied=False)
if pending["status"] != "PENDING_APPROVAL":
if pending.interrupt_id != interrupt_id:
return ResumeDecisionResult(applied=False)
if pending.status != PendingToolStatus.PENDING_APPROVAL:
return ResumeDecisionResult(applied=False)
now = datetime.now(timezone.utc)
if pending.expires_at <= now:
expired_pending = pending.model_copy(
update={
"status": PendingToolStatus.EXPIRED,
"updated_at": now,
}
)
expired_snapshot = snapshot.model_copy(
update={"pending_tool_call": expired_pending}
)
session.state_snapshot = expired_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=False)
decision_value = decision.get("decision", "approved")
if decision_value == "approved":
new_status = "APPROVED_EXECUTING"
else:
new_status = "REJECTED"
next_status = (
PendingToolStatus.APPROVED_EXECUTING
if decision_value == "approved"
else PendingToolStatus.REJECTED
)
snapshot["pending_tool_call"]["status"] = new_status
snapshot["pending_tool_call"]["decision"] = decision
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": next_status,
"decision": decision,
"updated_at": now,
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
if "\n" in input_data.runId or "\r" in input_data.runId:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(input_data.runId, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
self._validate_no_newlines(run_id, field_name="runId")
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_resume",
identifier=str(user_id),
limit=DEFAULT_RATE_LIMIT,
window_seconds=DEFAULT_RATE_LIMIT,
)
try:
session_id = UUID(run_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="run_id must be a valid UUID"
) from exc
session = await self._get_session_for_update(session_id)
if session is None or session.user_id != user_id:
raise HTTPException(status_code=404, detail="Session not found")
if input_data.resume is None:
raise HTTPException(status_code=422, detail="resume payload is required")
interrupt_id = input_data.resume.get("interruptId")
if not isinstance(interrupt_id, str) or not interrupt_id:
raise HTTPException(
status_code=422, detail="resume.interruptId is required"
)
decision_payload = input_data.resume.get("payload", {})
if not isinstance(decision_payload, dict):
raise HTTPException(
status_code=422,
detail="resume.payload must be an object",
)
try:
decision_result = await self.apply_resume_decision(
session_id=session_id,
interrupt_id=interrupt_id,
decision=decision_payload,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if not decision_result.applied:
if session.state_snapshot is not None:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
await self._session.commit()
raise HTTPException(
status_code=410,
detail="Pending tool call expired",
)
raise HTTPException(
status_code=409,
detail="Resume decision not applicable",
)
await self._session.commit()
async def stream_resume(
self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]:
if "\n" in run_id or "\r" in run_id:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(run_id, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})
+8
View File
@@ -17,6 +17,12 @@ ALLOWED_BACKEND_TOOLS = frozenset(
}
)
ALLOWED_FRONTEND_TOOLS = frozenset(
{
"ui.navigate_to",
}
)
class InterruptResult(BaseModel):
interrupt_type: str
@@ -47,6 +53,8 @@ def dispatch_tool_call(
args = tool.get("args", {})
if target == "frontend":
if name not in ALLOWED_FRONTEND_TOOLS:
raise ValueError(f"Frontend tool '{name}' not in allowlist")
return InterruptResult(
interrupt_type="tool_execution",
tool_name=name,
@@ -13,6 +13,9 @@ from v1.users.dependencies import get_current_user
class FakeAgentService:
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
return None
async def stream_run(self, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from uuid import UUID
import pytest
@@ -13,6 +14,9 @@ from v1.users.dependencies import get_current_user
class FakeAgentServiceWithInterrupt:
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
return None
async def stream_run(self, input_data: RunAgentInput):
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
@@ -30,7 +34,7 @@ class FakeAgentServiceWithInterrupt:
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
yield (
'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": '
+ str(payload.get("result", {}))
+ json.dumps(payload.get("result", {}))
+ "}\n\n"
)
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
@@ -100,10 +104,7 @@ class TestInterruptResumeFlow:
0
]
assert '"toolName": "ui.navigate_to"' in tool_result_event
assert (
"'success': True" in tool_result_event
or '"success": true' in tool_result_event.lower()
)
assert '"success": true' in tool_result_event.lower()
def test_backend_tool_approval_rejected(self, client: TestClient):
payload = {
@@ -1,5 +1,17 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
from v1.agent.tool_registry import validate_tool_spec
@@ -18,5 +30,51 @@ class TestAgentSecurityRules:
else:
raise AssertionError("Should have raised ValueError for unknown namespace")
def test_frontend_result_fails_when_interrupt_mismatch(self):
pass
@pytest.mark.asyncio
async def test_frontend_result_fails_when_interrupt_mismatch(self):
session = AgentChatSession(
id=uuid4(),
user_id=UUID("00000000-0000-0000-0000-000000000001"),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
class FakeAsyncSession:
def __init__(self, session_obj: AgentChatSession) -> None:
self._session_obj = session_obj
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(self._session_obj)
async def scalar(self, stmt: object) -> AgentChatSession | None:
return self._session_obj
service = AgentChatService(
session=FakeAsyncSession(session), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-other",
decision={"decision": "approved"},
)
assert result.applied is False
@@ -17,6 +17,7 @@ class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
self.last_fetch_with_lock = False
def add(self, obj: object) -> None:
self.added.append(obj)
@@ -35,8 +36,17 @@ class FakeAsyncSession:
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def execute(self, stmt: object):
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
@@ -69,7 +79,10 @@ def service(fake_db: FakeAsyncSession) -> AgentChatService:
class TestResumeIdempotency:
@pytest.mark.asyncio
async def test_resume_is_idempotent(
self, service: AgentChatService, session: AgentChatSession
self,
service: AgentChatService,
session: AgentChatSession,
fake_db: FakeAsyncSession,
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
@@ -78,6 +91,8 @@ class TestResumeIdempotency:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
first = await service.apply_resume_decision(
@@ -93,6 +108,7 @@ class TestResumeIdempotency:
assert first.applied is True
assert second.applied is False
assert fake_db.last_fetch_with_lock is True
@pytest.mark.asyncio
async def test_resume_updates_status_to_approved(
@@ -105,6 +121,8 @@ class TestResumeIdempotency:
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
@@ -129,6 +147,8 @@ class TestResumeIdempotency:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
@@ -140,3 +160,28 @@ class TestResumeIdempotency:
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
@pytest.mark.asyncio
async def test_resume_expired_pending_marks_expired_and_not_applied(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-expired",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-expired",
decision={"decision": "approved"},
)
assert result.applied is False
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
+71 -1
View File
@@ -1,4 +1,8 @@
from v1.agent.schemas import RunAgentInput
from datetime import datetime, timezone
import pytest
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
class TestRunAgentInput:
@@ -55,3 +59,69 @@ class TestRunAgentInput:
assert model.state == {"key": "value"}
assert len(model.messages) == 1
assert model.messages[0]["role"] == "user"
class TestAgentSessionSnapshot:
def test_state_snapshot_v2_model_accepts_valid_payload(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
model = AgentSessionSnapshot.model_validate(payload)
assert model.version == 2
assert model.pending_tool_call is not None
assert model.pending_tool_call.interrupt_id == "int-1"
assert model.pending_tool_call.updated_at == datetime(
2026, 3, 3, 11, 59, tzinfo=timezone.utc
)
def test_state_snapshot_v2_rejects_wrong_version(self):
payload = {
"version": 1,
"pending_tool_call": None,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
payload = {
"version": 2,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
def test_state_snapshot_v2_rejects_extra_fields(self):
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
"unexpected": True,
},
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
@@ -35,8 +35,15 @@ class FakeAsyncSession:
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
return _Result(next(iter(self._sessions.values()), None))
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
@@ -78,9 +85,14 @@ class TestPendingToolCall:
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot is not None
assert snapshot["version"] == 2
assert snapshot["run_context"]["thread_id"] == "t1"
assert snapshot["run_context"]["run_id"] == "r1"
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
@@ -103,6 +115,8 @@ class TestPendingToolCall:
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
thread_id="t1",
run_id="r1",
)
await service.update_pending_tool_call_status(
@@ -113,3 +127,42 @@ class TestPendingToolCall:
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
@pytest.mark.asyncio
async def test_invalid_legacy_snapshot_is_rejected(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-legacy",
decision={"decision": "approved"},
)
@pytest.mark.asyncio
async def test_snapshot_rejects_naive_datetime(
self, service: AgentChatService, session: AgentChatSession
):
session.state_snapshot = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-naive",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-naive",
decision={"decision": "approved"},
)
@@ -0,0 +1,126 @@
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.schemas import RunAgentInput
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self, sessions: list[AgentChatSession]) -> None:
self._sessions = {session.id: session for session in sessions}
self.commit_called = False
async def execute(self, stmt: object):
class _Result:
def __init__(self, session_obj: AgentChatSession | None) -> None:
self._session_obj = session_obj
def scalar_one_or_none(self) -> AgentChatSession | None:
return self._session_obj
for session in self._sessions.values():
return _Result(session)
return _Result(None)
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
async def commit(self) -> None:
self.commit_called = True
def _build_input(run_id: str) -> RunAgentInput:
return RunAgentInput.model_validate(
{
"threadId": "t1",
"runId": run_id,
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
}
)
@pytest.mark.asyncio
async def test_stream_resume_rejects_non_owner_session() -> None:
session = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": datetime.now(timezone.utc).isoformat(),
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
service = AgentChatService(
session=FakeAsyncSession([session]), # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_prepare_resume_commits_expired_state_before_410() -> None:
owner_id = UUID("00000000-0000-0000-0000-000000000001")
session = AgentChatSession(
id=uuid4(),
user_id=owner_id,
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
state_snapshot={
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2000-01-01T00:00:00+00:00",
"decision": None,
"result": None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
},
)
fake_db = FakeAsyncSession([session])
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=owner_id),
)
with pytest.raises(HTTPException) as exc_info:
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
assert exc_info.value.status_code == 410
assert fake_db.commit_called is True
@@ -1,5 +1,7 @@
from __future__ import annotations
import pytest
from v1.agent.tool_dispatcher import (
BackendExecutionResult,
InterruptResult,
@@ -46,9 +48,18 @@ class TestToolDispatcher:
def test_dispatcher_class_can_dispatch(self):
dispatcher = ToolDispatcher()
tool = {
"name": "ui.show_dialog",
"name": "ui.navigate_to",
"execution_target": "frontend",
"args": {"message": "Hello"},
}
result = dispatcher.dispatch(tool)
assert isinstance(result, InterruptResult)
def test_unknown_frontend_tool_is_rejected(self):
tool = {
"name": "ui.unknown_action",
"execution_target": "frontend",
"args": {},
}
with pytest.raises(ValueError, match="not in allowlist"):
dispatch_tool_call(tool)
@@ -0,0 +1,95 @@
# Agent Interrupt/Resume 遗留问题修复设计
## 1. 目标
本次修复一次性完成以下三项遗留问题:
1. `state_snapshot` 并发一致性问题(并发 resume 竞争)
2. `expires_at` 过期未强校验问题
3. `state_snapshot` 缺少强类型与版本化问题
## 2. 设计决策
采用方案 2(严格重构):
- `state_snapshot` 仅接受新结构,不再兼容旧结构
- 统一快照版本为 `version = 2`
- 使用强类型模型约束快照结构与状态迁移
- resume 入口引入行级锁语义,避免并发双写
## 3. 状态快照模型
`state_snapshot` 顶层结构:
```json
{
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": null,
"result": null,
"updated_at": "2026-03-03T11:59:00Z"
},
"run_context": {
"thread_id": "t1",
"run_id": "r1"
}
}
```
说明:
- `version` 必须为 2,否则拒绝处理
- `pending_tool_call` 字段缺失或类型错误,按无效快照处理
- `run_context` 仅保留 interrupt/resume 必需字段
## 4. 状态机约束
仅允许以下迁移:
- `PENDING_APPROVAL -> APPROVED_EXECUTING -> EXECUTED`
- `PENDING_APPROVAL -> REJECTED`
- `PENDING_APPROVAL -> EXPIRED`
非法状态迁移必须返回错误,不做隐式修复。
## 5. 并发与过期语义
- resume 前先对目标 session 加锁再读取快照
- 同一 `interrupt_id` 并发 resume 只能有一个请求成功
-`expires_at < now(UTC)`,先迁移为 `EXPIRED`,再返回 410
## 6. 错误语义(RFC7807
- `409 Conflict`: run/interrupt 不匹配,或并发冲突导致状态已消费
- `410 Gone`: 挂起调用已过期
- `422 Unprocessable Entity`: `state_snapshot` 非法或版本不匹配
- `404 Not Found`: 目标 session/run 不存在
## 7. 测试策略
采用 TDD,先写失败测试后实现:
- 快照版本校验(`version != 2`
- 快照结构校验(必填字段/类型)
- 并发 resume 幂等竞争(仅一个成功)
- 过期校验(返回 410 + 状态置 EXPIRED
- 合法状态迁移路径覆盖
## 8. 验证命令
- `uv run pytest backend/tests/unit/v1/agent -v`
- `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py -v`
- `uv run pytest backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
- `cd backend && uv run ruff check src/v1/agent`
- `cd backend && uv run basedpyright src/v1/agent`
## 9. 风险与回滚
- 风险:旧快照不再兼容,可能触发运行时拒绝
- 处置:通过明确 422 错误暴露不合规数据,结合日志定位并人工修复数据
- 回滚:回退本次变更并恢复旧快照解析逻辑(仅在紧急故障时)
@@ -0,0 +1,377 @@
# Agent Interrupt/Resume Strict Refactor Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 通过严格重构一次性修复 interrupt/resume 的并发安全、过期校验和 state_snapshot 强类型版本化问题。
**Architecture:**`state_snapshot v2` 为唯一合法结构,服务层使用强类型模型解析与状态迁移,resume 路径在读取会话时加行锁保证并发一致性。路由层维持现有 run/resume 入口,错误通过 HTTPException 输出,测试覆盖版本校验、过期语义、并发幂等和状态机迁移。
**Tech Stack:** FastAPI, SQLAlchemy AsyncSession, Pydantic v2, pytest
---
### Task 1: 新增 state_snapshot v2 强类型模型
**Files:**
- Modify: `backend/src/v1/agent/schemas.py`
- Test: `backend/tests/unit/v1/agent/test_schemas.py`
**Step 1: Write the failing test**
```python
def test_state_snapshot_v2_model_accepts_valid_payload():
payload = {
"version": 2,
"pending_tool_call": {
"interrupt_id": "int-1",
"tool_name": "srv.transfer_funds",
"tool_args": {"to": "u2", "amount": 100},
"status": "PENDING_APPROVAL",
"expires_at": "2026-03-03T12:00:00Z",
"decision": None,
"result": None,
"updated_at": "2026-03-03T11:59:00Z",
},
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
model = AgentSessionSnapshot.model_validate(payload)
assert model.version == 2
def test_state_snapshot_v2_rejects_wrong_version():
payload = {
"version": 1,
"pending_tool_call": None,
"run_context": {"thread_id": "t1", "run_id": "r1"},
}
with pytest.raises(ValueError):
AgentSessionSnapshot.model_validate(payload)
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/v1/agent/test_schemas.py -v`
Expected: FAIL`AgentSessionSnapshot` 未定义或校验不符合预期)
**Step 3: Write minimal implementation**
```python
class PendingToolStatus(str, Enum):
PENDING_APPROVAL = "PENDING_APPROVAL"
APPROVED_EXECUTING = "APPROVED_EXECUTING"
EXECUTED = "EXECUTED"
REJECTED = "REJECTED"
EXPIRED = "EXPIRED"
class PendingToolCall(BaseModel):
interrupt_id: str
tool_name: str
tool_args: dict[str, Any]
status: PendingToolStatus
expires_at: datetime
decision: dict[str, Any] | None = None
result: dict[str, Any] | None = None
updated_at: datetime
class SnapshotRunContext(BaseModel):
thread_id: str
run_id: str
class AgentSessionSnapshot(BaseModel):
version: Literal[2]
pending_tool_call: PendingToolCall | None = None
run_context: SnapshotRunContext
```
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/v1/agent/test_schemas.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/agent/schemas.py backend/tests/unit/v1/agent/test_schemas.py
git commit -m "refactor(agent): add strict v2 session snapshot schema"
```
---
### Task 2: service 层改为 v2 快照读写(严格拒绝旧结构)
**Files:**
- Modify: `backend/src/v1/agent/service.py`
- Test: `backend/tests/unit/v1/agent/test_service_pending_tool_call.py`
**Step 1: Write the failing test**
```python
@pytest.mark.asyncio
async def test_set_pending_tool_call_writes_v2_snapshot(service, session):
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
thread_id="t1",
run_id="r1",
)
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["version"] == 2
assert snapshot["run_context"]["run_id"] == "r1"
@pytest.mark.asyncio
async def test_invalid_legacy_snapshot_is_rejected(service, session):
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
with pytest.raises(ValueError):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/v1/agent/test_service_pending_tool_call.py -v`
Expected: FAIL
**Step 3: Write minimal implementation**
```python
def _build_snapshot_v2(...):
return AgentSessionSnapshot(...).model_dump(mode="json")
def _load_snapshot_v2(raw: dict[str, Any] | None) -> AgentSessionSnapshot:
if raw is None:
raise ValueError("state_snapshot missing")
return AgentSessionSnapshot.model_validate(raw)
```
并将 `set_pending_tool_call/get_state_snapshot/update_pending_tool_call_status` 全部改成 v2 模型读写。
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/v1/agent/test_service_pending_tool_call.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_service_pending_tool_call.py
git commit -m "refactor(agent): enforce v2 snapshot read write in service"
```
---
### Task 3: 增加 resume 行锁与并发幂等
**Files:**
- Modify: `backend/src/v1/agent/service.py`
- Test: `backend/tests/unit/v1/agent/test_resume_idempotency.py`
**Step 1: Write the failing test**
```python
@pytest.mark.asyncio
async def test_apply_resume_decision_uses_locked_session_fetch(service, fake_db, session):
await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
assert fake_db.last_fetch_with_lock is True
@pytest.mark.asyncio
async def test_resume_is_idempotent(service, session):
first = await service.apply_resume_decision(...)
second = await service.apply_resume_decision(...)
assert first.applied is True
assert second.applied is False
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
Expected: FAIL
**Step 3: Write minimal implementation**
```python
async def _get_session_for_update(self, session_id: UUID) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
```
`apply_resume_decision` 改为锁内读取、校验、状态迁移,保证并发下单次生效。
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_resume_idempotency.py
git commit -m "fix(agent): add row lock for resume state transition"
```
---
### Task 4: 增加 expires_at 过期校验(含 EXPIRED 迁移)
**Files:**
- Modify: `backend/src/v1/agent/service.py`
- Test: `backend/tests/unit/v1/agent/test_resume_idempotency.py`
**Step 1: Write the failing test**
```python
@pytest.mark.asyncio
async def test_resume_expired_pending_returns_not_applied_and_marks_expired(service, session):
await service.set_pending_tool_call(..., expires_at=datetime.now(timezone.utc) - timedelta(seconds=1), thread_id="t1", run_id="r1")
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
assert result.applied is False
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
Expected: FAIL
**Step 3: Write minimal implementation**
```python
if pending.expires_at < datetime.now(timezone.utc):
pending.status = PendingToolStatus.EXPIRED
pending.updated_at = datetime.now(timezone.utc)
session.state_snapshot = snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=False, expired=True)
```
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_resume_idempotency.py
git commit -m "fix(agent): enforce expires_at when applying resume decision"
```
---
### Task 5: 路由层补齐 v2 快照与过期/冲突错误映射
**Files:**
- Modify: `backend/src/v1/agent/router.py`
- Modify: `backend/src/v1/agent/service.py`
- Test: `backend/tests/integration/v1/agent/test_chat_routes.py`
- Test: `backend/tests/integration/v1/agent/test_interrupt_resume_flow.py`
**Step 1: Write the failing test**
```python
def test_resume_route_returns_409_on_run_id_mismatch(client):
...
def test_resume_route_returns_410_when_pending_expired(client):
...
def test_resume_route_returns_422_for_legacy_snapshot(client):
...
```
**Step 2: Run test to verify it fails**
Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
Expected: FAIL
**Step 3: Write minimal implementation**
`stream_resume` 或路由调用链里将领域错误映射为:
- 过期 -> `HTTPException(410)`
- 旧快照/结构错误 -> `HTTPException(422)`
- 状态冲突/重复消费 -> `HTTPException(409)`
**Step 4: Run test to verify it passes**
Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
Expected: PASS
**Step 5: Commit**
```bash
git add backend/src/v1/agent/router.py backend/src/v1/agent/service.py backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py
git commit -m "fix(agent): map resume snapshot errors to 409 410 422"
```
---
### Task 6: 更新文档并完成验证
**Files:**
- Modify: `docs/plans/2026-03-03-agent-chat-design.md`
- Modify: `docs/runtime/runtime-route.md`
**Step 1: Update docs**
- 明确 `state_snapshot version=2` 为唯一支持结构
- 明确 resume 过期与并发冲突语义(410/409)
- 明确旧快照拒绝策略(422
**Step 2: Run unit tests**
Run: `uv run pytest backend/tests/unit/v1/agent -v`
Expected: PASS
**Step 3: Run integration tests**
Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
Expected: PASS
**Step 4: Run static checks**
Run: `cd backend && uv run ruff check src/v1/agent`
Expected: PASS
Run: `cd backend && uv run basedpyright src/v1/agent`
Expected: PASS
**Step 5: Commit**
```bash
git add docs/plans/2026-03-03-agent-chat-design.md docs/runtime/runtime-route.md
git commit -m "docs(agent): document strict snapshot v2 and resume error semantics"
```
---
Plan complete and saved to `docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md`.
Execution mode selected by user request: Subagent-Driven (this session), proceed task-by-task immediately.
+50 -24
View File
@@ -788,42 +788,68 @@
## Agent
### POST /agent
### POST /agent/runs
运行 Agent 对话(需要认证)。
创建 Agent 运行(需要认证SSE 响应)。
**Request:**
**Request (RunAgentInput):**
```json
{
"message": "string (1-8000 chars)",
"session_id": "string? (UUID)"
"threadId": "string",
"runId": "string",
"parentRunId": "string?",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": null
}
```
**Response:** 200 OK
```json
{
"session_id": "string (UUID)",
"output": "string",
"events": [
{
"type": "string",
"run_id": "string?",
"message_id": "string?",
"delta": "string?",
"tool_name": "string?",
"result": "string?",
"output": "string?",
"error": "string?"
}
]
}
```
**Response:** 200 OK (`text/event-stream`)
**Errors:**
- 401: 未认证
- 422: 请求参数无效
### POST /agent/runs/{run_id}/resume
恢复被中断运行(需要认证,SSE 响应)。
**Request (RunAgentInput):**
```json
{
"threadId": "string",
"runId": "string",
"state": {},
"messages": [],
"tools": [],
"context": [],
"forwardedProps": {},
"resume": {
"interruptId": "string",
"payload": {}
}
}
```
**State Snapshot Contract:**
- `state_snapshot` 仅支持 `version = 2`
- 顶层必须包含 `run_context``pending_tool_call`
- 旧格式或缺失字段会被拒绝
**Resume Semantics:**
- 同一 `interrupt_id` 并发恢复仅允许一个请求成功
- `expires_at` 超时后会标记为 `EXPIRED`,恢复请求不再生效
**Errors:**
- 401: 未认证
- 404: 会话不存在
- 409: `run_id``interrupt_id` 冲突,或状态已被消费
- 410: 挂起调用已过期
- 422: `state_snapshot` 非法或版本不匹配
---
## Infra