2026-03-05 15:34:37 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
from datetime import date, datetime, time, timedelta, timezone
|
|
|
|
|
import json
|
2026-03-08 17:07:09 +08:00
|
|
|
from typing import Protocol
|
2026-03-05 15:34:37 +08:00
|
|
|
from uuid import UUID
|
|
|
|
|
|
|
|
|
|
from fastapi import HTTPException
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
2026-03-12 09:29:57 +08:00
|
|
|
from core.config.settings import config
|
2026-03-07 17:30:20 +08:00
|
|
|
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
2026-03-05 15:34:37 +08:00
|
|
|
from models.agent_chat_session import AgentChatSession
|
|
|
|
|
|
|
|
|
|
|
2026-03-08 17:07:09 +08:00
|
|
|
class ToolResultPayloadStorage(Protocol):
|
|
|
|
|
async def read_json(
|
|
|
|
|
self, *, bucket: str, path: str
|
|
|
|
|
) -> dict[str, object] | None: ...
|
|
|
|
|
|
|
|
|
|
|
2026-03-05 15:34:37 +08:00
|
|
|
class AgentRepository:
|
2026-03-08 17:07:09 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
*,
|
|
|
|
|
tool_result_storage: ToolResultPayloadStorage | None = None,
|
|
|
|
|
) -> None:
|
2026-03-05 15:34:37 +08:00
|
|
|
self._session = session
|
2026-03-08 17:07:09 +08:00
|
|
|
self._tool_result_storage = tool_result_storage
|
2026-03-05 15:34:37 +08:00
|
|
|
|
|
|
|
|
async def get_session_owner(self, *, session_id: str) -> str:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
|
|
|
|
|
stmt = select(AgentChatSession.user_id).where(
|
|
|
|
|
AgentChatSession.id == session_uuid
|
|
|
|
|
)
|
|
|
|
|
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if owner_id is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
return str(owner_id)
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
async def create_session_for_user(
|
|
|
|
|
self, *, user_id: str, session_id: str | None = None
|
|
|
|
|
) -> str:
|
2026-03-05 15:34:37 +08:00
|
|
|
try:
|
|
|
|
|
user_uuid = UUID(user_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
2026-03-07 17:30:20 +08:00
|
|
|
session_uuid = None
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
2026-03-08 17:07:09 +08:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=422, detail="Invalid session_id"
|
|
|
|
|
) from exc
|
2026-03-05 15:34:37 +08:00
|
|
|
|
2026-03-06 12:02:10 +08:00
|
|
|
session = AgentChatSession(
|
2026-03-07 17:30:20 +08:00
|
|
|
id=session_uuid,
|
2026-03-06 12:02:10 +08:00
|
|
|
user_id=user_uuid,
|
|
|
|
|
)
|
2026-03-05 15:34:37 +08:00
|
|
|
self._session.add(session)
|
|
|
|
|
await self._session.flush()
|
|
|
|
|
await self._session.refresh(session)
|
|
|
|
|
return str(session.id)
|
|
|
|
|
|
|
|
|
|
async def commit(self) -> None:
|
|
|
|
|
await self._session.commit()
|
|
|
|
|
|
|
|
|
|
async def rollback(self) -> None:
|
|
|
|
|
await self._session.rollback()
|
|
|
|
|
|
|
|
|
|
async def delete_session(self, *, session_id: str) -> None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
session = await self._session.get(AgentChatSession, session_uuid)
|
|
|
|
|
if session is not None:
|
|
|
|
|
await self._session.delete(session)
|
|
|
|
|
await self._session.flush()
|
2026-03-07 17:30:20 +08:00
|
|
|
|
2026-03-11 21:06:02 +08:00
|
|
|
async def persist_user_message(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
session_id: str,
|
|
|
|
|
run_id: str,
|
|
|
|
|
content_text: str,
|
|
|
|
|
metadata: dict[str, object] | None,
|
|
|
|
|
) -> None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
|
|
|
|
|
stmt = (
|
|
|
|
|
select(AgentChatSession)
|
|
|
|
|
.where(AgentChatSession.id == session_uuid)
|
|
|
|
|
.with_for_update()
|
|
|
|
|
)
|
|
|
|
|
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if session_row is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
|
|
|
|
|
next_seq = int(session_row.message_count or 0) + 1
|
2026-03-12 00:18:45 +08:00
|
|
|
if not _has_title(session_row.title):
|
|
|
|
|
session_title = _derive_session_title(content_text)
|
|
|
|
|
if session_title is not None:
|
|
|
|
|
session_row.title = session_title
|
2026-03-11 21:06:02 +08:00
|
|
|
payload_metadata = dict(metadata or {})
|
|
|
|
|
payload_metadata["run_id"] = run_id
|
|
|
|
|
message = AgentChatMessage(
|
|
|
|
|
session_id=session_uuid,
|
|
|
|
|
seq=next_seq,
|
|
|
|
|
role=AgentChatMessageRole.USER,
|
|
|
|
|
content=content_text,
|
|
|
|
|
metadata_json=payload_metadata,
|
|
|
|
|
)
|
|
|
|
|
self._session.add(message)
|
|
|
|
|
session_row.message_count = next_seq
|
|
|
|
|
session_row.last_activity_at = datetime.now(timezone.utc)
|
|
|
|
|
await self._session.flush()
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
async def get_history_day(
|
|
|
|
|
self, *, session_id: str, before: date | None
|
|
|
|
|
) -> dict[str, object] | None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
|
|
|
|
|
timestamp_stmt = (
|
|
|
|
|
select(AgentChatMessage.created_at)
|
|
|
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
|
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
|
|
|
.order_by(AgentChatMessage.created_at.desc())
|
|
|
|
|
)
|
|
|
|
|
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
|
|
|
|
|
unique_days: list[date] = []
|
|
|
|
|
for created_at in rows:
|
|
|
|
|
if created_at is None:
|
|
|
|
|
continue
|
|
|
|
|
day = created_at.astimezone(timezone.utc).date()
|
|
|
|
|
if day not in unique_days:
|
|
|
|
|
unique_days.append(day)
|
|
|
|
|
|
|
|
|
|
if not unique_days:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
target_day: date | None = None
|
|
|
|
|
if before is None:
|
|
|
|
|
target_day = unique_days[0]
|
|
|
|
|
else:
|
|
|
|
|
for day in unique_days:
|
|
|
|
|
if day < before:
|
|
|
|
|
target_day = day
|
|
|
|
|
break
|
|
|
|
|
if target_day is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
|
|
|
|
end = start + timedelta(days=1)
|
|
|
|
|
message_stmt = (
|
|
|
|
|
select(AgentChatMessage)
|
|
|
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
|
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
|
|
|
.where(AgentChatMessage.created_at >= start)
|
|
|
|
|
.where(AgentChatMessage.created_at < end)
|
|
|
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
|
|
|
)
|
|
|
|
|
messages = (await self._session.execute(message_stmt)).scalars().all()
|
|
|
|
|
has_more = any(day < target_day for day in unique_days)
|
2026-03-08 17:07:09 +08:00
|
|
|
snapshot_messages: list[dict[str, object]] = []
|
|
|
|
|
for message in messages:
|
|
|
|
|
snapshot_messages.append(await self._to_snapshot_message(message))
|
2026-03-07 17:30:20 +08:00
|
|
|
return {
|
|
|
|
|
"day": target_day.isoformat(),
|
|
|
|
|
"hasMore": has_more,
|
2026-03-08 17:07:09 +08:00
|
|
|
"messages": snapshot_messages,
|
2026-03-07 17:30:20 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
|
|
|
|
try:
|
|
|
|
|
user_uuid = UUID(user_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
|
|
|
|
stmt = (
|
|
|
|
|
select(AgentChatSession.id)
|
|
|
|
|
.where(AgentChatSession.user_id == user_uuid)
|
|
|
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
|
|
|
.order_by(AgentChatSession.last_activity_at.desc())
|
|
|
|
|
.limit(1)
|
|
|
|
|
)
|
|
|
|
|
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if latest_id is None:
|
|
|
|
|
return None
|
|
|
|
|
return str(latest_id)
|
|
|
|
|
|
2026-03-12 09:29:57 +08:00
|
|
|
async def get_message_attachment_reference(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
session_id: str,
|
|
|
|
|
message_id: str,
|
|
|
|
|
attachment_index: int,
|
|
|
|
|
) -> dict[str, str] | None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
message_uuid = UUID(message_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=422, detail="Invalid message/session id"
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
stmt = (
|
|
|
|
|
select(AgentChatMessage)
|
|
|
|
|
.where(AgentChatMessage.id == message_uuid)
|
|
|
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
|
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
|
|
|
)
|
|
|
|
|
message = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if message is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
metadata = (
|
|
|
|
|
message.metadata_json if isinstance(message.metadata_json, dict) else {}
|
|
|
|
|
)
|
|
|
|
|
attachments_raw = metadata.get("attachments")
|
|
|
|
|
if not isinstance(attachments_raw, list):
|
|
|
|
|
return None
|
|
|
|
|
if attachment_index < 0 or attachment_index >= len(attachments_raw):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
attachment = attachments_raw[attachment_index]
|
|
|
|
|
if not isinstance(attachment, dict):
|
|
|
|
|
return None
|
|
|
|
|
bucket = attachment.get("bucket")
|
|
|
|
|
path = attachment.get("path")
|
|
|
|
|
mime_type = attachment.get("mimeType")
|
|
|
|
|
if (
|
|
|
|
|
not isinstance(bucket, str)
|
|
|
|
|
or not bucket
|
|
|
|
|
or not isinstance(path, str)
|
|
|
|
|
or not path
|
|
|
|
|
or not isinstance(mime_type, str)
|
|
|
|
|
or not mime_type
|
|
|
|
|
):
|
|
|
|
|
return None
|
|
|
|
|
return {
|
|
|
|
|
"bucket": bucket,
|
|
|
|
|
"path": path,
|
|
|
|
|
"mimeType": mime_type,
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 17:07:09 +08:00
|
|
|
async def _to_snapshot_message(
|
|
|
|
|
self, message: AgentChatMessage
|
|
|
|
|
) -> dict[str, object]:
|
2026-03-07 17:30:20 +08:00
|
|
|
role = (
|
|
|
|
|
message.role.value
|
|
|
|
|
if isinstance(message.role, AgentChatMessageRole)
|
|
|
|
|
else str(message.role)
|
|
|
|
|
)
|
|
|
|
|
payload: dict[str, object] = {
|
|
|
|
|
"id": str(message.id),
|
|
|
|
|
"role": role,
|
|
|
|
|
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if role == AgentChatMessageRole.TOOL.value:
|
|
|
|
|
metadata = message.metadata_json or {}
|
|
|
|
|
tool_call_id = metadata.get("tool_call_id")
|
|
|
|
|
if isinstance(tool_call_id, str) and tool_call_id:
|
|
|
|
|
payload["toolCallId"] = tool_call_id
|
|
|
|
|
|
|
|
|
|
parsed_content: dict[str, object] | None = None
|
|
|
|
|
try:
|
|
|
|
|
decoded = json.loads(message.content)
|
|
|
|
|
if isinstance(decoded, dict):
|
|
|
|
|
parsed_content = decoded
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
parsed_content = None
|
2026-03-08 17:07:09 +08:00
|
|
|
|
|
|
|
|
hydrated_content: dict[str, object] | None = None
|
|
|
|
|
if self._tool_result_storage is not None:
|
|
|
|
|
storage_bucket = metadata.get("storage_bucket")
|
|
|
|
|
storage_path = metadata.get("storage_path")
|
|
|
|
|
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
2026-03-12 09:29:57 +08:00
|
|
|
expected_bucket = config.storage.bucket
|
|
|
|
|
message_session_id = getattr(message, "session_id", None)
|
|
|
|
|
expected_prefix = (
|
|
|
|
|
f"tool-results/{message_session_id}/"
|
|
|
|
|
if message_session_id is not None
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
tool_call_id = metadata.get("tool_call_id")
|
|
|
|
|
is_legacy_path = isinstance(
|
|
|
|
|
tool_call_id, str
|
|
|
|
|
) and storage_path.endswith(f"/{tool_call_id}.json")
|
|
|
|
|
if (
|
|
|
|
|
storage_bucket == expected_bucket
|
|
|
|
|
and _is_safe_storage_path(storage_path)
|
|
|
|
|
and (
|
|
|
|
|
(
|
|
|
|
|
expected_prefix is not None
|
|
|
|
|
and storage_path.startswith(expected_prefix)
|
|
|
|
|
)
|
|
|
|
|
or (
|
|
|
|
|
storage_path.startswith("tool-results/")
|
|
|
|
|
and is_legacy_path
|
|
|
|
|
)
|
2026-03-08 17:07:09 +08:00
|
|
|
)
|
2026-03-12 09:29:57 +08:00
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
hydrated_content = (
|
|
|
|
|
await self._tool_result_storage.read_json(
|
|
|
|
|
bucket=storage_bucket,
|
|
|
|
|
path=storage_path,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
hydrated_content = None
|
2026-03-08 17:07:09 +08:00
|
|
|
|
|
|
|
|
resolved_content = hydrated_content or parsed_content
|
2026-03-12 09:29:57 +08:00
|
|
|
payload["content"] = message.content
|
2026-03-08 17:07:09 +08:00
|
|
|
if resolved_content is not None:
|
|
|
|
|
ui = resolved_content.get("ui")
|
2026-03-12 09:29:57 +08:00
|
|
|
if not isinstance(ui, dict):
|
|
|
|
|
ui = resolved_content.get("ui_schema")
|
2026-03-07 17:30:20 +08:00
|
|
|
if isinstance(ui, dict):
|
|
|
|
|
payload["ui"] = ui
|
2026-03-08 17:07:09 +08:00
|
|
|
display_content = resolved_content.get("content")
|
2026-03-12 09:29:57 +08:00
|
|
|
if not isinstance(display_content, str):
|
|
|
|
|
nested_result = resolved_content.get("result")
|
|
|
|
|
if isinstance(nested_result, dict):
|
|
|
|
|
nested_content = nested_result.get("content")
|
|
|
|
|
if isinstance(nested_content, str):
|
|
|
|
|
display_content = nested_content
|
|
|
|
|
if (
|
|
|
|
|
isinstance(display_content, str)
|
|
|
|
|
and display_content.strip()
|
|
|
|
|
and (
|
|
|
|
|
not payload["content"]
|
|
|
|
|
or _looks_like_offloaded_placeholder(str(payload["content"]))
|
|
|
|
|
)
|
|
|
|
|
):
|
2026-03-07 17:30:20 +08:00
|
|
|
payload["content"] = display_content
|
|
|
|
|
else:
|
|
|
|
|
payload["content"] = message.content
|
2026-03-11 21:06:02 +08:00
|
|
|
metadata = message.metadata_json or {}
|
|
|
|
|
attachments = (
|
|
|
|
|
metadata.get("attachments") if isinstance(metadata, dict) else None
|
|
|
|
|
)
|
|
|
|
|
if isinstance(attachments, list):
|
2026-03-12 09:29:57 +08:00
|
|
|
rendered: list[dict[str, object]] = []
|
|
|
|
|
for index, item in enumerate(attachments):
|
|
|
|
|
if not isinstance(item, dict):
|
|
|
|
|
continue
|
|
|
|
|
mime_type = item.get("mimeType")
|
|
|
|
|
if not isinstance(mime_type, str) or not mime_type:
|
|
|
|
|
continue
|
|
|
|
|
rendered.append(
|
|
|
|
|
{
|
|
|
|
|
"mimeType": mime_type,
|
|
|
|
|
"previewPath": (
|
|
|
|
|
f"/api/v1/agent/runs/{message.session_id}/attachments/"
|
|
|
|
|
f"{message.id}/{index}"
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
)
|
2026-03-11 21:06:02 +08:00
|
|
|
if rendered:
|
|
|
|
|
payload["attachments"] = rendered
|
2026-03-07 17:30:20 +08:00
|
|
|
return payload
|
2026-03-12 00:18:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _has_title(title: object) -> bool:
|
|
|
|
|
return isinstance(title, str) and bool(title.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _derive_session_title(content_text: str) -> str | None:
|
|
|
|
|
normalized = " ".join(content_text.split())
|
|
|
|
|
if not normalized:
|
|
|
|
|
return None
|
|
|
|
|
return normalized[:80]
|
2026-03-12 09:29:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_safe_storage_path(path: str) -> bool:
|
|
|
|
|
normalized = path.strip()
|
|
|
|
|
if not normalized:
|
|
|
|
|
return False
|
|
|
|
|
if normalized.startswith("/"):
|
|
|
|
|
return False
|
|
|
|
|
if ".." in normalized:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _looks_like_offloaded_placeholder(content: str) -> bool:
|
|
|
|
|
normalized = content.strip().lower()
|
|
|
|
|
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
|