fix(agent): polish interrupt-resume flow for merge readiness
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
# Deletion Log
|
||||
|
||||
## 2026-03-03 feature polish
|
||||
|
||||
- Scope: `backend/src/v1/agent/*` and `backend/tests/unit/v1/agent/*` only.
|
||||
- Candidate review source: scoped `refactor-cleaner` run for the directories above.
|
||||
|
||||
### Executed cleanup
|
||||
|
||||
1. Merged duplicated newline validation logic in `backend/src/v1/agent/service.py`.
|
||||
- Before: duplicated checks in `prepare_resume`, `stream_run`, `stream_resume`.
|
||||
- After: centralized `_validate_no_newlines` helper.
|
||||
- Behavior impact: none (same validation semantics).
|
||||
|
||||
2. Merged duplicated SSE event string formatting in `backend/src/v1/agent/service.py`.
|
||||
- Before: repeated `f"data: {json.dumps(...)}\n\n"` fragments.
|
||||
- After: centralized `_sse_data` helper.
|
||||
- Behavior impact: none (same payload format).
|
||||
|
||||
### Candidates not deleted (insufficient evidence)
|
||||
|
||||
- `backend/src/v1/agent/crewai_flow.py`
|
||||
- Reason: candidate report suggested possible dead code, but no deletion was done in this polish pass because cross-module usage certainty was insufficient.
|
||||
- Legacy `run()` path in `backend/src/v1/agent/service.py`
|
||||
- Reason: potentially still relied on by non-scope code paths; deletion deferred.
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
@@ -25,7 +25,7 @@ async def create_run(
|
||||
|
||||
@router.post("/runs/{run_id}/resume")
|
||||
async def resume_run(
|
||||
run_id: str,
|
||||
run_id: Annotated[str, Path(min_length=1, max_length=255)],
|
||||
input_data: RunAgentInput,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||
) -> StreamingResponse:
|
||||
@@ -34,6 +34,7 @@ async def resume_run(
|
||||
status_code=409,
|
||||
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
|
||||
)
|
||||
await service.prepare_resume(run_id, input_data)
|
||||
return StreamingResponse(
|
||||
service.stream_resume(run_id, input_data),
|
||||
media_type="text/event-stream",
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class RunAgentInput(BaseModel):
|
||||
@@ -40,3 +43,46 @@ class AgentChatRunResponse(BaseModel):
|
||||
session_id: UUID
|
||||
output: str
|
||||
events: list[AgentChatEvent]
|
||||
|
||||
|
||||
class PendingToolStatus(str, Enum):
|
||||
PENDING_APPROVAL = "PENDING_APPROVAL"
|
||||
APPROVED_EXECUTING = "APPROVED_EXECUTING"
|
||||
EXECUTED = "EXECUTED"
|
||||
REJECTED = "REJECTED"
|
||||
EXPIRED = "EXPIRED"
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
interrupt_id: str = Field(min_length=1, max_length=255)
|
||||
tool_name: str = Field(min_length=1, max_length=255)
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
status: PendingToolStatus
|
||||
expires_at: datetime
|
||||
decision: dict[str, Any] | None = None
|
||||
result: dict[str, Any] | None = None
|
||||
updated_at: datetime
|
||||
|
||||
@field_validator("expires_at", "updated_at")
|
||||
@classmethod
|
||||
def _validate_timezone_aware(cls, value: datetime) -> datetime:
|
||||
if value.tzinfo is None or value.utcoffset() is None:
|
||||
raise ValueError("datetime must be timezone-aware")
|
||||
return value
|
||||
|
||||
|
||||
class SnapshotRunContext(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str = Field(min_length=1, max_length=255)
|
||||
run_id: str = Field(min_length=1, max_length=255)
|
||||
|
||||
|
||||
class AgentSessionSnapshot(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
version: Literal[2]
|
||||
pending_tool_call: PendingToolCall | None
|
||||
run_context: SnapshotRunContext
|
||||
|
||||
+193
-44
@@ -21,10 +21,14 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.agent.schemas import (
|
||||
AgentSessionSnapshot,
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
PendingToolCall,
|
||||
PendingToolStatus,
|
||||
RunAgentInput,
|
||||
SnapshotRunContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -294,6 +298,41 @@ class AgentChatService(BaseService):
|
||||
return None
|
||||
return session.state_snapshot
|
||||
|
||||
@staticmethod
|
||||
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
|
||||
try:
|
||||
return AgentSessionSnapshot.model_validate(raw_snapshot)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise ValueError("Invalid state_snapshot format") from exc
|
||||
|
||||
async def _get_session_for_update(
|
||||
self, session_id: UUID
|
||||
) -> AgentChatSession | None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.with_for_update()
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _assert_session_owner(self, session: AgentChatSession) -> None:
|
||||
if self._current_user is None:
|
||||
return
|
||||
|
||||
if session.user_id != self.require_user_id():
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
@staticmethod
|
||||
def _validate_no_newlines(value: str, *, field_name: str) -> None:
|
||||
if "\n" in value or "\r" in value:
|
||||
raise ValueError(f"{field_name} must not contain newlines")
|
||||
|
||||
@staticmethod
|
||||
def _sse_data(payload: dict[str, Any]) -> str:
|
||||
return f"data: {json.dumps(payload)}\\n\\n"
|
||||
|
||||
async def set_pending_tool_call(
|
||||
self,
|
||||
*,
|
||||
@@ -302,22 +341,30 @@ class AgentChatService(BaseService):
|
||||
tool_name: str,
|
||||
tool_args: dict,
|
||||
expires_at: datetime,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot or {}
|
||||
snapshot["pending_tool_call"] = {
|
||||
"interrupt_id": interrupt_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
}
|
||||
session.state_snapshot = snapshot
|
||||
self._assert_session_owner(session)
|
||||
|
||||
snapshot = AgentSessionSnapshot(
|
||||
version=2,
|
||||
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
|
||||
pending_tool_call=PendingToolCall(
|
||||
interrupt_id=interrupt_id,
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
status=PendingToolStatus.PENDING_APPROVAL,
|
||||
expires_at=expires_at,
|
||||
decision=None,
|
||||
result=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
),
|
||||
)
|
||||
session.state_snapshot = snapshot.model_dump(mode="json")
|
||||
|
||||
async def update_pending_tool_call_status(
|
||||
self,
|
||||
@@ -330,13 +377,27 @@ class AgentChatService(BaseService):
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot
|
||||
if snapshot is None or "pending_tool_call" not in snapshot:
|
||||
self._assert_session_owner(session)
|
||||
if session.state_snapshot is None:
|
||||
raise ValueError("No pending tool call found")
|
||||
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
|
||||
|
||||
snapshot = self._load_snapshot_v2(session.state_snapshot)
|
||||
pending = snapshot.pending_tool_call
|
||||
if pending is None:
|
||||
raise ValueError("No pending tool call found")
|
||||
if pending.interrupt_id != interrupt_id:
|
||||
raise ValueError("Interrupt ID mismatch")
|
||||
snapshot["pending_tool_call"]["status"] = status
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
updated_pending = pending.model_copy(
|
||||
update={
|
||||
"status": PendingToolStatus(status),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
updated_snapshot = snapshot.model_copy(
|
||||
update={"pending_tool_call": updated_pending}
|
||||
)
|
||||
session.state_snapshot = updated_snapshot.model_dump(mode="json")
|
||||
|
||||
async def apply_resume_decision(
|
||||
self,
|
||||
@@ -345,52 +406,140 @@ class AgentChatService(BaseService):
|
||||
interrupt_id: str,
|
||||
decision: dict[str, Any],
|
||||
) -> ResumeDecisionResult:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
session = await self._get_session_for_update(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
self._assert_session_owner(session)
|
||||
|
||||
snapshot = session.state_snapshot
|
||||
if snapshot is None or "pending_tool_call" not in snapshot:
|
||||
if session.state_snapshot is None:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
pending = snapshot["pending_tool_call"]
|
||||
if pending["interrupt_id"] != interrupt_id:
|
||||
snapshot = self._load_snapshot_v2(session.state_snapshot)
|
||||
pending = snapshot.pending_tool_call
|
||||
if pending is None:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
if pending["status"] != "PENDING_APPROVAL":
|
||||
if pending.interrupt_id != interrupt_id:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
if pending.status != PendingToolStatus.PENDING_APPROVAL:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if pending.expires_at <= now:
|
||||
expired_pending = pending.model_copy(
|
||||
update={
|
||||
"status": PendingToolStatus.EXPIRED,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
expired_snapshot = snapshot.model_copy(
|
||||
update={"pending_tool_call": expired_pending}
|
||||
)
|
||||
session.state_snapshot = expired_snapshot.model_dump(mode="json")
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
decision_value = decision.get("decision", "approved")
|
||||
if decision_value == "approved":
|
||||
new_status = "APPROVED_EXECUTING"
|
||||
else:
|
||||
new_status = "REJECTED"
|
||||
next_status = (
|
||||
PendingToolStatus.APPROVED_EXECUTING
|
||||
if decision_value == "approved"
|
||||
else PendingToolStatus.REJECTED
|
||||
)
|
||||
|
||||
snapshot["pending_tool_call"]["status"] = new_status
|
||||
snapshot["pending_tool_call"]["decision"] = decision
|
||||
session.state_snapshot = snapshot
|
||||
updated_pending = pending.model_copy(
|
||||
update={
|
||||
"status": next_status,
|
||||
"decision": decision,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
updated_snapshot = snapshot.model_copy(
|
||||
update={"pending_tool_call": updated_pending}
|
||||
)
|
||||
session.state_snapshot = updated_snapshot.model_dump(mode="json")
|
||||
|
||||
return ResumeDecisionResult(applied=True)
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
|
||||
if "\n" in input_data.runId or "\r" in input_data.runId:
|
||||
raise ValueError("runId must not contain newlines")
|
||||
self._validate_no_newlines(input_data.runId, field_name="runId")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
|
||||
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
|
||||
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
|
||||
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
|
||||
self._validate_no_newlines(run_id, field_name="runId")
|
||||
|
||||
user_id = self.require_user_id()
|
||||
await enforce_rate_limit(
|
||||
scope="agent_resume",
|
||||
identifier=str(user_id),
|
||||
limit=DEFAULT_RATE_LIMIT,
|
||||
window_seconds=DEFAULT_RATE_LIMIT,
|
||||
)
|
||||
|
||||
try:
|
||||
session_id = UUID(run_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="run_id must be a valid UUID"
|
||||
) from exc
|
||||
|
||||
session = await self._get_session_for_update(session_id)
|
||||
if session is None or session.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if input_data.resume is None:
|
||||
raise HTTPException(status_code=422, detail="resume payload is required")
|
||||
|
||||
interrupt_id = input_data.resume.get("interruptId")
|
||||
if not isinstance(interrupt_id, str) or not interrupt_id:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="resume.interruptId is required"
|
||||
)
|
||||
|
||||
decision_payload = input_data.resume.get("payload", {})
|
||||
if not isinstance(decision_payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="resume.payload must be an object",
|
||||
)
|
||||
|
||||
try:
|
||||
decision_result = await self.apply_resume_decision(
|
||||
session_id=session_id,
|
||||
interrupt_id=interrupt_id,
|
||||
decision=decision_payload,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
if not decision_result.applied:
|
||||
if session.state_snapshot is not None:
|
||||
snapshot = self._load_snapshot_v2(session.state_snapshot)
|
||||
pending = snapshot.pending_tool_call
|
||||
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
|
||||
await self._session.commit()
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail="Pending tool call expired",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Resume decision not applicable",
|
||||
)
|
||||
|
||||
await self._session.commit()
|
||||
|
||||
async def stream_resume(
|
||||
self, run_id: str, input_data: RunAgentInput
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if "\n" in run_id or "\r" in run_id:
|
||||
raise ValueError("runId must not contain newlines")
|
||||
self._validate_no_newlines(run_id, field_name="runId")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
|
||||
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
|
||||
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})
|
||||
|
||||
@@ -17,6 +17,12 @@ ALLOWED_BACKEND_TOOLS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
ALLOWED_FRONTEND_TOOLS = frozenset(
|
||||
{
|
||||
"ui.navigate_to",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class InterruptResult(BaseModel):
|
||||
interrupt_type: str
|
||||
@@ -47,6 +53,8 @@ def dispatch_tool_call(
|
||||
args = tool.get("args", {})
|
||||
|
||||
if target == "frontend":
|
||||
if name not in ALLOWED_FRONTEND_TOOLS:
|
||||
raise ValueError(f"Frontend tool '{name}' not in allowlist")
|
||||
return InterruptResult(
|
||||
interrupt_type="tool_execution",
|
||||
tool_name=name,
|
||||
|
||||
@@ -13,6 +13,9 @@ from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentService:
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
return None
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@@ -13,6 +14,9 @@ from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentServiceWithInterrupt:
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
return None
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
@@ -30,7 +34,7 @@ class FakeAgentServiceWithInterrupt:
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
|
||||
yield (
|
||||
'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": '
|
||||
+ str(payload.get("result", {}))
|
||||
+ json.dumps(payload.get("result", {}))
|
||||
+ "}\n\n"
|
||||
)
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
@@ -100,10 +104,7 @@ class TestInterruptResumeFlow:
|
||||
0
|
||||
]
|
||||
assert '"toolName": "ui.navigate_to"' in tool_result_event
|
||||
assert (
|
||||
"'success': True" in tool_result_event
|
||||
or '"success": true' in tool_result_event.lower()
|
||||
)
|
||||
assert '"success": true' in tool_result_event.lower()
|
||||
|
||||
def test_backend_tool_approval_rejected(self, client: TestClient):
|
||||
payload = {
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
@@ -18,5 +30,51 @@ class TestAgentSecurityRules:
|
||||
else:
|
||||
raise AssertionError("Should have raised ValueError for unknown namespace")
|
||||
|
||||
def test_frontend_result_fails_when_interrupt_mismatch(self):
|
||||
pass
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_result_fails_when_interrupt_mismatch(self):
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, session_obj: AgentChatSession) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(self._session_obj)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession(session), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-other",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is False
|
||||
|
||||
@@ -17,6 +17,7 @@ class FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||
self.last_fetch_with_lock = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
@@ -35,8 +36,17 @@ class FakeAsyncSession:
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object) -> None:
|
||||
pass
|
||||
async def execute(self, stmt: object):
|
||||
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
|
||||
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(next(iter(self._sessions.values()), None))
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
@@ -69,7 +79,10 @@ def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||
class TestResumeIdempotency:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_is_idempotent(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
self,
|
||||
service: AgentChatService,
|
||||
session: AgentChatSession,
|
||||
fake_db: FakeAsyncSession,
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
@@ -78,6 +91,8 @@ class TestResumeIdempotency:
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
first = await service.apply_resume_decision(
|
||||
@@ -93,6 +108,7 @@ class TestResumeIdempotency:
|
||||
|
||||
assert first.applied is True
|
||||
assert second.applied is False
|
||||
assert fake_db.last_fetch_with_lock is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_updates_status_to_approved(
|
||||
@@ -105,6 +121,8 @@ class TestResumeIdempotency:
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
@@ -129,6 +147,8 @@ class TestResumeIdempotency:
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
@@ -140,3 +160,28 @@ class TestResumeIdempotency:
|
||||
assert result.applied is True
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_expired_pending_marks_expired_and_not_applied(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-expired",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-expired",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is False
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
|
||||
|
||||
|
||||
class TestRunAgentInput:
|
||||
@@ -55,3 +59,69 @@ class TestRunAgentInput:
|
||||
assert model.state == {"key": "value"}
|
||||
assert len(model.messages) == 1
|
||||
assert model.messages[0]["role"] == "user"
|
||||
|
||||
|
||||
class TestAgentSessionSnapshot:
|
||||
def test_state_snapshot_v2_model_accepts_valid_payload(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
model = AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
assert model.version == 2
|
||||
assert model.pending_tool_call is not None
|
||||
assert model.pending_tool_call.interrupt_id == "int-1"
|
||||
assert model.pending_tool_call.updated_at == datetime(
|
||||
2026, 3, 3, 11, 59, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
def test_state_snapshot_v2_rejects_wrong_version(self):
|
||||
payload = {
|
||||
"version": 1,
|
||||
"pending_tool_call": None,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
def test_state_snapshot_v2_rejects_extra_fields(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
"unexpected": True,
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
@@ -35,8 +35,15 @@ class FakeAsyncSession:
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object) -> None:
|
||||
pass
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(next(iter(self._sessions.values()), None))
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
@@ -78,9 +85,14 @@ class TestPendingToolCall:
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is not None
|
||||
assert snapshot["version"] == 2
|
||||
assert snapshot["run_context"]["thread_id"] == "t1"
|
||||
assert snapshot["run_context"]["run_id"] == "r1"
|
||||
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
|
||||
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
|
||||
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
|
||||
@@ -103,6 +115,8 @@ class TestPendingToolCall:
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
await service.update_pending_tool_call_status(
|
||||
@@ -113,3 +127,42 @@ class TestPendingToolCall:
|
||||
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_legacy_snapshot_is_rejected(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-legacy",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_snapshot_rejects_naive_datetime(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
session.state_snapshot = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-naive",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-naive",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, sessions: list[AgentChatSession]) -> None:
|
||||
self._sessions = {session.id: session for session in sessions}
|
||||
self.commit_called = False
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
for session in self._sessions.values():
|
||||
return _Result(session)
|
||||
return _Result(None)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_called = True
|
||||
|
||||
|
||||
def _build_input(run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "t1",
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_resume_rejects_non_owner_session() -> None:
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": datetime.now(timezone.utc).isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession([session]), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_resume_commits_expired_state_before_410() -> None:
|
||||
owner_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2000-01-01T00:00:00+00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
fake_db = FakeAsyncSession([session])
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 410
|
||||
assert fake_db.commit_called is True
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.tool_dispatcher import (
|
||||
BackendExecutionResult,
|
||||
InterruptResult,
|
||||
@@ -46,9 +48,18 @@ class TestToolDispatcher:
|
||||
def test_dispatcher_class_can_dispatch(self):
|
||||
dispatcher = ToolDispatcher()
|
||||
tool = {
|
||||
"name": "ui.show_dialog",
|
||||
"name": "ui.navigate_to",
|
||||
"execution_target": "frontend",
|
||||
"args": {"message": "Hello"},
|
||||
}
|
||||
result = dispatcher.dispatch(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
|
||||
def test_unknown_frontend_tool_is_rejected(self):
|
||||
tool = {
|
||||
"name": "ui.unknown_action",
|
||||
"execution_target": "frontend",
|
||||
"args": {},
|
||||
}
|
||||
with pytest.raises(ValueError, match="not in allowlist"):
|
||||
dispatch_tool_call(tool)
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# Agent Interrupt/Resume 遗留问题修复设计
|
||||
|
||||
## 1. 目标
|
||||
|
||||
本次修复一次性完成以下三项遗留问题:
|
||||
|
||||
1. `state_snapshot` 并发一致性问题(并发 resume 竞争)
|
||||
2. `expires_at` 过期未强校验问题
|
||||
3. `state_snapshot` 缺少强类型与版本化问题
|
||||
|
||||
## 2. 设计决策
|
||||
|
||||
采用方案 2(严格重构):
|
||||
|
||||
- `state_snapshot` 仅接受新结构,不再兼容旧结构
|
||||
- 统一快照版本为 `version = 2`
|
||||
- 使用强类型模型约束快照结构与状态迁移
|
||||
- resume 入口引入行级锁语义,避免并发双写
|
||||
|
||||
## 3. 状态快照模型
|
||||
|
||||
`state_snapshot` 顶层结构:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": null,
|
||||
"result": null,
|
||||
"updated_at": "2026-03-03T11:59:00Z"
|
||||
},
|
||||
"run_context": {
|
||||
"thread_id": "t1",
|
||||
"run_id": "r1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `version` 必须为 2,否则拒绝处理
|
||||
- `pending_tool_call` 字段缺失或类型错误,按无效快照处理
|
||||
- `run_context` 仅保留 interrupt/resume 必需字段
|
||||
|
||||
## 4. 状态机约束
|
||||
|
||||
仅允许以下迁移:
|
||||
|
||||
- `PENDING_APPROVAL -> APPROVED_EXECUTING -> EXECUTED`
|
||||
- `PENDING_APPROVAL -> REJECTED`
|
||||
- `PENDING_APPROVAL -> EXPIRED`
|
||||
|
||||
非法状态迁移必须返回错误,不做隐式修复。
|
||||
|
||||
## 5. 并发与过期语义
|
||||
|
||||
- resume 前先对目标 session 加锁再读取快照
|
||||
- 同一 `interrupt_id` 并发 resume 只能有一个请求成功
|
||||
- 若 `expires_at < now(UTC)`,先迁移为 `EXPIRED`,再返回 410
|
||||
|
||||
## 6. 错误语义(RFC7807)
|
||||
|
||||
- `409 Conflict`: run/interrupt 不匹配,或并发冲突导致状态已消费
|
||||
- `410 Gone`: 挂起调用已过期
|
||||
- `422 Unprocessable Entity`: `state_snapshot` 非法或版本不匹配
|
||||
- `404 Not Found`: 目标 session/run 不存在
|
||||
|
||||
## 7. 测试策略
|
||||
|
||||
采用 TDD,先写失败测试后实现:
|
||||
|
||||
- 快照版本校验(`version != 2`)
|
||||
- 快照结构校验(必填字段/类型)
|
||||
- 并发 resume 幂等竞争(仅一个成功)
|
||||
- 过期校验(返回 410 + 状态置 EXPIRED)
|
||||
- 合法状态迁移路径覆盖
|
||||
|
||||
## 8. 验证命令
|
||||
|
||||
- `uv run pytest backend/tests/unit/v1/agent -v`
|
||||
- `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py -v`
|
||||
- `uv run pytest backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
|
||||
- `cd backend && uv run ruff check src/v1/agent`
|
||||
- `cd backend && uv run basedpyright src/v1/agent`
|
||||
|
||||
## 9. 风险与回滚
|
||||
|
||||
- 风险:旧快照不再兼容,可能触发运行时拒绝
|
||||
- 处置:通过明确 422 错误暴露不合规数据,结合日志定位并人工修复数据
|
||||
- 回滚:回退本次变更并恢复旧快照解析逻辑(仅在紧急故障时)
|
||||
@@ -0,0 +1,377 @@
|
||||
# Agent Interrupt/Resume Strict Refactor Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 通过严格重构一次性修复 interrupt/resume 的并发安全、过期校验和 state_snapshot 强类型版本化问题。
|
||||
|
||||
**Architecture:** 以 `state_snapshot v2` 为唯一合法结构,服务层使用强类型模型解析与状态迁移,resume 路径在读取会话时加行锁保证并发一致性。路由层维持现有 run/resume 入口,错误通过 HTTPException 输出,测试覆盖版本校验、过期语义、并发幂等和状态机迁移。
|
||||
|
||||
**Tech Stack:** FastAPI, SQLAlchemy AsyncSession, Pydantic v2, pytest
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 新增 state_snapshot v2 强类型模型
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/schemas.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_schemas.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
def test_state_snapshot_v2_model_accepts_valid_payload():
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
model = AgentSessionSnapshot.model_validate(payload)
|
||||
assert model.version == 2
|
||||
|
||||
|
||||
def test_state_snapshot_v2_rejects_wrong_version():
|
||||
payload = {
|
||||
"version": 1,
|
||||
"pending_tool_call": None,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_schemas.py -v`
|
||||
Expected: FAIL(`AgentSessionSnapshot` 未定义或校验不符合预期)
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
class PendingToolStatus(str, Enum):
|
||||
PENDING_APPROVAL = "PENDING_APPROVAL"
|
||||
APPROVED_EXECUTING = "APPROVED_EXECUTING"
|
||||
EXECUTED = "EXECUTED"
|
||||
REJECTED = "REJECTED"
|
||||
EXPIRED = "EXPIRED"
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
interrupt_id: str
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
status: PendingToolStatus
|
||||
expires_at: datetime
|
||||
decision: dict[str, Any] | None = None
|
||||
result: dict[str, Any] | None = None
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SnapshotRunContext(BaseModel):
|
||||
thread_id: str
|
||||
run_id: str
|
||||
|
||||
|
||||
class AgentSessionSnapshot(BaseModel):
|
||||
version: Literal[2]
|
||||
pending_tool_call: PendingToolCall | None = None
|
||||
run_context: SnapshotRunContext
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_schemas.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/schemas.py backend/tests/unit/v1/agent/test_schemas.py
|
||||
git commit -m "refactor(agent): add strict v2 session snapshot schema"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: service 层改为 v2 快照读写(严格拒绝旧结构)
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/service.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_service_pending_tool_call.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_pending_tool_call_writes_v2_snapshot(service, session):
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["version"] == 2
|
||||
assert snapshot["run_context"]["run_id"] == "r1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_legacy_snapshot_is_rejected(service, session):
|
||||
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_service_pending_tool_call.py -v`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
def _build_snapshot_v2(...):
|
||||
return AgentSessionSnapshot(...).model_dump(mode="json")
|
||||
|
||||
|
||||
def _load_snapshot_v2(raw: dict[str, Any] | None) -> AgentSessionSnapshot:
|
||||
if raw is None:
|
||||
raise ValueError("state_snapshot missing")
|
||||
return AgentSessionSnapshot.model_validate(raw)
|
||||
```
|
||||
|
||||
并将 `set_pending_tool_call/get_state_snapshot/update_pending_tool_call_status` 全部改成 v2 模型读写。
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_service_pending_tool_call.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_service_pending_tool_call.py
|
||||
git commit -m "refactor(agent): enforce v2 snapshot read write in service"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: 增加 resume 行锁与并发幂等
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/service.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_resume_idempotency.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_resume_decision_uses_locked_session_fetch(service, fake_db, session):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
assert fake_db.last_fetch_with_lock is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_is_idempotent(service, session):
|
||||
first = await service.apply_resume_decision(...)
|
||||
second = await service.apply_resume_decision(...)
|
||||
assert first.applied is True
|
||||
assert second.applied is False
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
async def _get_session_for_update(self, session_id: UUID) -> AgentChatSession | None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.with_for_update()
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
```
|
||||
|
||||
`apply_resume_decision` 改为锁内读取、校验、状态迁移,保证并发下单次生效。
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_resume_idempotency.py
|
||||
git commit -m "fix(agent): add row lock for resume state transition"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: 增加 expires_at 过期校验(含 EXPIRED 迁移)
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/service.py`
|
||||
- Test: `backend/tests/unit/v1/agent/test_resume_idempotency.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_expired_pending_returns_not_applied_and_marks_expired(service, session):
|
||||
await service.set_pending_tool_call(..., expires_at=datetime.now(timezone.utc) - timedelta(seconds=1), thread_id="t1", run_id="r1")
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
assert result.applied is False
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
```python
|
||||
if pending.expires_at < datetime.now(timezone.utc):
|
||||
pending.status = PendingToolStatus.EXPIRED
|
||||
pending.updated_at = datetime.now(timezone.utc)
|
||||
session.state_snapshot = snapshot.model_dump(mode="json")
|
||||
return ResumeDecisionResult(applied=False, expired=True)
|
||||
```
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_resume_idempotency.py
|
||||
git commit -m "fix(agent): enforce expires_at when applying resume decision"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: 路由层补齐 v2 快照与过期/冲突错误映射
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent/router.py`
|
||||
- Modify: `backend/src/v1/agent/service.py`
|
||||
- Test: `backend/tests/integration/v1/agent/test_chat_routes.py`
|
||||
- Test: `backend/tests/integration/v1/agent/test_interrupt_resume_flow.py`
|
||||
|
||||
**Step 1: Write the failing test**
|
||||
|
||||
```python
|
||||
def test_resume_route_returns_409_on_run_id_mismatch(client):
|
||||
...
|
||||
|
||||
|
||||
def test_resume_route_returns_410_when_pending_expired(client):
|
||||
...
|
||||
|
||||
|
||||
def test_resume_route_returns_422_for_legacy_snapshot(client):
|
||||
...
|
||||
```
|
||||
|
||||
**Step 2: Run test to verify it fails**
|
||||
|
||||
Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
|
||||
Expected: FAIL
|
||||
|
||||
**Step 3: Write minimal implementation**
|
||||
|
||||
在 `stream_resume` 或路由调用链里将领域错误映射为:
|
||||
|
||||
- 过期 -> `HTTPException(410)`
|
||||
- 旧快照/结构错误 -> `HTTPException(422)`
|
||||
- 状态冲突/重复消费 -> `HTTPException(409)`
|
||||
|
||||
**Step 4: Run test to verify it passes**
|
||||
|
||||
Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/src/v1/agent/router.py backend/src/v1/agent/service.py backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py
|
||||
git commit -m "fix(agent): map resume snapshot errors to 409 410 422"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: 更新文档并完成验证
|
||||
|
||||
**Files:**
|
||||
- Modify: `docs/plans/2026-03-03-agent-chat-design.md`
|
||||
- Modify: `docs/runtime/runtime-route.md`
|
||||
|
||||
**Step 1: Update docs**
|
||||
|
||||
- 明确 `state_snapshot version=2` 为唯一支持结构
|
||||
- 明确 resume 过期与并发冲突语义(410/409)
|
||||
- 明确旧快照拒绝策略(422)
|
||||
|
||||
**Step 2: Run unit tests**
|
||||
|
||||
Run: `uv run pytest backend/tests/unit/v1/agent -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 3: Run integration tests**
|
||||
|
||||
Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v`
|
||||
Expected: PASS
|
||||
|
||||
**Step 4: Run static checks**
|
||||
|
||||
Run: `cd backend && uv run ruff check src/v1/agent`
|
||||
Expected: PASS
|
||||
|
||||
Run: `cd backend && uv run basedpyright src/v1/agent`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add docs/plans/2026-03-03-agent-chat-design.md docs/runtime/runtime-route.md
|
||||
git commit -m "docs(agent): document strict snapshot v2 and resume error semantics"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Plan complete and saved to `docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md`.
|
||||
|
||||
Execution mode selected by user request: Subagent-Driven (this session), proceed task-by-task immediately.
|
||||
@@ -788,42 +788,68 @@
|
||||
|
||||
## Agent
|
||||
|
||||
### POST /agent
|
||||
### POST /agent/runs
|
||||
|
||||
运行 Agent 对话(需要认证)。
|
||||
创建 Agent 运行(需要认证,SSE 响应)。
|
||||
|
||||
**Request:**
|
||||
**Request (RunAgentInput):**
|
||||
```json
|
||||
{
|
||||
"message": "string (1-8000 chars)",
|
||||
"session_id": "string? (UUID)"
|
||||
"threadId": "string",
|
||||
"runId": "string",
|
||||
"parentRunId": "string?",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": null
|
||||
}
|
||||
```
|
||||
|
||||
**Response:** 200 OK
|
||||
```json
|
||||
{
|
||||
"session_id": "string (UUID)",
|
||||
"output": "string",
|
||||
"events": [
|
||||
{
|
||||
"type": "string",
|
||||
"run_id": "string?",
|
||||
"message_id": "string?",
|
||||
"delta": "string?",
|
||||
"tool_name": "string?",
|
||||
"result": "string?",
|
||||
"output": "string?",
|
||||
"error": "string?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
**Response:** 200 OK (`text/event-stream`)
|
||||
|
||||
**Errors:**
|
||||
- 401: 未认证
|
||||
- 422: 请求参数无效
|
||||
|
||||
### POST /agent/runs/{run_id}/resume
|
||||
|
||||
恢复被中断运行(需要认证,SSE 响应)。
|
||||
|
||||
**Request (RunAgentInput):**
|
||||
```json
|
||||
{
|
||||
"threadId": "string",
|
||||
"runId": "string",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {
|
||||
"interruptId": "string",
|
||||
"payload": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**State Snapshot Contract:**
|
||||
- `state_snapshot` 仅支持 `version = 2`
|
||||
- 顶层必须包含 `run_context` 与 `pending_tool_call`
|
||||
- 旧格式或缺失字段会被拒绝
|
||||
|
||||
**Resume Semantics:**
|
||||
- 同一 `interrupt_id` 并发恢复仅允许一个请求成功
|
||||
- `expires_at` 超时后会标记为 `EXPIRED`,恢复请求不再生效
|
||||
|
||||
**Errors:**
|
||||
- 401: 未认证
|
||||
- 404: 会话不存在
|
||||
- 409: `run_id` 或 `interrupt_id` 冲突,或状态已被消费
|
||||
- 410: 挂起调用已过期
|
||||
- 422: `state_snapshot` 非法或版本不匹配
|
||||
|
||||
---
|
||||
|
||||
## Infra
|
||||
|
||||
Reference in New Issue
Block a user