fix(agent): polish interrupt-resume flow for merge readiness

This commit is contained in:
qzl
2026-03-03 17:26:04 +08:00
parent 7be8669144
commit 30a4a1af5d
16 changed files with 1179 additions and 85 deletions
+193 -44
View File
@@ -21,10 +21,14 @@ from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent.schemas import (
AgentSessionSnapshot,
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
PendingToolCall,
PendingToolStatus,
RunAgentInput,
SnapshotRunContext,
)
if TYPE_CHECKING:
@@ -294,6 +298,41 @@ class AgentChatService(BaseService):
return None
return session.state_snapshot
@staticmethod
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
try:
return AgentSessionSnapshot.model_validate(raw_snapshot)
except Exception as exc: # noqa: BLE001
raise ValueError("Invalid state_snapshot format") from exc
async def _get_session_for_update(
self, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _assert_session_owner(self, session: AgentChatSession) -> None:
if self._current_user is None:
return
if session.user_id != self.require_user_id():
raise HTTPException(status_code=404, detail="Session not found")
@staticmethod
def _validate_no_newlines(value: str, *, field_name: str) -> None:
if "\n" in value or "\r" in value:
raise ValueError(f"{field_name} must not contain newlines")
@staticmethod
def _sse_data(payload: dict[str, Any]) -> str:
return f"data: {json.dumps(payload)}\\n\\n"
async def set_pending_tool_call(
self,
*,
@@ -302,22 +341,30 @@ class AgentChatService(BaseService):
tool_name: str,
tool_args: dict,
expires_at: datetime,
thread_id: str,
run_id: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot or {}
snapshot["pending_tool_call"] = {
"interrupt_id": interrupt_id,
"tool_name": tool_name,
"tool_args": tool_args,
"status": "PENDING_APPROVAL",
"expires_at": expires_at.isoformat(),
"decision": None,
"result": None,
}
session.state_snapshot = snapshot
self._assert_session_owner(session)
snapshot = AgentSessionSnapshot(
version=2,
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
pending_tool_call=PendingToolCall(
interrupt_id=interrupt_id,
tool_name=tool_name,
tool_args=tool_args,
status=PendingToolStatus.PENDING_APPROVAL,
expires_at=expires_at,
decision=None,
result=None,
updated_at=datetime.now(timezone.utc),
),
)
session.state_snapshot = snapshot.model_dump(mode="json")
async def update_pending_tool_call_status(
self,
@@ -330,13 +377,27 @@ class AgentChatService(BaseService):
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
self._assert_session_owner(session)
if session.state_snapshot is None:
raise ValueError("No pending tool call found")
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
raise ValueError("No pending tool call found")
if pending.interrupt_id != interrupt_id:
raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": PendingToolStatus(status),
"updated_at": datetime.now(timezone.utc),
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
async def apply_resume_decision(
self,
@@ -345,52 +406,140 @@ class AgentChatService(BaseService):
interrupt_id: str,
decision: dict[str, Any],
) -> ResumeDecisionResult:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
session = await self._get_session_for_update(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
if session.state_snapshot is None:
return ResumeDecisionResult(applied=False)
pending = snapshot["pending_tool_call"]
if pending["interrupt_id"] != interrupt_id:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
return ResumeDecisionResult(applied=False)
if pending["status"] != "PENDING_APPROVAL":
if pending.interrupt_id != interrupt_id:
return ResumeDecisionResult(applied=False)
if pending.status != PendingToolStatus.PENDING_APPROVAL:
return ResumeDecisionResult(applied=False)
now = datetime.now(timezone.utc)
if pending.expires_at <= now:
expired_pending = pending.model_copy(
update={
"status": PendingToolStatus.EXPIRED,
"updated_at": now,
}
)
expired_snapshot = snapshot.model_copy(
update={"pending_tool_call": expired_pending}
)
session.state_snapshot = expired_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=False)
decision_value = decision.get("decision", "approved")
if decision_value == "approved":
new_status = "APPROVED_EXECUTING"
else:
new_status = "REJECTED"
next_status = (
PendingToolStatus.APPROVED_EXECUTING
if decision_value == "approved"
else PendingToolStatus.REJECTED
)
snapshot["pending_tool_call"]["status"] = new_status
snapshot["pending_tool_call"]["decision"] = decision
session.state_snapshot = snapshot
updated_pending = pending.model_copy(
update={
"status": next_status,
"decision": decision,
"updated_at": now,
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
if "\n" in input_data.runId or "\r" in input_data.runId:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(input_data.runId, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': input_data.runId})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Hello'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm1'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': input_data.runId})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
self._validate_no_newlines(run_id, field_name="runId")
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_resume",
identifier=str(user_id),
limit=DEFAULT_RATE_LIMIT,
window_seconds=DEFAULT_RATE_LIMIT,
)
try:
session_id = UUID(run_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="run_id must be a valid UUID"
) from exc
session = await self._get_session_for_update(session_id)
if session is None or session.user_id != user_id:
raise HTTPException(status_code=404, detail="Session not found")
if input_data.resume is None:
raise HTTPException(status_code=422, detail="resume payload is required")
interrupt_id = input_data.resume.get("interruptId")
if not isinstance(interrupt_id, str) or not interrupt_id:
raise HTTPException(
status_code=422, detail="resume.interruptId is required"
)
decision_payload = input_data.resume.get("payload", {})
if not isinstance(decision_payload, dict):
raise HTTPException(
status_code=422,
detail="resume.payload must be an object",
)
try:
decision_result = await self.apply_resume_decision(
session_id=session_id,
interrupt_id=interrupt_id,
decision=decision_payload,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if not decision_result.applied:
if session.state_snapshot is not None:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
await self._session.commit()
raise HTTPException(
status_code=410,
detail="Pending tool call expired",
)
raise HTTPException(
status_code=409,
detail="Resume decision not applicable",
)
await self._session.commit()
async def stream_resume(
self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]:
if "\n" in run_id or "\r" in run_id:
raise ValueError("runId must not contain newlines")
self._validate_no_newlines(run_id, field_name="runId")
yield f"data: {json.dumps({'type': 'RUN_STARTED', 'runId': run_id})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_START', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_CONTENT', 'delta': 'Resumed'})}\n\n"
yield f"data: {json.dumps({'type': 'TEXT_MESSAGE_END', 'messageId': 'm2'})}\n\n"
yield f"data: {json.dumps({'type': 'RUN_FINISHED', 'runId': run_id})}\n\n"
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})