fix(agent): enforce idempotent resume transition

This commit is contained in:
qzl
2026-03-03 15:43:10 +08:00
parent cff1436bc6
commit dedd23fdf9
2 changed files with 183 additions and 1 deletions
+41 -1
View File
@@ -2,10 +2,11 @@ from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from decimal import Decimal from decimal import Decimal
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@@ -29,6 +30,10 @@ if TYPE_CHECKING:
logger = get_logger("v1.agent.service") logger = get_logger("v1.agent.service")
class ResumeDecisionResult(BaseModel):
applied: bool
def build_session_title(first_message: str, *, now: datetime) -> str: def build_session_title(first_message: str, *, now: datetime) -> str:
title = first_message.strip().replace("\n", " ")[:24] title = first_message.strip().replace("\n", " ")[:24]
if not title: if not title:
@@ -335,3 +340,38 @@ class AgentChatService(BaseService):
raise ValueError("Interrupt ID mismatch") raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot session.state_snapshot = snapshot
async def apply_resume_decision(
self,
*,
session_id: UUID,
interrupt_id: str,
decision: dict[str, Any],
) -> ResumeDecisionResult:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
return ResumeDecisionResult(applied=False)
pending = snapshot["pending_tool_call"]
if pending["interrupt_id"] != interrupt_id:
return ResumeDecisionResult(applied=False)
if pending["status"] != "PENDING_APPROVAL":
return ResumeDecisionResult(applied=False)
decision_value = decision.get("decision", "approved")
if decision_value == "approved":
new_status = "APPROVED_EXECUTING"
else:
new_status = "REJECTED"
snapshot["pending_tool_call"]["status"] = new_status
snapshot["pending_tool_call"]["decision"] = decision
session.state_snapshot = snapshot
return ResumeDecisionResult(applied=True)
@@ -0,0 +1,142 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import pytest
from models.agent_chat_session import (
AgentChatSession,
AgentChatSessionStatus,
SessionType,
)
from v1.agent.service import AgentChatService
class FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self._sessions: dict[UUID, AgentChatSession] = {}
def add(self, obj: object) -> None:
self.added.append(obj)
if isinstance(obj, AgentChatSession):
self._sessions[obj.id] = obj
async def flush(self) -> None:
return None
async def commit(self) -> None:
pass
async def rollback(self) -> None:
pass
async def refresh(self, obj: object) -> None:
pass
async def execute(self, stmt: object) -> None:
pass
async def scalar(self, stmt: object) -> AgentChatSession | None:
for session in self._sessions.values():
return session
return None
@pytest.fixture
def fake_db() -> FakeAsyncSession:
return FakeAsyncSession()
@pytest.fixture
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
sess = AgentChatSession(
id=uuid4(),
user_id=uuid4(),
session_type=SessionType.CHAT,
status=AgentChatSessionStatus.RUNNING,
)
fake_db.add(sess)
return sess
@pytest.fixture
def service(fake_db: FakeAsyncSession) -> AgentChatService:
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
class TestResumeIdempotency:
@pytest.mark.asyncio
async def test_resume_is_idempotent(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-1",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
)
first = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
second = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-1",
decision={"decision": "approved"},
)
assert first.applied is True
assert second.applied is False
@pytest.mark.asyncio
async def test_resume_updates_status_to_approved(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-2",
tool_name="srv.delete_file",
tool_args={"file_id": "f1"},
expires_at=expires_at,
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-2",
decision={"decision": "approved"},
)
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"}
@pytest.mark.asyncio
async def test_resume_updates_status_to_rejected(
self, service: AgentChatService, session: AgentChatSession
):
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
await service.set_pending_tool_call(
session_id=session.id,
interrupt_id="int-3",
tool_name="srv.transfer_funds",
tool_args={"to": "u2", "amount": 100},
expires_at=expires_at,
)
result = await service.apply_resume_decision(
session_id=session.id,
interrupt_id="int-3",
decision={"decision": "rejected"},
)
assert result.applied is True
snapshot = await service.get_state_snapshot(session.id)
assert snapshot["pending_tool_call"]["status"] == "REJECTED"