feat(agent): persist pending tool call in session snapshot
This commit is contained in:
@@ -284,3 +284,54 @@ class AgentChatService(BaseService):
|
||||
"content": message,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
|
||||
}
|
||||
|
||||
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
return None
|
||||
return session.state_snapshot
|
||||
|
||||
async def set_pending_tool_call(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict,
|
||||
expires_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot or {}
|
||||
snapshot["pending_tool_call"] = {
|
||||
"interrupt_id": interrupt_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
}
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
async def update_pending_tool_call_status(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot
|
||||
if snapshot is None or "pending_tool_call" not in snapshot:
|
||||
raise ValueError("No pending tool call found")
|
||||
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
|
||||
raise ValueError("Interrupt ID mismatch")
|
||||
snapshot["pending_tool_call"]["status"] = status
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
Reference in New Issue
Block a user