89 lines
2.6 KiB
Python
89 lines
2.6 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Literal
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
|
|
|
|
class RunAgentInput(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
threadId: str = Field(min_length=1, max_length=255)
|
|
runId: str = Field(min_length=1, max_length=255)
|
|
parentRunId: str | None = Field(default=None, max_length=255)
|
|
state: dict[str, Any] = Field(default_factory=dict)
|
|
messages: list[dict[str, Any]] = Field(default_factory=list)
|
|
tools: list[dict[str, Any]] = Field(default_factory=list)
|
|
context: list[dict[str, Any]] = Field(default_factory=list)
|
|
forwardedProps: dict[str, Any] = Field(default_factory=dict)
|
|
resume: dict[str, Any] | None = None
|
|
|
|
|
|
class AgentChatRunRequest(BaseModel):
|
|
message: str = Field(min_length=1, max_length=8000)
|
|
session_id: UUID | None = None
|
|
|
|
|
|
class AgentChatEvent(BaseModel):
|
|
type: str
|
|
run_id: str | None = None
|
|
message_id: str | None = None
|
|
delta: str | None = None
|
|
tool_name: str | None = None
|
|
result: str | None = None
|
|
output: str | None = None
|
|
error: str | None = None
|
|
|
|
|
|
class AgentChatRunResponse(BaseModel):
|
|
session_id: UUID
|
|
output: str
|
|
events: list[AgentChatEvent]
|
|
|
|
|
|
class PendingToolStatus(str, Enum):
|
|
PENDING_APPROVAL = "PENDING_APPROVAL"
|
|
APPROVED_EXECUTING = "APPROVED_EXECUTING"
|
|
EXECUTED = "EXECUTED"
|
|
REJECTED = "REJECTED"
|
|
EXPIRED = "EXPIRED"
|
|
|
|
|
|
class PendingToolCall(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
interrupt_id: str = Field(min_length=1, max_length=255)
|
|
tool_name: str = Field(min_length=1, max_length=255)
|
|
tool_args: dict[str, Any] = Field(default_factory=dict)
|
|
status: PendingToolStatus
|
|
expires_at: datetime
|
|
decision: dict[str, Any] | None = None
|
|
result: dict[str, Any] | None = None
|
|
updated_at: datetime
|
|
|
|
@field_validator("expires_at", "updated_at")
|
|
@classmethod
|
|
def _validate_timezone_aware(cls, value: datetime) -> datetime:
|
|
if value.tzinfo is None or value.utcoffset() is None:
|
|
raise ValueError("datetime must be timezone-aware")
|
|
return value
|
|
|
|
|
|
class SnapshotRunContext(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
thread_id: str = Field(min_length=1, max_length=255)
|
|
run_id: str = Field(min_length=1, max_length=255)
|
|
|
|
|
|
class AgentSessionSnapshot(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
version: Literal[2]
|
|
pending_tool_call: PendingToolCall | None
|
|
run_context: SnapshotRunContext
|