fix(agent): polish interrupt-resume flow for merge readiness

This commit is contained in:
qzl
2026-03-03 17:26:04 +08:00
parent 7be8669144
commit 30a4a1af5d
16 changed files with 1179 additions and 85 deletions
+3 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from v1.agent.dependencies import get_agent_service
@@ -25,7 +25,7 @@ async def create_run(
@router.post("/runs/{run_id}/resume")
async def resume_run(
run_id: str,
run_id: Annotated[str, Path(min_length=1, max_length=255)],
input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse:
@@ -34,6 +34,7 @@ async def resume_run(
status_code=409,
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
)
await service.prepare_resume(run_id, input_data)
return StreamingResponse(
service.stream_resume(run_id, input_data),
media_type="text/event-stream",
+47 -1
View File
@@ -1,9 +1,12 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Literal
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
class RunAgentInput(BaseModel):
@@ -40,3 +43,46 @@ class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
class PendingToolStatus(str, Enum):
PENDING_APPROVAL = "PENDING_APPROVAL"
APPROVED_EXECUTING = "APPROVED_EXECUTING"
EXECUTED = "EXECUTED"
REJECTED = "REJECTED"
EXPIRED = "EXPIRED"
class PendingToolCall(BaseModel):
model_config = ConfigDict(extra="forbid")
interrupt_id: str = Field(min_length=1, max_length=255)
tool_name: str = Field(min_length=1, max_length=255)
tool_args: dict[str, Any] = Field(default_factory=dict)
status: PendingToolStatus
expires_at: datetime
decision: dict[str, Any] | None = None
result: dict[str, Any] | None = None
updated_at: datetime
@field_validator("expires_at", "updated_at")
@classmethod
def _validate_timezone_aware(cls, value: datetime) -> datetime:
if value.tzinfo is None or value.utcoffset() is None:
raise ValueError("datetime must be timezone-aware")
return value
class SnapshotRunContext(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str = Field(min_length=1, max_length=255)
run_id: str = Field(min_length=1, max_length=255)
class AgentSessionSnapshot(BaseModel):
model_config = ConfigDict(extra="forbid")
version: Literal[2]
pending_tool_call: PendingToolCall | None
run_context: SnapshotRunContext
+193 -44
View File
@@ -21,10 +21,14 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent.schemas import (
AgentSessionSnapshot,
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
PendingToolCall,
PendingToolStatus,
RunAgentInput,
SnapshotRunContext,
)
if TYPE_CHECKING:
@@ -294,6 +298,41 @@ class AgentChatService(BaseService):
return None
return session.state_snapshot
@staticmethod
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
try:
return AgentSessionSnapshot.model_validate(raw_snapshot)
except Exception as exc: # noqa: BLE001
raise ValueError("Invalid state_snapshot format") from exc
async def _get_session_for_update(
self, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _assert_session_owner(self, session: AgentChatSession) -> None:
if self._current_user is None:
return
if session.user_id != self.require_user_id():
raise HTTPException(status_code=404, detail="Session not found")
@staticmethod
def _validate_no_newlines(value: str, *, field_name: str) -> None:
if "\n" in value or "\r" in value:
raise ValueError(f"{field_name} must not contain newlines")
@staticmethod
def _sse_data(payload: dict[str, Any]) -> str:
return f"data: {json.dumps(payload)}\\n\\n"
async def set_pending_tool_call(
self,
*,
@@ -302,22 +341,30 @@ class AgentChatService(BaseService):
tool_name: str,
tool_args: dict,
expires_at: datetime,
thread_id: str,
run_id: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot or {}
snapshot["pending_tool_call"] = {
"interrupt_id": interrupt_id,
"tool_name": tool_name,
"tool_args": tool_args,
"status": "PENDING_APPROVAL",
"expires_at": expires_at.isoformat(),
"decision": None,
"result": None,
}
session.state_snapshot = snapshot
self._assert_session_owner(session)
snapshot = AgentSessionSnapshot(
version=2,
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
pending_tool_call=PendingToolCall(
interrupt_id=interrupt_id,
tool_name=tool_name,
tool_args=tool_args,
status=PendingToolStatus.PENDING_APPROVAL,
expires_at=expires_at,
decision=None,
result=None,
updated_at=datetime.now(timezone.utc),
),
)
session.state_snapshot = snapshot.model_dump(mode="json")
async def update_pending_tool_call_status(
self,
@@ -330,13 +377,27 @@ class AgentChatService(BaseService):
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
self._assert_session_owner(session)
if session.state_snapshot is None:
raise ValueError("No pending tool call found")
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
raise ValueError("No pending tool call found")
if pending.interrupt_id != interrupt_id:
raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": PendingToolStatus(status),
"updated_at": datetime.now(timezone.utc),
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
async def apply_resume_decision(
self,
@@ -345,52 +406,140 @@ class AgentChatService(BaseService):
interrupt_id: str,
decision: dict[str, Any],
) -> ResumeDecisionResult:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
session = await self._get_session_for_update(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
if session.state_snapshot is None:
return ResumeDecisionResult(applied=False)
pending = snapshot["pending_tool_call"]
if pending["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
return ResumeDecisionResult(applied=False)
if pending["status"] != "PENDING_APPROVAL":
if pending.interrupt_id != interrupt_id:
return ResumeDecisionResult(applied=False)
if pending.status != PendingToolStatus.PENDING_APPROVAL:
return ResumeDecisionResult(applied=False)
now = datetime.now(timezone.utc)
if pending.expires_at <= now:
expired_pending = pending.model_copy(
update={
"status": PendingToolStatus.EXPIRED,
"updated_at": now,
}
)
expired_snapshot = snapshot.model_copy(
update={"pending_tool_call": expired_pending}
)
session.state_snapshot = expired_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=False)
decision_value = decision.get("decision", "approved")
if decision_value == "approved":
new_status = "APPROVED_EXECUTING"
else:
new_status = "REJECTED"
next_status = (
PendingToolStatus.APPROVED_EXECUTING
if decision_value == "approved"
else PendingToolStatus.REJECTED
)
snapshot["pending_tool_call"]["status"] = new_status
snapshot["pending_tool_call"]["decision"] = decision
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": next_status,
"decision": decision,
"updated_at": now,
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
if "\n" in input_data.runId or "\r" in input_data.runId:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(input_data.runId, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
self._validate_no_newlines(run_id, field_name="runId")
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_resume",
identifier=str(user_id),
limit=DEFAULT_RATE_LIMIT,
window_seconds=DEFAULT_RATE_LIMIT,
)
try:
session_id = UUID(run_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="run_id must be a valid UUID"
) from exc
session = await self._get_session_for_update(session_id)
if session is None or session.user_id != user_id:
raise HTTPException(status_code=404, detail="Session not found")
if input_data.resume is None:
raise HTTPException(status_code=422, detail="resume payload is required")
interrupt_id = input_data.resume.get("interruptId")
if not isinstance(interrupt_id, str) or not interrupt_id:
raise HTTPException(
status_code=422, detail="resume.interruptId is required"
)
decision_payload = input_data.resume.get("payload", {})
if not isinstance(decision_payload, dict):
raise HTTPException(
status_code=422,
detail="resume.payload must be an object",
)
try:
decision_result = await self.apply_resume_decision(
session_id=session_id,
interrupt_id=interrupt_id,
decision=decision_payload,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if not decision_result.applied:
if session.state_snapshot is not None:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
await self._session.commit()
raise HTTPException(
status_code=410,
detail="Pending tool call expired",
)
raise HTTPException(
status_code=409,
detail="Resume decision not applicable",
)
await self._session.commit()
async def stream_resume(
self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]:
if "\n" in run_id or "\r" in run_id:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(run_id, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})
+8
View File
@@ -17,6 +17,12 @@ ALLOWED_BACKEND_TOOLS = frozenset(
}
)
ALLOWED_FRONTEND_TOOLS = frozenset(
{
"ui.navigate_to",
}
)
class InterruptResult(BaseModel):
interrupt_type: str
@@ -47,6 +53,8 @@ def dispatch_tool_call(
args = tool.get("args", {})
if target == "frontend":
if name not in ALLOWED_FRONTEND_TOOLS:
raise ValueError(f"Frontend tool '{name}' not in allowlist")
return InterruptResult(
interrupt_type="tool_execution",
tool_name=name,