feat(agent): persist pending tool call in session snapshot
This commit is contained in:
@@ -14,7 +14,7 @@ from sqlalchemy import (
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
@@ -75,3 +75,7 @@ class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
|
||||
total_cost: Mapped[Decimal] = mapped_column(
|
||||
Numeric(12, 6), nullable=False, server_default=text("0")
|
||||
)
|
||||
state_snapshot: Mapped[dict | None] = mapped_column(
|
||||
JSONB,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
@@ -284,3 +284,54 @@ class AgentChatService(BaseService):
|
||||
"content": message,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
|
||||
}
|
||||
|
||||
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
return None
|
||||
return session.state_snapshot
|
||||
|
||||
async def set_pending_tool_call(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict,
|
||||
expires_at: datetime,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot or {}
|
||||
snapshot["pending_tool_call"] = {
|
||||
"interrupt_id": interrupt_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
}
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
async def update_pending_tool_call_status(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
snapshot = session.state_snapshot
|
||||
if snapshot is None or "pending_tool_call" not in snapshot:
|
||||
raise ValueError("No pending tool call found")
|
||||
if snapshot["pending_tool_call"]["interrupt_id"] != interrupt_id:
|
||||
raise ValueError("Interrupt ID mismatch")
|
||||
snapshot["pending_tool_call"]["status"] = status
|
||||
session.state_snapshot = snapshot
|
||||
|
||||
Reference in New Issue
Block a user