Files
eryao/backend/src/v1/agent/schemas.py
T

247 lines
6.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Any, Literal, Protocol
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from schemas.agent.runtime_models import ErrorInfo, RunStatus, SignLevel
from schemas.domain.chat_message import AgentChatMessage
from schemas.domain.divination import DerivedDivinationData
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def delete_session(self, *, session_id: str) -> list[dict[str, str]]: ...
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_session_messages(
self,
*,
session_id: str,
visibility_mask: int | None = None,
) -> list[AgentChatMessage]: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def get_latest_assistant_messages_by_user_sessions(
self,
*,
user_id: str,
visibility_mask: int | None = None,
session_limit: int = 50,
) -> list[AgentChatMessage]: ...
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: Any,
visibility_mask: int,
) -> None: ...
async def get_assistant_message_count(self, *, session_id: str) -> int: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
async def request_cancel(
self,
*,
thread_id: str,
run_id: str,
requested_by: str,
) -> None: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class PointsServiceLike(Protocol):
async def ensure_run_points_available(
self,
*,
user_id: UUID,
) -> int: ...
async def consume_successful_run_points(
self,
*,
user_id: UUID,
session_id: UUID,
run_id: str,
operator_id: UUID | None,
) -> Any: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
async def delete_prefix(self, *, bucket: str, prefix: str) -> int: ...
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
@dataclass(frozen=True)
class CancelRequested:
thread_id: str
run_id: str
accepted: bool
class TaskAcceptedResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
task_id: str = Field(alias="taskId")
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
created: bool
class CancelRunResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
accepted: bool
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
class AttachmentReference(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
bucket: str
path: str
mime_type: str = Field(alias="mimeType")
url: str
class AttachmentUploadResponse(BaseModel):
attachment: AttachmentReference
class AttachmentSignedUrlResponse(BaseModel):
bucket: str
path: str
url: str
class HistoryMessageAttachment(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
mime_type: str = Field(alias="mimeType")
url: str
class HistoryMessage(BaseModel):
"""History message schema for /history endpoint response."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
id: str = Field(description="Message UUID")
thread_id: str = Field(alias="threadId", description="Owning session UUID")
seq: int = Field(description="Message sequence number")
role: Literal["user", "assistant"] = Field(
description="Message role: user | assistant"
)
content: str = Field(description="Message text content")
attachments: list[HistoryMessageAttachment] = Field(
default_factory=list,
description="Temporary signed URLs for user-attached images",
)
agent_output: HistoryAgentOutput | None = Field(
default=None,
description="Structured assistant output for history replay",
)
timestamp: str = Field(description="Message creation timestamp in ISO-8601 format")
class HistoryAgentOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
status: RunStatus | None = None
sign_level: SignLevel | None = None
conclusion: list[str] = Field(default_factory=list)
focus_points: list[str] = Field(default_factory=list)
advice: list[str] = Field(default_factory=list)
keywords: list[str] = Field(default_factory=list)
answer: str | None = None
error: ErrorInfo | None = None
divination_derived: DerivedDivinationData | None = None
class HistorySnapshotResponse(BaseModel):
"""Response schema for GET /api/v1/agent/history"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
scope: str = Field(
default="history_session_full",
description="history_session_full | history_sessions_latest_assistant",
)
thread_id: str | None = Field(default=None, alias="threadId")
day: str | None = None
has_more: bool = Field(default=False, alias="hasMore")
messages: list[HistoryMessage] = Field(default_factory=list)