fix(agent): enforce idempotent resume transition
This commit is contained in:
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
@@ -29,6 +30,10 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger("v1.agent.service")
|
logger = get_logger("v1.agent.service")
|
||||||
|
|
||||||
|
|
||||||
|
class ResumeDecisionResult(BaseModel):
|
||||||
|
applied: bool
|
||||||
|
|
||||||
|
|
||||||
def build_session_title(first_message: str, *, now: datetime) -> str:
|
def build_session_title(first_message: str, *, now: datetime) -> str:
|
||||||
title = first_message.strip().replace("\n", " ")[:24]
|
title = first_message.strip().replace("\n", " ")[:24]
|
||||||
if not title:
|
if not title:
|
||||||
@@ -335,3 +340,38 @@ class AgentChatService(BaseService):
|
|||||||
raise ValueError("Interrupt ID mismatch")
|
raise ValueError("Interrupt ID mismatch")
|
||||||
snapshot["pending_tool_call"]["status"] = status
|
snapshot["pending_tool_call"]["status"] = status
|
||||||
session.state_snapshot = snapshot
|
session.state_snapshot = snapshot
|
||||||
|
|
||||||
|
async def apply_resume_decision(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: UUID,
|
||||||
|
interrupt_id: str,
|
||||||
|
decision: dict[str, Any],
|
||||||
|
) -> ResumeDecisionResult:
|
||||||
|
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||||
|
session = await self._session.scalar(stmt)
|
||||||
|
if session is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
snapshot = session.state_snapshot
|
||||||
|
if snapshot is None or "pending_tool_call" not in snapshot:
|
||||||
|
return ResumeDecisionResult(applied=False)
|
||||||
|
|
||||||
|
pending = snapshot["pending_tool_call"]
|
||||||
|
if pending["interrupt_id"] != interrupt_id:
|
||||||
|
return ResumeDecisionResult(applied=False)
|
||||||
|
|
||||||
|
if pending["status"] != "PENDING_APPROVAL":
|
||||||
|
return ResumeDecisionResult(applied=False)
|
||||||
|
|
||||||
|
decision_value = decision.get("decision", "approved")
|
||||||
|
if decision_value == "approved":
|
||||||
|
new_status = "APPROVED_EXECUTING"
|
||||||
|
else:
|
||||||
|
new_status = "REJECTED"
|
||||||
|
|
||||||
|
snapshot["pending_tool_call"]["status"] = new_status
|
||||||
|
snapshot["pending_tool_call"]["decision"] = decision
|
||||||
|
session.state_snapshot = snapshot
|
||||||
|
|
||||||
|
return ResumeDecisionResult(applied=True)
|
||||||
|
|||||||
@@ -0,0 +1,142 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from models.agent_chat_session import (
|
||||||
|
AgentChatSession,
|
||||||
|
AgentChatSessionStatus,
|
||||||
|
SessionType,
|
||||||
|
)
|
||||||
|
from v1.agent.service import AgentChatService
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAsyncSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.added: list[object] = []
|
||||||
|
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||||
|
|
||||||
|
def add(self, obj: object) -> None:
|
||||||
|
self.added.append(obj)
|
||||||
|
if isinstance(obj, AgentChatSession):
|
||||||
|
self._sessions[obj.id] = obj
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def refresh(self, obj: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute(self, stmt: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||||
|
for session in self._sessions.values():
|
||||||
|
return session
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_db() -> FakeAsyncSession:
|
||||||
|
return FakeAsyncSession()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
|
||||||
|
sess = AgentChatSession(
|
||||||
|
id=uuid4(),
|
||||||
|
user_id=uuid4(),
|
||||||
|
session_type=SessionType.CHAT,
|
||||||
|
status=AgentChatSessionStatus.RUNNING,
|
||||||
|
)
|
||||||
|
fake_db.add(sess)
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||||
|
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumeIdempotency:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_is_idempotent(
|
||||||
|
self, service: AgentChatService, session: AgentChatSession
|
||||||
|
):
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||||
|
await service.set_pending_tool_call(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-1",
|
||||||
|
tool_name="srv.transfer_funds",
|
||||||
|
tool_args={"to": "u2", "amount": 100},
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
first = await service.apply_resume_decision(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-1",
|
||||||
|
decision={"decision": "approved"},
|
||||||
|
)
|
||||||
|
second = await service.apply_resume_decision(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-1",
|
||||||
|
decision={"decision": "approved"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert first.applied is True
|
||||||
|
assert second.applied is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_updates_status_to_approved(
|
||||||
|
self, service: AgentChatService, session: AgentChatSession
|
||||||
|
):
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||||
|
await service.set_pending_tool_call(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-2",
|
||||||
|
tool_name="srv.delete_file",
|
||||||
|
tool_args={"file_id": "f1"},
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.apply_resume_decision(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-2",
|
||||||
|
decision={"decision": "approved"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.applied is True
|
||||||
|
snapshot = await service.get_state_snapshot(session.id)
|
||||||
|
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||||
|
assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_updates_status_to_rejected(
|
||||||
|
self, service: AgentChatService, session: AgentChatSession
|
||||||
|
):
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||||
|
await service.set_pending_tool_call(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-3",
|
||||||
|
tool_name="srv.transfer_funds",
|
||||||
|
tool_args={"to": "u2", "amount": 100},
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.apply_resume_decision(
|
||||||
|
session_id=session.id,
|
||||||
|
interrupt_id="int-3",
|
||||||
|
decision={"decision": "rejected"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.applied is True
|
||||||
|
snapshot = await service.get_state_snapshot(session.id)
|
||||||
|
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
|
||||||
Reference in New Issue
Block a user