From dedd23fdf98451e23e943f2b0e55b549de340c4a Mon Sep 17 00:00:00 2001 From: qzl Date: Tue, 3 Mar 2026 15:43:10 +0800 Subject: [PATCH] fix(agent): enforce idempotent resume transition --- backend/src/v1/agent/service.py | 42 +++++- .../unit/v1/agent/test_resume_idempotency.py | 142 ++++++++++++++++++ 2 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 backend/tests/unit/v1/agent/test_resume_idempotency.py diff --git a/backend/src/v1/agent/service.py b/backend/src/v1/agent/service.py index a506eb6..31415a3 100644 --- a/backend/src/v1/agent/service.py +++ b/backend/src/v1/agent/service.py @@ -2,10 +2,11 @@ from __future__ import annotations from datetime import datetime, timezone from decimal import Decimal -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID from fastapi import HTTPException +from pydantic import BaseModel from sqlalchemy import func, select from sqlalchemy.exc import SQLAlchemyError @@ -29,6 +30,10 @@ if TYPE_CHECKING: logger = get_logger("v1.agent.service") +class ResumeDecisionResult(BaseModel): + applied: bool + + def build_session_title(first_message: str, *, now: datetime) -> str: title = first_message.strip().replace("\n", " ")[:24] if not title: @@ -335,3 +340,38 @@ class AgentChatService(BaseService): raise ValueError("Interrupt ID mismatch") snapshot["pending_tool_call"]["status"] = status session.state_snapshot = snapshot + + async def apply_resume_decision( + self, + *, + session_id: UUID, + interrupt_id: str, + decision: dict[str, Any], + ) -> ResumeDecisionResult: + stmt = select(AgentChatSession).where(AgentChatSession.id == session_id) + session = await self._session.scalar(stmt) + if session is None: + raise ValueError(f"Session {session_id} not found") + + snapshot = session.state_snapshot + if snapshot is None or "pending_tool_call" not in snapshot: + return ResumeDecisionResult(applied=False) + + pending = snapshot["pending_tool_call"] + if pending["interrupt_id"] != interrupt_id: + return ResumeDecisionResult(applied=False) + + if pending["status"] != "PENDING_APPROVAL": + return ResumeDecisionResult(applied=False) + + decision_value = decision.get("decision", "approved") + if decision_value == "approved": + new_status = "APPROVED_EXECUTING" + else: + new_status = "REJECTED" + + snapshot["pending_tool_call"]["status"] = new_status + snapshot["pending_tool_call"]["decision"] = decision + session.state_snapshot = snapshot + + return ResumeDecisionResult(applied=True) diff --git a/backend/tests/unit/v1/agent/test_resume_idempotency.py b/backend/tests/unit/v1/agent/test_resume_idempotency.py new file mode 100644 index 0000000..660beea --- /dev/null +++ b/backend/tests/unit/v1/agent/test_resume_idempotency.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +import pytest + +from models.agent_chat_session import ( + AgentChatSession, + AgentChatSessionStatus, + SessionType, +) +from v1.agent.service import AgentChatService + + +class FakeAsyncSession: + def __init__(self) -> None: + self.added: list[object] = [] + self._sessions: dict[UUID, AgentChatSession] = {} + + def add(self, obj: object) -> None: + self.added.append(obj) + if isinstance(obj, AgentChatSession): + self._sessions[obj.id] = obj + + async def flush(self) -> None: + return None + + async def commit(self) -> None: + pass + + async def rollback(self) -> None: + pass + + async def refresh(self, obj: object) -> None: + pass + + async def execute(self, stmt: object) -> None: + pass + + async def scalar(self, stmt: object) -> AgentChatSession | None: + for session in self._sessions.values(): + return session + return None + + +@pytest.fixture +def fake_db() -> FakeAsyncSession: + return FakeAsyncSession() + + +@pytest.fixture +def session(fake_db: FakeAsyncSession) -> AgentChatSession: + sess = AgentChatSession( + id=uuid4(), + user_id=uuid4(), + session_type=SessionType.CHAT, + status=AgentChatSessionStatus.RUNNING, + ) + fake_db.add(sess) + return sess + + +@pytest.fixture +def service(fake_db: FakeAsyncSession) -> AgentChatService: + return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type] + + +class TestResumeIdempotency: + @pytest.mark.asyncio + async def test_resume_is_idempotent( + self, service: AgentChatService, session: AgentChatSession + ): + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-1", + tool_name="srv.transfer_funds", + tool_args={"to": "u2", "amount": 100}, + expires_at=expires_at, + ) + + first = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-1", + decision={"decision": "approved"}, + ) + second = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-1", + decision={"decision": "approved"}, + ) + + assert first.applied is True + assert second.applied is False + + @pytest.mark.asyncio + async def test_resume_updates_status_to_approved( + self, service: AgentChatService, session: AgentChatSession + ): + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-2", + tool_name="srv.delete_file", + tool_args={"file_id": "f1"}, + expires_at=expires_at, + ) + + result = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-2", + decision={"decision": "approved"}, + ) + + assert result.applied is True + snapshot = await service.get_state_snapshot(session.id) + assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING" + assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"} + + @pytest.mark.asyncio + async def test_resume_updates_status_to_rejected( + self, service: AgentChatService, session: AgentChatSession + ): + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + await service.set_pending_tool_call( + session_id=session.id, + interrupt_id="int-3", + tool_name="srv.transfer_funds", + tool_args={"to": "u2", "amount": 100}, + expires_at=expires_at, + ) + + result = await service.apply_resume_decision( + session_id=session.id, + interrupt_id="int-3", + decision={"decision": "rejected"}, + ) + + assert result.applied is True + snapshot = await service.get_state_snapshot(session.id) + assert snapshot["pending_tool_call"]["status"] == "REJECTED"