feat(agent): persist pending tool call in session snapshot

This commit is contained in:
qzl
2026-03-03 15:39:56 +08:00
parent e03923e593
commit cff1436bc6
3 changed files with 171 additions and 1 deletions
+5 -1
View File
@@ -14,7 +14,7 @@ from sqlalchemy import (
func,
text,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
@@ -75,3 +75,7 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
total_cost: Mapped[Decimal] = mapped_column(
Numeric(12, 6), nullable=False, server_default=text("0")
)
state_snapshot: Mapped[dict | None] = mapped_column(
JSONB,
nullable=True,
)
+51
View File
@@ -284,3 +284,54 @@ class AgentChatService(BaseService):
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
return None
return session.state_snapshot
async def set_pending_tool_call(
self,
*,
session_id: UUID,
interrupt_id: str,
tool_name: str,
tool_args: dict,
expires_at: datetime,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot or {}
snapshot["pending_tool_call"] = {
"interrupt_id": interrupt_id,
"tool_name": tool_name,
"tool_args": tool_args,
"status": "PENDING_APPROVAL",
"expires_at": expires_at.isoformat(),
"decision": None,
"result": None,
}
session.state_snapshot = snapshot
async def update_pending_tool_call_status(
self,
*,
session_id: UUID,
interrupt_id: str,
status: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
snapshot = session.state_snapshot
if snapshot is None or "pending_tool_call" not in snapshot:
raise ValueError("No pending tool call found")
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
raise ValueError("Interrupt ID mismatch")
snapshot["pending_tool_call"]["status"] = status
session.state_snapshot = snapshot