From 30a4a1af5d8ccefac89257652f4853ee546ab0ab Mon Sep 17 00:00:00 2001 From: qzl Date: Tue, 3 Mar 2026 17:26:04 +0800 Subject: [PATCH] fix(agent): polish interrupt-resume flow for merge readiness --- DELETION_LOG.md | 25 ++ backend/src/v1/agent/router.py | 5 +- backend/src/v1/agent/schemas.py | 48 ++- backend/src/v1/agent/service.py | 237 +++++++++-- backend/src/v1/agent/tool_dispatcher.py | 8 + .../integration/v1/agent/test_chat_routes.py | 3 + .../v1/agent/test_interrupt_resume_flow.py | 11 +- .../v1/agent/test_agent_security_rules.py | 62 ++- .../unit/v1/agent/test_resume_idempotency.py | 51 ++- backend/tests/unit/v1/agent/test_schemas.py | 72 +++- .../agent/test_service_pending_tool_call.py | 57 ++- .../v1/agent/test_stream_resume_security.py | 126 ++++++ .../unit/v1/agent/test_tool_dispatcher.py | 13 +- ...026-03-03-interrupt-resume-fixes-design.md | 95 +++++ ...errupt-resume-fixes-implementation-plan.md | 377 ++++++++++++++++++ docs/runtime/runtime-route.md | 74 ++-- 16 files changed, 1179 insertions(+), 85 deletions(-) create mode 100644 DELETION_LOG.md create mode 100644 backend/tests/unit/v1/agent/test_stream_resume_security.py create mode 100644 docs/plans/2026-03-03-interrupt-resume-fixes-design.md create mode 100644 docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md diff --git a/DELETION_LOG.md b/DELETION_LOG.md new file mode 100644 index 0000000..74f4ca7 --- /dev/null +++ b/DELETION_LOG.md @@ -0,0 +1,25 @@ +# Deletion Log + +## 2026-03-03 feature polish + +- Scope: `backend/src/v1/agent/*` and `backend/tests/unit/v1/agent/*` only. +- Candidate review source: scoped `refactor-cleaner` run for the directories above. + +### Executed cleanup + +1. Merged duplicated newline validation logic in `backend/src/v1/agent/service.py`. + - Before: duplicated checks in `prepare_resume`, `stream_run`, `stream_resume`. + - After: centralized `_validate_no_newlines` helper. + - Behavior impact: none (same validation semantics). + +2. Merged duplicated SSE event string formatting in `backend/src/v1/agent/service.py`. + - Before: repeated `f"data: {json.dumps(...)}\n\n"` fragments. + - After: centralized `_sse_data` helper. + - Behavior impact: none (same payload format). + +### Candidates not deleted (insufficient evidence) + +- `backend/src/v1/agent/crewai_flow.py` + - Reason: candidate report suggested possible dead code, but no deletion was done in this polish pass because cross-module usage certainty was insufficient. +- Legacy `run()` path in `backend/src/v1/agent/service.py` + - Reason: potentially still relied on by non-scope code paths; deletion deferred. diff --git a/backend/src/v1/agent/router.py b/backend/src/v1/agent/router.py index 667a35b..5f6612b 100644 --- a/backend/src/v1/agent/router.py +++ b/backend/src/v1/agent/router.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Path from fastapi.responses import StreamingResponse from v1.agent.dependencies import get_agent_service @@ -25,7 +25,7 @@ async def create_run( @router.post("/runs/{run_id}/resume") async def resume_run( - run_id: str, + run_id: Annotated[str, Path(min_length=1, max_length=255)], input_data: RunAgentInput, service: Annotated[AgentChatService, Depends(get_agent_service)], ) -> StreamingResponse: @@ -34,6 +34,7 @@ async def resume_run( status_code=409, detail=f"run_id mismatch: path={run_id}, body={input_data.runId}", ) + await service.prepare_resume(run_id, input_data) return StreamingResponse( service.stream_resume(run_id, input_data), media_type="text/event-stream", diff --git a/backend/src/v1/agent/schemas.py b/backend/src/v1/agent/schemas.py index 59a5a8b..4e9da22 100644 --- a/backend/src/v1/agent/schemas.py +++ b/backend/src/v1/agent/schemas.py @@ -1,9 +1,12 @@ from __future__ import annotations +from datetime import datetime +from enum import Enum +from typing import Literal from typing import Any from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator class RunAgentInput(BaseModel): @@ -40,3 +43,46 @@ class AgentChatRunResponse(BaseModel): session_id: UUID output: str events: list[AgentChatEvent] + + +class PendingToolStatus(str, Enum): + PENDING_APPROVAL = "PENDING_APPROVAL" + APPROVED_EXECUTING = "APPROVED_EXECUTING" + EXECUTED = "EXECUTED" + REJECTED = "REJECTED" + EXPIRED = "EXPIRED" + + +class PendingToolCall(BaseModel): + model_config = ConfigDict(extra="forbid") + + interrupt_id: str = Field(min_length=1, max_length=255) + tool_name: str = Field(min_length=1, max_length=255) + tool_args: dict[str, Any] = Field(default_factory=dict) + status: PendingToolStatus + expires_at: datetime + decision: dict[str, Any] | None = None + result: dict[str, Any] | None = None + updated_at: datetime + + @field_validator("expires_at", "updated_at") + @classmethod + def _validate_timezone_aware(cls, value: datetime) -> datetime: + if value.tzinfo is None or value.utcoffset() is None: + raise ValueError("datetime must be timezone-aware") + return value + + +class SnapshotRunContext(BaseModel): + model_config = ConfigDict(extra="forbid") + + thread_id: str = Field(min_length=1, max_length=255) + run_id: str = Field(min_length=1, max_length=255) + + +class AgentSessionSnapshot(BaseModel): + model_config = ConfigDict(extra="forbid") + + version: Literal[2] + pending_tool_call: PendingToolCall | None + run_context: SnapshotRunContext diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index 45f037e..81cbced 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -21,10 +21,14 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus from v1.auth.rate_limit import enforce_rate_limit from v1.agent.schemas import ( + AgentSessionSnapshot, AgentChatEvent, AgentChatRunRequest, AgentChatRunResponse, + PendingToolCall, + PendingToolStatus, RunAgentInput, + SnapshotRunContext, ) if TYPE_CHECKING: @@ -294,6 +298,41 @@ class AgentChatService(BaseService): return None return session.state_snapshot + @staticmethod + def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot: + try: + return AgentSessionSnapshot.model_validate(raw_snapshot) + except Exception as exc: # noqa: BLE001 + raise ValueError("Invalid state_snapshot format") from exc + + async def _get_session_for_update( + self, session_id: UUID + ) -> AgentChatSession | None: + stmt = ( + select(AgentChatSession) + .where(AgentChatSession.id == session_id) + .with_for_update() + .limit(1) + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + def _assert_session_owner(self, session: AgentChatSession) -> None: + if self._current_user is None: + return + + if session.user_id != self.require_user_id(): + raise HTTPException(status_code=404, detail="Session not found") + + @staticmethod + def _validate_no_newlines(value: str, *, field_name: str) -> None: + if "\n" in value or "\r" in value: + raise ValueError(f"{field_name} must not contain newlines") + + @staticmethod + def _sse_data(payload: dict[str, Any]) -> str: + return f"data: {json.dumps(payload)}\\n\\n" + async def set_pending_tool_call( self, *, @@ -302,22 +341,30 @@ class AgentChatService(BaseService): tool_name: str, tool_args: dict, expires_at: datetime, + thread_id: str, + run_id: str, ) -> None: stmt = select(AgentChatSession).where(AgentChatSession.id == session_id) session = await self._session.scalar(stmt) if session is None: raise ValueError(f"Session {session_id} not found") - snapshot = session.state_snapshot or {} - snapshot["pending_tool_call"] = { - "interrupt_id": interrupt_id, - "tool_name": tool_name, - "tool_args": tool_args, - "status": "PENDING_APPROVAL", - "expires_at": expires_at.isoformat(), - "decision": None, - "result": None, - } - session.state_snapshot = snapshot + self._assert_session_owner(session) + + snapshot = AgentSessionSnapshot( + version=2, + run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id), + pending_tool_call=PendingToolCall( + interrupt_id=interrupt_id, + tool_name=tool_name, + tool_args=tool_args, + status=PendingToolStatus.PENDING_APPROVAL, + expires_at=expires_at, + decision=None, + result=None, + updated_at=datetime.now(timezone.utc), + ), + ) + session.state_snapshot = snapshot.model_dump(mode="json") async def update_pending_tool_call_status( self, @@ -330,13 +377,27 @@ class AgentChatService(BaseService): session = await self._session.scalar(stmt) if session is None: raise ValueError(f"Session {session_id} not found") - snapshot = session.state_snapshot - if snapshot is None or "pending_tool_call" not in snapshot: + self._assert_session_owner(session) + if session.state_snapshot is None: raise ValueError("No pending tool call found") - if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id: + + snapshot = self._load_snapshot_v2(session.state_snapshot) + pending = snapshot.pending_tool_call + if pending is None: + raise ValueError("No pending tool call found") + if pending.interrupt_id != interrupt_id: raise ValueError("Interrupt ID mismatch") - snapshot["pending_tool_call"]["status"] = status - session.state_snapshot = snapshot + + updated_pending = pending.model_copy( + update={ + "status": PendingToolStatus(status), + "updated_at": datetime.now(timezone.utc), + } + ) + updated_snapshot = snapshot.model_copy( + update={"pending_tool_call": updated_pending} + ) + session.state_snapshot = updated_snapshot.model_dump(mode="json") async def apply_resume_decision( self, @@ -345,52 +406,140 @@ class AgentChatService(BaseService): interrupt_id: str, decision: dict[str, Any], ) -> ResumeDecisionResult: - stmt = select(AgentChatSession).where(AgentChatSession.id == session_id) - session = await self._session.scalar(stmt) + session = await self._get_session_for_update(session_id) if session is None: raise ValueError(f"Session {session_id} not found") + self._assert_session_owner(session) - snapshot = session.state_snapshot - if snapshot is None or "pending_tool_call" not in snapshot: + if session.state_snapshot is None: return ResumeDecisionResult(applied=False) - pending = snapshot["pending_tool_call"] - if pending["interrupt_id"] != interrupt_id: + snapshot = self._load_snapshot_v2(session.state_snapshot) + pending = snapshot.pending_tool_call + if pending is None: return ResumeDecisionResult(applied=False) - if pending["status"] != "PENDING_APPROVAL": + if pending.interrupt_id != interrupt_id: + return ResumeDecisionResult(applied=False) + + if pending.status != PendingToolStatus.PENDING_APPROVAL: + return ResumeDecisionResult(applied=False) + + now = datetime.now(timezone.utc) + if pending.expires_at <= now: + expired_pending = pending.model_copy( + update={ + "status": PendingToolStatus.EXPIRED, + "updated_at": now, + } + ) + expired_snapshot = snapshot.model_copy( + update={"pending_tool_call": expired_pending} + ) + session.state_snapshot = expired_snapshot.model_dump(mode="json") return ResumeDecisionResult(applied=False) decision_value = decision.get("decision", "approved") - if decision_value == "approved": - new_status = "APPROVED_EXECUTING" - else: - new_status = "REJECTED" + next_status = ( + PendingToolStatus.APPROVED_EXECUTING + if decision_value == "approved" + else PendingToolStatus.REJECTED + ) - snapshot["pending_tool_call"]["status"] = new_status - snapshot["pending_tool_call"]["decision"] = decision - session.state_snapshot = snapshot + updated_pending = pending.model_copy( + update={ + "status": next_status, + "decision": decision, + "updated_at": now, + } + ) + updated_snapshot = snapshot.model_copy( + update={"pending_tool_call": updated_pending} + ) + session.state_snapshot = updated_snapshot.model_dump(mode="json") return ResumeDecisionResult(applied=True) async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]: - if "\n" in input_data.runId or "\r" in input_data.runId: - raise ValueError("runId must not contain newlines") + self._validate_no_newlines(input_data.runId, field_name="runId") - yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n" - yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n" - yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n" - yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n" - yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n" + yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId}) + yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"}) + yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}) + yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"}) + yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId}) + + async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None: + self._validate_no_newlines(run_id, field_name="runId") + + user_id = self.require_user_id() + await enforce_rate_limit( + scope="agent_resume", + identifier=str(user_id), + limit=DEFAULT_RATE_LIMIT, + window_seconds=DEFAULT_RATE_LIMIT, + ) + + try: + session_id = UUID(run_id) + except ValueError as exc: + raise HTTPException( + status_code=422, detail="run_id must be a valid UUID" + ) from exc + + session = await self._get_session_for_update(session_id) + if session is None or session.user_id != user_id: + raise HTTPException(status_code=404, detail="Session not found") + + if input_data.resume is None: + raise HTTPException(status_code=422, detail="resume payload is required") + + interrupt_id = input_data.resume.get("interruptId") + if not isinstance(interrupt_id, str) or not interrupt_id: + raise HTTPException( + status_code=422, detail="resume.interruptId is required" + ) + + decision_payload = input_data.resume.get("payload", {}) + if not isinstance(decision_payload, dict): + raise HTTPException( + status_code=422, + detail="resume.payload must be an object", + ) + + try: + decision_result = await self.apply_resume_decision( + session_id=session_id, + interrupt_id=interrupt_id, + decision=decision_payload, + ) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + + if not decision_result.applied: + if session.state_snapshot is not None: + snapshot = self._load_snapshot_v2(session.state_snapshot) + pending = snapshot.pending_tool_call + if pending is not None and pending.status == PendingToolStatus.EXPIRED: + await self._session.commit() + raise HTTPException( + status_code=410, + detail="Pending tool call expired", + ) + raise HTTPException( + status_code=409, + detail="Resume decision not applicable", + ) + + await self._session.commit() async def stream_resume( self, run_id: str, input_data: RunAgentInput ) -> AsyncGenerator[str, None]: - if "\n" in run_id or "\r" in run_id: - raise ValueError("runId must not contain newlines") + self._validate_no_newlines(run_id, field_name="runId") - yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n" - yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n" - yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n" - yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n" - yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n" + yield self._sse_data({"type": "RUN_STARTED", "runId": run_id}) + yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"}) + yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}) + yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"}) + yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id}) diff --git a/backend/src/v1/agent/tool_dispatcher.py b/backend/src/v1/agent/tool_dispatcher.py index efe1909..b1c4d18 100644 --- a/backend/src/v1/agent/tool_dispatcher.py +++ b/backend/src/v1/agent/tool_dispatcher.py @@ -17,6 +17,12 @@ ALLOWED_BACKEND_TOOLS = frozenset( } ) +ALLOWED_FRONTEND_TOOLS = frozenset( + { + "ui.navigate_to", + } +) + class InterruptResult(BaseModel): interrupt_type: str @@ -47,6 +53,8 @@ def dispatch_tool_call( args = tool.get("args", {}) if target == "frontend": + if name not in ALLOWED_FRONTEND_TOOLS: + raise ValueError(f"Frontend tool '{name}' not in allowlist") return InterruptResult( interrupt_type="tool_execution", tool_name=name, diff --git a/backend/tests/integration/v1/agent/test_chat_routes.py b/backend/tests/integration/v1/agent/test_chat_routes.py index b81ab75..36c482c 100644 --- a/backend/tests/integration/v1/agent/test_chat_routes.py +++ b/backend/tests/integration/v1/agent/test_chat_routes.py @@ -13,6 +13,9 @@ from v1.users.dependencies import get_current_user class FakeAgentService: + async def prepare_resume(self, run_id: str, input_data: RunAgentInput): + return None + async def stream_run(self, input_data: RunAgentInput): yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n' yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n' diff --git a/backend/tests/integration/v1/agent/test_interrupt_resume_flow.py b/backend/tests/integration/v1/agent/test_interrupt_resume_flow.py index c41f117..26a2c79 100644 --- a/backend/tests/integration/v1/agent/test_interrupt_resume_flow.py +++ b/backend/tests/integration/v1/agent/test_interrupt_resume_flow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from uuid import UUID import pytest @@ -13,6 +14,9 @@ from v1.users.dependencies import get_current_user class FakeAgentServiceWithInterrupt: + async def prepare_resume(self, run_id: str, input_data: RunAgentInput): + return None + async def stream_run(self, input_data: RunAgentInput): yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n' yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n' @@ -30,7 +34,7 @@ class FakeAgentServiceWithInterrupt: yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n' yield ( 'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": ' - + str(payload.get("result", {})) + + json.dumps(payload.get("result", {})) + "}\n\n" ) yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n' @@ -100,10 +104,7 @@ class TestInterruptResumeFlow: 0 ] assert '"toolName": "ui.navigate_to"' in tool_result_event - assert ( - "'success': True" in tool_result_event - or '"success": true' in tool_result_event.lower() - ) + assert '"success": true' in tool_result_event.lower() def test_backend_tool_approval_rejected(self, client: TestClient): payload = { diff --git a/backend/tests/unit/v1/agent/test_agent_security_rules.py b/backend/tests/unit/v1/agent/test_agent_security_rules.py index 766ac42..5ed52f1 100644 --- a/backend/tests/unit/v1/agent/test_agent_security_rules.py +++ b/backend/tests/unit/v1/agent/test_agent_security_rules.py @@ -1,5 +1,17 @@ from __future__ import annotations +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +import pytest + +from core.auth.models import CurrentUser +from models.agent_chat_session import ( + AgentChatSession, + AgentChatSessionStatus, + SessionType, +) +from v1.agent.service import AgentChatService from v1.agent.tool_registry import validate_tool_spec @@ -18,5 +30,51 @@ class TestAgentSecurityRules: else: raise AssertionError("Should have raised ValueError for unknown namespace") - def test_frontend_result_fails_when_interrupt_mismatch(self): - pass + @pytest.mark.asyncio + async def test_frontend_result_fails_when_interrupt_mismatch(self): + session = AgentChatSession( + id=uuid4(), + user_id=UUID("00000000-0000-0000-0000-000000000001"), + session_type=SessionType.CHAT, + status=AgentChatSessionStatus.RUNNING, + ) + + class FakeAsyncSession: + def __init__(self, session_obj: AgentChatSession) -> None: + self._session_obj = session_obj + + async def execute(self, stmt: object): + class _Result: + def __init__(self, session_obj: AgentChatSession | None) -> None: + self._session_obj = session_obj + + def scalar_one_or_none(self) -> AgentChatSession | None: + return self._session_obj + + return _Result(self._session_obj) + + async def scalar(self, stmt: object) -> AgentChatSession | None: + return self._session_obj + + service = AgentChatService( + session=FakeAsyncSession(session), # type: ignore[arg-type] + current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")), + ) + + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-1", + tool_name="srv.transfer_funds", + tool_args={"to": "u2", "amount": 100}, + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + thread_id="t1", + run_id="r1", + ) + + result = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-other", + decision={"decision": "approved"}, + ) + + assert result.applied is False diff --git a/backend/tests/unit/v1/agent/test_resume_idempotency.py b/backend/tests/unit/v1/agent/test_resume_idempotency.py index 660beea..a11048e 100644 --- a/backend/tests/unit/v1/agent/test_resume_idempotency.py +++ b/backend/tests/unit/v1/agent/test_resume_idempotency.py @@ -17,6 +17,7 @@ class FakeAsyncSession: def __init__(self) -> None: self.added: list[object] = [] self._sessions: dict[UUID, AgentChatSession] = {} + self.last_fetch_with_lock = False def add(self, obj: object) -> None: self.added.append(obj) @@ -35,8 +36,17 @@ class FakeAsyncSession: async def refresh(self, obj: object) -> None: pass - async def execute(self, stmt: object) -> None: - pass + async def execute(self, stmt: object): + self.last_fetch_with_lock = "FOR UPDATE" in str(stmt) + + class _Result: + def __init__(self, session_obj: AgentChatSession | None) -> None: + self._session_obj = session_obj + + def scalar_one_or_none(self) -> AgentChatSession | None: + return self._session_obj + + return _Result(next(iter(self._sessions.values()), None)) async def scalar(self, stmt: object) -> AgentChatSession | None: for session in self._sessions.values(): @@ -69,7 +79,10 @@ def service(fake_db: FakeAsyncSession) -> AgentChatService: class TestResumeIdempotency: @pytest.mark.asyncio async def test_resume_is_idempotent( - self, service: AgentChatService, session: AgentChatSession + self, + service: AgentChatService, + session: AgentChatSession, + fake_db: FakeAsyncSession, ): expires_at = datetime.now(timezone.utc) + timedelta(hours=1) await service.set_pending_tool_call( @@ -78,6 +91,8 @@ class TestResumeIdempotency: tool_name="srv.transfer_funds", tool_args={"to": "u2", "amount": 100}, expires_at=expires_at, + thread_id="t1", + run_id="r1", ) first = await service.apply_resume_decision( @@ -93,6 +108,7 @@ class TestResumeIdempotency: assert first.applied is True assert second.applied is False + assert fake_db.last_fetch_with_lock is True @pytest.mark.asyncio async def test_resume_updates_status_to_approved( @@ -105,6 +121,8 @@ class TestResumeIdempotency: tool_name="srv.delete_file", tool_args={"file_id": "f1"}, expires_at=expires_at, + thread_id="t1", + run_id="r1", ) result = await service.apply_resume_decision( @@ -129,6 +147,8 @@ class TestResumeIdempotency: tool_name="srv.transfer_funds", tool_args={"to": "u2", "amount": 100}, expires_at=expires_at, + thread_id="t1", + run_id="r1", ) result = await service.apply_resume_decision( @@ -140,3 +160,28 @@ class TestResumeIdempotency: assert result.applied is True snapshot = await service.get_state_snapshot(session.id) assert snapshot["pending_tool_call"]["status"] == "REJECTED" + + @pytest.mark.asyncio + async def test_resume_expired_pending_marks_expired_and_not_applied( + self, service: AgentChatService, session: AgentChatSession + ): + expires_at = datetime.now(timezone.utc) - timedelta(seconds=1) + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-expired", + tool_name="srv.transfer_funds", + tool_args={"to": "u2", "amount": 100}, + expires_at=expires_at, + thread_id="t1", + run_id="r1", + ) + + result = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-expired", + decision={"decision": "approved"}, + ) + + assert result.applied is False + snapshot = await service.get_state_snapshot(session.id) + assert snapshot["pending_tool_call"]["status"] == "EXPIRED" diff --git a/backend/tests/unit/v1/agent/test_schemas.py b/backend/tests/unit/v1/agent/test_schemas.py index 9fac259..d9ee806 100644 --- a/backend/tests/unit/v1/agent/test_schemas.py +++ b/backend/tests/unit/v1/agent/test_schemas.py @@ -1,4 +1,8 @@ -from v1.agent.schemas import RunAgentInput +from datetime import datetime, timezone + +import pytest + +from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput class TestRunAgentInput: @@ -55,3 +59,69 @@ class TestRunAgentInput: assert model.state == {"key": "value"} assert len(model.messages) == 1 assert model.messages[0]["role"] == "user" + + +class TestAgentSessionSnapshot: + def test_state_snapshot_v2_model_accepts_valid_payload(self): + payload = { + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-1", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": "2026-03-03T12:00:00Z", + "decision": None, + "result": None, + "updated_at": "2026-03-03T11:59:00Z", + }, + "run_context": {"thread_id": "t1", "run_id": "r1"}, + } + + model = AgentSessionSnapshot.model_validate(payload) + + assert model.version == 2 + assert model.pending_tool_call is not None + assert model.pending_tool_call.interrupt_id == "int-1" + assert model.pending_tool_call.updated_at == datetime( + 2026, 3, 3, 11, 59, tzinfo=timezone.utc + ) + + def test_state_snapshot_v2_rejects_wrong_version(self): + payload = { + "version": 1, + "pending_tool_call": None, + "run_context": {"thread_id": "t1", "run_id": "r1"}, + } + + with pytest.raises(ValueError): + AgentSessionSnapshot.model_validate(payload) + + def test_state_snapshot_v2_requires_pending_tool_call_key(self): + payload = { + "version": 2, + "run_context": {"thread_id": "t1", "run_id": "r1"}, + } + + with pytest.raises(ValueError): + AgentSessionSnapshot.model_validate(payload) + + def test_state_snapshot_v2_rejects_extra_fields(self): + payload = { + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-1", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": "2026-03-03T12:00:00Z", + "decision": None, + "result": None, + "updated_at": "2026-03-03T11:59:00Z", + "unexpected": True, + }, + "run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"}, + } + + with pytest.raises(ValueError): + AgentSessionSnapshot.model_validate(payload) diff --git a/backend/tests/unit/v1/agent/test_service_pending_tool_call.py b/backend/tests/unit/v1/agent/test_service_pending_tool_call.py index 3a801a0..2662f0c 100644 --- a/backend/tests/unit/v1/agent/test_service_pending_tool_call.py +++ b/backend/tests/unit/v1/agent/test_service_pending_tool_call.py @@ -35,8 +35,15 @@ class FakeAsyncSession: async def refresh(self, obj: object) -> None: pass - async def execute(self, stmt: object) -> None: - pass + async def execute(self, stmt: object): + class _Result: + def __init__(self, session_obj: AgentChatSession | None) -> None: + self._session_obj = session_obj + + def scalar_one_or_none(self) -> AgentChatSession | None: + return self._session_obj + + return _Result(next(iter(self._sessions.values()), None)) async def scalar(self, stmt: object) -> AgentChatSession | None: for session in self._sessions.values(): @@ -78,9 +85,14 @@ class TestPendingToolCall: tool_name="srv.transfer_funds", tool_args={"to": "u2", "amount": 100}, expires_at=expires_at, + thread_id="t1", + run_id="r1", ) snapshot = await service.get_state_snapshot(session.id) assert snapshot is not None + assert snapshot["version"] == 2 + assert snapshot["run_context"]["thread_id"] == "t1" + assert snapshot["run_context"]["run_id"] == "r1" assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL" assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1" assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds" @@ -103,6 +115,8 @@ class TestPendingToolCall: tool_name="srv.delete_file", tool_args={"file_id": "f1"}, expires_at=expires_at, + thread_id="t1", + run_id="r1", ) await service.update_pending_tool_call_status( @@ -113,3 +127,42 @@ class TestPendingToolCall: snapshot = await service.get_state_snapshot(session.id) assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING" + + @pytest.mark.asyncio + async def test_invalid_legacy_snapshot_is_rejected( + self, service: AgentChatService, session: AgentChatSession + ): + session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}} + + with pytest.raises(ValueError): + await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-legacy", + decision={"decision": "approved"}, + ) + + @pytest.mark.asyncio + async def test_snapshot_rejects_naive_datetime( + self, service: AgentChatService, session: AgentChatSession + ): + session.state_snapshot = { + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-naive", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": "2026-03-03T12:00:00", + "decision": None, + "result": None, + "updated_at": "2026-03-03T11:59:00", + }, + "run_context": {"thread_id": "t1", "run_id": "r1"}, + } + + with pytest.raises(ValueError): + await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-naive", + decision={"decision": "approved"}, + ) diff --git a/backend/tests/unit/v1/agent/test_stream_resume_security.py b/backend/tests/unit/v1/agent/test_stream_resume_security.py new file mode 100644 index 0000000..b38c62a --- /dev/null +++ b/backend/tests/unit/v1/agent/test_stream_resume_security.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +import pytest +from fastapi import HTTPException + +from core.auth.models import CurrentUser +from models.agent_chat_session import ( + AgentChatSession, + AgentChatSessionStatus, + SessionType, +) +from v1.agent.schemas import RunAgentInput +from v1.agent.service import AgentChatService + + +class FakeAsyncSession: + def __init__(self, sessions: list[AgentChatSession]) -> None: + self._sessions = {session.id: session for session in sessions} + self.commit_called = False + + async def execute(self, stmt: object): + class _Result: + def __init__(self, session_obj: AgentChatSession | None) -> None: + self._session_obj = session_obj + + def scalar_one_or_none(self) -> AgentChatSession | None: + return self._session_obj + + for session in self._sessions.values(): + return _Result(session) + return _Result(None) + + async def scalar(self, stmt: object) -> AgentChatSession | None: + for session in self._sessions.values(): + return session + return None + + async def commit(self) -> None: + self.commit_called = True + + +def _build_input(run_id: str) -> RunAgentInput: + return RunAgentInput.model_validate( + { + "threadId": "t1", + "runId": run_id, + "state": {}, + "messages": [], + "tools": [], + "context": [], + "forwardedProps": {}, + "resume": {"interruptId": "int-1", "payload": {"decision": "approved"}}, + } + ) + + +@pytest.mark.asyncio +async def test_stream_resume_rejects_non_owner_session() -> None: + session = AgentChatSession( + id=uuid4(), + user_id=uuid4(), + session_type=SessionType.CHAT, + status=AgentChatSessionStatus.RUNNING, + state_snapshot={ + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-1", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": datetime.now(timezone.utc).isoformat(), + "decision": None, + "result": None, + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + "run_context": {"thread_id": "t1", "run_id": str(uuid4())}, + }, + ) + service = AgentChatService( + session=FakeAsyncSession([session]), # type: ignore[arg-type] + current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")), + ) + + with pytest.raises(HTTPException) as exc_info: + await service.prepare_resume(str(session.id), _build_input(str(session.id))) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_prepare_resume_commits_expired_state_before_410() -> None: + owner_id = UUID("00000000-0000-0000-0000-000000000001") + session = AgentChatSession( + id=uuid4(), + user_id=owner_id, + session_type=SessionType.CHAT, + status=AgentChatSessionStatus.RUNNING, + state_snapshot={ + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-1", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": "2000-01-01T00:00:00+00:00", + "decision": None, + "result": None, + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + "run_context": {"thread_id": "t1", "run_id": str(uuid4())}, + }, + ) + fake_db = FakeAsyncSession([session]) + service = AgentChatService( + session=fake_db, # type: ignore[arg-type] + current_user=CurrentUser(id=owner_id), + ) + + with pytest.raises(HTTPException) as exc_info: + await service.prepare_resume(str(session.id), _build_input(str(session.id))) + + assert exc_info.value.status_code == 410 + assert fake_db.commit_called is True diff --git a/backend/tests/unit/v1/agent/test_tool_dispatcher.py b/backend/tests/unit/v1/agent/test_tool_dispatcher.py index 8b7ecc2..0eacfcf 100644 --- a/backend/tests/unit/v1/agent/test_tool_dispatcher.py +++ b/backend/tests/unit/v1/agent/test_tool_dispatcher.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from v1.agent.tool_dispatcher import ( BackendExecutionResult, InterruptResult, @@ -46,9 +48,18 @@ class TestToolDispatcher: def test_dispatcher_class_can_dispatch(self): dispatcher = ToolDispatcher() tool = { - "name": "ui.show_dialog", + "name": "ui.navigate_to", "execution_target": "frontend", "args": {"message": "Hello"}, } result = dispatcher.dispatch(tool) assert isinstance(result, InterruptResult) + + def test_unknown_frontend_tool_is_rejected(self): + tool = { + "name": "ui.unknown_action", + "execution_target": "frontend", + "args": {}, + } + with pytest.raises(ValueError, match="not in allowlist"): + dispatch_tool_call(tool) diff --git a/docs/plans/2026-03-03-interrupt-resume-fixes-design.md b/docs/plans/2026-03-03-interrupt-resume-fixes-design.md new file mode 100644 index 0000000..1e72df9 --- /dev/null +++ b/docs/plans/2026-03-03-interrupt-resume-fixes-design.md @@ -0,0 +1,95 @@ +# Agent Interrupt/Resume 遗留问题修复设计 + +## 1. 目标 + +本次修复一次性完成以下三项遗留问题: + +1. `state_snapshot` 并发一致性问题(并发 resume 竞争) +2. `expires_at` 过期未强校验问题 +3. `state_snapshot` 缺少强类型与版本化问题 + +## 2. 设计决策 + +采用方案 2(严格重构): + +- `state_snapshot` 仅接受新结构,不再兼容旧结构 +- 统一快照版本为 `version = 2` +- 使用强类型模型约束快照结构与状态迁移 +- resume 入口引入行级锁语义,避免并发双写 + +## 3. 状态快照模型 + +`state_snapshot` 顶层结构: + +```json +{ + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-1", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": "2026-03-03T12:00:00Z", + "decision": null, + "result": null, + "updated_at": "2026-03-03T11:59:00Z" + }, + "run_context": { + "thread_id": "t1", + "run_id": "r1" + } +} +``` + +说明: + +- `version` 必须为 2,否则拒绝处理 +- `pending_tool_call` 字段缺失或类型错误,按无效快照处理 +- `run_context` 仅保留 interrupt/resume 必需字段 + +## 4. 状态机约束 + +仅允许以下迁移: + +- `PENDING_APPROVAL -> APPROVED_EXECUTING -> EXECUTED` +- `PENDING_APPROVAL -> REJECTED` +- `PENDING_APPROVAL -> EXPIRED` + +非法状态迁移必须返回错误,不做隐式修复。 + +## 5. 并发与过期语义 + +- resume 前先对目标 session 加锁再读取快照 +- 同一 `interrupt_id` 并发 resume 只能有一个请求成功 +- 若 `expires_at < now(UTC)`,先迁移为 `EXPIRED`,再返回 410 + +## 6. 错误语义(RFC7807) + +- `409 Conflict`: run/interrupt 不匹配,或并发冲突导致状态已消费 +- `410 Gone`: 挂起调用已过期 +- `422 Unprocessable Entity`: `state_snapshot` 非法或版本不匹配 +- `404 Not Found`: 目标 session/run 不存在 + +## 7. 测试策略 + +采用 TDD,先写失败测试后实现: + +- 快照版本校验(`version != 2`) +- 快照结构校验(必填字段/类型) +- 并发 resume 幂等竞争(仅一个成功) +- 过期校验(返回 410 + 状态置 EXPIRED) +- 合法状态迁移路径覆盖 + +## 8. 验证命令 + +- `uv run pytest backend/tests/unit/v1/agent -v` +- `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py -v` +- `uv run pytest backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v` +- `cd backend && uv run ruff check src/v1/agent` +- `cd backend && uv run basedpyright src/v1/agent` + +## 9. 风险与回滚 + +- 风险:旧快照不再兼容,可能触发运行时拒绝 +- 处置:通过明确 422 错误暴露不合规数据,结合日志定位并人工修复数据 +- 回滚:回退本次变更并恢复旧快照解析逻辑(仅在紧急故障时) diff --git a/docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md b/docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md new file mode 100644 index 0000000..a89d306 --- /dev/null +++ b/docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md @@ -0,0 +1,377 @@ +# Agent Interrupt/Resume Strict Refactor Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 通过严格重构一次性修复 interrupt/resume 的并发安全、过期校验和 state_snapshot 强类型版本化问题。 + +**Architecture:** 以 `state_snapshot v2` 为唯一合法结构,服务层使用强类型模型解析与状态迁移,resume 路径在读取会话时加行锁保证并发一致性。路由层维持现有 run/resume 入口,错误通过 HTTPException 输出,测试覆盖版本校验、过期语义、并发幂等和状态机迁移。 + +**Tech Stack:** FastAPI, SQLAlchemy AsyncSession, Pydantic v2, pytest + +--- + +### Task 1: 新增 state_snapshot v2 强类型模型 + +**Files:** +- Modify: `backend/src/v1/agent/schemas.py` +- Test: `backend/tests/unit/v1/agent/test_schemas.py` + +**Step 1: Write the failing test** + +```python +def test_state_snapshot_v2_model_accepts_valid_payload(): + payload = { + "version": 2, + "pending_tool_call": { + "interrupt_id": "int-1", + "tool_name": "srv.transfer_funds", + "tool_args": {"to": "u2", "amount": 100}, + "status": "PENDING_APPROVAL", + "expires_at": "2026-03-03T12:00:00Z", + "decision": None, + "result": None, + "updated_at": "2026-03-03T11:59:00Z", + }, + "run_context": {"thread_id": "t1", "run_id": "r1"}, + } + model = AgentSessionSnapshot.model_validate(payload) + assert model.version == 2 + + +def test_state_snapshot_v2_rejects_wrong_version(): + payload = { + "version": 1, + "pending_tool_call": None, + "run_context": {"thread_id": "t1", "run_id": "r1"}, + } + with pytest.raises(ValueError): + AgentSessionSnapshot.model_validate(payload) +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_schemas.py -v` +Expected: FAIL(`AgentSessionSnapshot` 未定义或校验不符合预期) + +**Step 3: Write minimal implementation** + +```python +class PendingToolStatus(str, Enum): + PENDING_APPROVAL = "PENDING_APPROVAL" + APPROVED_EXECUTING = "APPROVED_EXECUTING" + EXECUTED = "EXECUTED" + REJECTED = "REJECTED" + EXPIRED = "EXPIRED" + + +class PendingToolCall(BaseModel): + interrupt_id: str + tool_name: str + tool_args: dict[str, Any] + status: PendingToolStatus + expires_at: datetime + decision: dict[str, Any] | None = None + result: dict[str, Any] | None = None + updated_at: datetime + + +class SnapshotRunContext(BaseModel): + thread_id: str + run_id: str + + +class AgentSessionSnapshot(BaseModel): + version: Literal[2] + pending_tool_call: PendingToolCall | None = None + run_context: SnapshotRunContext +``` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_schemas.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/v1/agent/schemas.py backend/tests/unit/v1/agent/test_schemas.py +git commit -m "refactor(agent): add strict v2 session snapshot schema" +``` + +--- + +### Task 2: service 层改为 v2 快照读写(严格拒绝旧结构) + +**Files:** +- Modify: `backend/src/v1/agent/service.py` +- Test: `backend/tests/unit/v1/agent/test_service_pending_tool_call.py` + +**Step 1: Write the failing test** + +```python +@pytest.mark.asyncio +async def test_set_pending_tool_call_writes_v2_snapshot(service, session): + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-1", + tool_name="srv.transfer_funds", + tool_args={"to": "u2", "amount": 100}, + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + thread_id="t1", + run_id="r1", + ) + snapshot = await service.get_state_snapshot(session.id) + assert snapshot["version"] == 2 + assert snapshot["run_context"]["run_id"] == "r1" + + +@pytest.mark.asyncio +async def test_invalid_legacy_snapshot_is_rejected(service, session): + session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}} + with pytest.raises(ValueError): + await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-1", + decision={"decision": "approved"}, + ) +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_service_pending_tool_call.py -v` +Expected: FAIL + +**Step 3: Write minimal implementation** + +```python +def _build_snapshot_v2(...): + return AgentSessionSnapshot(...).model_dump(mode="json") + + +def _load_snapshot_v2(raw: dict[str, Any] | None) -> AgentSessionSnapshot: + if raw is None: + raise ValueError("state_snapshot missing") + return AgentSessionSnapshot.model_validate(raw) +``` + +并将 `set_pending_tool_call/get_state_snapshot/update_pending_tool_call_status` 全部改成 v2 模型读写。 + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_service_pending_tool_call.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_service_pending_tool_call.py +git commit -m "refactor(agent): enforce v2 snapshot read write in service" +``` + +--- + +### Task 3: 增加 resume 行锁与并发幂等 + +**Files:** +- Modify: `backend/src/v1/agent/service.py` +- Test: `backend/tests/unit/v1/agent/test_resume_idempotency.py` + +**Step 1: Write the failing test** + +```python +@pytest.mark.asyncio +async def test_apply_resume_decision_uses_locked_session_fetch(service, fake_db, session): + await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-1", + decision={"decision": "approved"}, + ) + assert fake_db.last_fetch_with_lock is True + + +@pytest.mark.asyncio +async def test_resume_is_idempotent(service, session): + first = await service.apply_resume_decision(...) + second = await service.apply_resume_decision(...) + assert first.applied is True + assert second.applied is False +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v` +Expected: FAIL + +**Step 3: Write minimal implementation** + +```python +async def _get_session_for_update(self, session_id: UUID) -> AgentChatSession | None: + stmt = ( + select(AgentChatSession) + .where(AgentChatSession.id == session_id) + .with_for_update() + .limit(1) + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() +``` + +`apply_resume_decision` 改为锁内读取、校验、状态迁移,保证并发下单次生效。 + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_resume_idempotency.py +git commit -m "fix(agent): add row lock for resume state transition" +``` + +--- + +### Task 4: 增加 expires_at 过期校验(含 EXPIRED 迁移) + +**Files:** +- Modify: `backend/src/v1/agent/service.py` +- Test: `backend/tests/unit/v1/agent/test_resume_idempotency.py` + +**Step 1: Write the failing test** + +```python +@pytest.mark.asyncio +async def test_resume_expired_pending_returns_not_applied_and_marks_expired(service, session): + await service.set_pending_tool_call(..., expires_at=datetime.now(timezone.utc) - timedelta(seconds=1), thread_id="t1", run_id="r1") + result = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-1", + decision={"decision": "approved"}, + ) + assert result.applied is False + snapshot = await service.get_state_snapshot(session.id) + assert snapshot["pending_tool_call"]["status"] == "EXPIRED" +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v` +Expected: FAIL + +**Step 3: Write minimal implementation** + +```python +if pending.expires_at < datetime.now(timezone.utc): + pending.status = PendingToolStatus.EXPIRED + pending.updated_at = datetime.now(timezone.utc) + session.state_snapshot = snapshot.model_dump(mode="json") + return ResumeDecisionResult(applied=False, expired=True) +``` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/unit/v1/agent/test_resume_idempotency.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/v1/agent/service.py backend/tests/unit/v1/agent/test_resume_idempotency.py +git commit -m "fix(agent): enforce expires_at when applying resume decision" +``` + +--- + +### Task 5: 路由层补齐 v2 快照与过期/冲突错误映射 + +**Files:** +- Modify: `backend/src/v1/agent/router.py` +- Modify: `backend/src/v1/agent/service.py` +- Test: `backend/tests/integration/v1/agent/test_chat_routes.py` +- Test: `backend/tests/integration/v1/agent/test_interrupt_resume_flow.py` + +**Step 1: Write the failing test** + +```python +def test_resume_route_returns_409_on_run_id_mismatch(client): + ... + + +def test_resume_route_returns_410_when_pending_expired(client): + ... + + +def test_resume_route_returns_422_for_legacy_snapshot(client): + ... +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v` +Expected: FAIL + +**Step 3: Write minimal implementation** + +在 `stream_resume` 或路由调用链里将领域错误映射为: + +- 过期 -> `HTTPException(410)` +- 旧快照/结构错误 -> `HTTPException(422)` +- 状态冲突/重复消费 -> `HTTPException(409)` + +**Step 4: Run test to verify it passes** + +Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/src/v1/agent/router.py backend/src/v1/agent/service.py backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py +git commit -m "fix(agent): map resume snapshot errors to 409 410 422" +``` + +--- + +### Task 6: 更新文档并完成验证 + +**Files:** +- Modify: `docs/plans/2026-03-03-agent-chat-design.md` +- Modify: `docs/runtime/runtime-route.md` + +**Step 1: Update docs** + +- 明确 `state_snapshot version=2` 为唯一支持结构 +- 明确 resume 过期与并发冲突语义(410/409) +- 明确旧快照拒绝策略(422) + +**Step 2: Run unit tests** + +Run: `uv run pytest backend/tests/unit/v1/agent -v` +Expected: PASS + +**Step 3: Run integration tests** + +Run: `uv run pytest backend/tests/integration/v1/agent/test_chat_routes.py backend/tests/integration/v1/agent/test_interrupt_resume_flow.py -v` +Expected: PASS + +**Step 4: Run static checks** + +Run: `cd backend && uv run ruff check src/v1/agent` +Expected: PASS + +Run: `cd backend && uv run basedpyright src/v1/agent` +Expected: PASS + +**Step 5: Commit** + +```bash +git add docs/plans/2026-03-03-agent-chat-design.md docs/runtime/runtime-route.md +git commit -m "docs(agent): document strict snapshot v2 and resume error semantics" +``` + +--- + +Plan complete and saved to `docs/plans/2026-03-03-interrupt-resume-fixes-implementation-plan.md`. + +Execution mode selected by user request: Subagent-Driven (this session), proceed task-by-task immediately. diff --git a/docs/runtime/runtime-route.md b/docs/runtime/runtime-route.md index f158920..6caea22 100644 --- a/docs/runtime/runtime-route.md +++ b/docs/runtime/runtime-route.md @@ -788,42 +788,68 @@ ## Agent -### POST /agent +### POST /agent/runs -运行 Agent 对话(需要认证)。 +创建 Agent 运行(需要认证,SSE 响应)。 -**Request:** +**Request (RunAgentInput):** ```json { - "message": "string (1-8000 chars)", - "session_id": "string? (UUID)" + "threadId": "string", + "runId": "string", + "parentRunId": "string?", + "state": {}, + "messages": [], + "tools": [], + "context": [], + "forwardedProps": {}, + "resume": null } ``` -**Response:** 200 OK -```json -{ - "session_id": "string (UUID)", - "output": "string", - "events": [ - { - "type": "string", - "run_id": "string?", - "message_id": "string?", - "delta": "string?", - "tool_name": "string?", - "result": "string?", - "output": "string?", - "error": "string?" - } - ] -} -``` +**Response:** 200 OK (`text/event-stream`) **Errors:** - 401: 未认证 - 422: 请求参数无效 +### POST /agent/runs/{run_id}/resume + +恢复被中断运行(需要认证,SSE 响应)。 + +**Request (RunAgentInput):** +```json +{ + "threadId": "string", + "runId": "string", + "state": {}, + "messages": [], + "tools": [], + "context": [], + "forwardedProps": {}, + "resume": { + "interruptId": "string", + "payload": {} + } +} +``` + +**State Snapshot Contract:** +- `state_snapshot` 仅支持 `version = 2` +- 顶层必须包含 `run_context` 与 `pending_tool_call` +- 旧格式或缺失字段会被拒绝 + +**Resume Semantics:** +- 同一 `interrupt_id` 并发恢复仅允许一个请求成功 +- `expires_at` 超时后会标记为 `EXPIRED`,恢复请求不再生效 + +**Errors:** +- 401: 未认证 +- 404: 会话不存在 +- 409: `run_id` 或 `interrupt_id` 冲突,或状态已被消费 +- 410: 挂起调用已过期 +- 422: `state_snapshot` 非法或版本不匹配 + --- ## Infra