366 lines
14 KiB
Python
366 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import date, datetime, time, timedelta, timezone
|
|
import json
|
|
from typing import Protocol
|
|
from uuid import UUID
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.config.settings import config
|
|
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
|
from models.agent_chat_session import AgentChatSession
|
|
from services.base.supabase import supabase_service
|
|
|
|
|
|
class ToolResultPayloadStorage(Protocol):
|
|
async def read_json(
|
|
self, *, bucket: str, path: str
|
|
) -> dict[str, object] | None: ...
|
|
|
|
|
|
class AgentRepository:
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
*,
|
|
tool_result_storage: ToolResultPayloadStorage | None = None,
|
|
) -> None:
|
|
self._session = session
|
|
self._tool_result_storage = tool_result_storage
|
|
|
|
async def get_session_owner(self, *, session_id: str) -> str:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
stmt = select(AgentChatSession.user_id).where(
|
|
AgentChatSession.id == session_uuid
|
|
)
|
|
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if owner_id is None:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return str(owner_id)
|
|
|
|
async def create_session_for_user(
|
|
self, *, user_id: str, session_id: str | None = None
|
|
) -> str:
|
|
try:
|
|
user_uuid = UUID(user_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
|
session_uuid = None
|
|
if session_id is not None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(
|
|
status_code=422, detail="Invalid session_id"
|
|
) from exc
|
|
|
|
session = AgentChatSession(
|
|
id=session_uuid,
|
|
user_id=user_uuid,
|
|
)
|
|
self._session.add(session)
|
|
await self._session.flush()
|
|
await self._session.refresh(session)
|
|
return str(session.id)
|
|
|
|
async def commit(self) -> None:
|
|
await self._session.commit()
|
|
|
|
async def rollback(self) -> None:
|
|
await self._session.rollback()
|
|
|
|
async def delete_session(self, *, session_id: str) -> None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
session = await self._session.get(AgentChatSession, session_uuid)
|
|
if session is not None:
|
|
await self._session.delete(session)
|
|
await self._session.flush()
|
|
|
|
async def persist_user_message(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
run_id: str,
|
|
content_text: str,
|
|
metadata: dict[str, object] | None,
|
|
) -> None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
stmt = (
|
|
select(AgentChatSession)
|
|
.where(AgentChatSession.id == session_uuid)
|
|
.with_for_update()
|
|
)
|
|
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if session_row is None:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
next_seq = int(session_row.message_count or 0) + 1
|
|
if not _has_title(session_row.title):
|
|
session_title = _derive_session_title(content_text)
|
|
if session_title is not None:
|
|
session_row.title = session_title
|
|
payload_metadata = dict(metadata or {})
|
|
payload_metadata["run_id"] = run_id
|
|
message = AgentChatMessage(
|
|
session_id=session_uuid,
|
|
seq=next_seq,
|
|
role=AgentChatMessageRole.USER,
|
|
content=content_text,
|
|
metadata_json=payload_metadata,
|
|
)
|
|
self._session.add(message)
|
|
session_row.message_count = next_seq
|
|
session_row.last_activity_at = datetime.now(timezone.utc)
|
|
await self._session.flush()
|
|
|
|
async def get_history_day(
|
|
self, *, session_id: str, before: date | None
|
|
) -> dict[str, object] | None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
timestamp_stmt = (
|
|
select(AgentChatMessage.created_at)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.order_by(AgentChatMessage.created_at.desc())
|
|
)
|
|
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
|
|
unique_days: list[date] = []
|
|
for created_at in rows:
|
|
if created_at is None:
|
|
continue
|
|
day = created_at.astimezone(timezone.utc).date()
|
|
if day not in unique_days:
|
|
unique_days.append(day)
|
|
|
|
if not unique_days:
|
|
return None
|
|
|
|
target_day: date | None = None
|
|
if before is None:
|
|
target_day = unique_days[0]
|
|
else:
|
|
for day in unique_days:
|
|
if day < before:
|
|
target_day = day
|
|
break
|
|
if target_day is None:
|
|
return None
|
|
|
|
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
|
end = start + timedelta(days=1)
|
|
message_stmt = (
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.where(AgentChatMessage.created_at >= start)
|
|
.where(AgentChatMessage.created_at < end)
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
)
|
|
messages = (await self._session.execute(message_stmt)).scalars().all()
|
|
has_more = any(day < target_day for day in unique_days)
|
|
snapshot_messages: list[dict[str, object]] = []
|
|
for message in messages:
|
|
snapshot_messages.append(await self._to_snapshot_message(message))
|
|
return {
|
|
"day": target_day.isoformat(),
|
|
"hasMore": has_more,
|
|
"messages": snapshot_messages,
|
|
}
|
|
|
|
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
|
try:
|
|
user_uuid = UUID(user_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
|
stmt = (
|
|
select(AgentChatSession.id)
|
|
.where(AgentChatSession.user_id == user_uuid)
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
.order_by(AgentChatSession.last_activity_at.desc())
|
|
.limit(1)
|
|
)
|
|
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if latest_id is None:
|
|
return None
|
|
return str(latest_id)
|
|
|
|
async def _to_snapshot_message(
|
|
self, message: AgentChatMessage
|
|
) -> dict[str, object]:
|
|
role = (
|
|
message.role.value
|
|
if isinstance(message.role, AgentChatMessageRole)
|
|
else str(message.role)
|
|
)
|
|
payload: dict[str, object] = {
|
|
"id": str(message.id),
|
|
"role": role,
|
|
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
|
}
|
|
|
|
if role == AgentChatMessageRole.TOOL.value:
|
|
metadata = message.metadata_json or {}
|
|
tool_call_id = metadata.get("tool_call_id")
|
|
if isinstance(tool_call_id, str) and tool_call_id:
|
|
payload["toolCallId"] = tool_call_id
|
|
|
|
parsed_content: dict[str, object] | None = None
|
|
try:
|
|
decoded = json.loads(message.content)
|
|
if isinstance(decoded, dict):
|
|
parsed_content = decoded
|
|
except (TypeError, ValueError):
|
|
parsed_content = None
|
|
|
|
hydrated_content: dict[str, object] | None = None
|
|
if self._tool_result_storage is not None:
|
|
storage_bucket = metadata.get("storage_bucket")
|
|
storage_path = metadata.get("storage_path")
|
|
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
|
expected_bucket = config.storage.bucket
|
|
message_session_id = getattr(message, "session_id", None)
|
|
expected_prefix = (
|
|
f"tool-results/{message_session_id}/"
|
|
if message_session_id is not None
|
|
else None
|
|
)
|
|
tool_call_id = metadata.get("tool_call_id")
|
|
is_legacy_path = isinstance(
|
|
tool_call_id, str
|
|
) and storage_path.endswith(f"/{tool_call_id}.json")
|
|
if (
|
|
storage_bucket == expected_bucket
|
|
and _is_safe_storage_path(storage_path)
|
|
and (
|
|
(
|
|
expected_prefix is not None
|
|
and storage_path.startswith(expected_prefix)
|
|
)
|
|
or (
|
|
storage_path.startswith("tool-results/")
|
|
and is_legacy_path
|
|
)
|
|
)
|
|
):
|
|
try:
|
|
hydrated_content = (
|
|
await self._tool_result_storage.read_json(
|
|
bucket=storage_bucket,
|
|
path=storage_path,
|
|
)
|
|
)
|
|
except Exception:
|
|
hydrated_content = None
|
|
|
|
resolved_content = hydrated_content or parsed_content
|
|
payload["content"] = message.content
|
|
if resolved_content is not None:
|
|
ui = resolved_content.get("ui")
|
|
if not isinstance(ui, dict):
|
|
ui = resolved_content.get("ui_schema")
|
|
if isinstance(ui, dict):
|
|
payload["ui"] = ui
|
|
display_content = resolved_content.get("content")
|
|
if not isinstance(display_content, str):
|
|
nested_result = resolved_content.get("result")
|
|
if isinstance(nested_result, dict):
|
|
nested_content = nested_result.get("content")
|
|
if isinstance(nested_content, str):
|
|
display_content = nested_content
|
|
if (
|
|
isinstance(display_content, str)
|
|
and display_content.strip()
|
|
and (
|
|
not payload["content"]
|
|
or _looks_like_offloaded_placeholder(str(payload["content"]))
|
|
)
|
|
):
|
|
payload["content"] = display_content
|
|
else:
|
|
payload["content"] = message.content
|
|
|
|
if role == AgentChatMessageRole.USER.value:
|
|
metadata = message.metadata_json or {}
|
|
user_attachments = metadata.get("user_message_attachments")
|
|
if isinstance(user_attachments, dict):
|
|
bucket = user_attachments.get("bucket")
|
|
path = user_attachments.get("path")
|
|
mime_type = user_attachments.get("mime_type")
|
|
if (
|
|
isinstance(bucket, str)
|
|
and isinstance(path, str)
|
|
and isinstance(mime_type, str)
|
|
):
|
|
try:
|
|
signed_url = await supabase_service.create_signed_url(
|
|
bucket=bucket,
|
|
path=path,
|
|
expires_in_seconds=3600,
|
|
)
|
|
attachment_block = {
|
|
"type": "binary",
|
|
"mimeType": mime_type,
|
|
"url": signed_url,
|
|
}
|
|
existing_content = message.content
|
|
if (
|
|
isinstance(existing_content, str)
|
|
and existing_content.strip()
|
|
):
|
|
content_blocks = [
|
|
{"type": "text", "text": existing_content}
|
|
]
|
|
content_blocks.append(attachment_block)
|
|
payload["content"] = content_blocks
|
|
else:
|
|
payload["content"] = [attachment_block]
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
|
|
return payload
|
|
|
|
|
|
def _has_title(title: object) -> bool:
|
|
return isinstance(title, str) and bool(title.strip())
|
|
|
|
|
|
def _derive_session_title(content_text: str) -> str | None:
|
|
normalized = " ".join(content_text.split())
|
|
if not normalized:
|
|
return None
|
|
return normalized[:80]
|
|
|
|
|
|
def _is_safe_storage_path(path: str) -> bool:
|
|
normalized = path.strip()
|
|
if not normalized:
|
|
return False
|
|
if normalized.startswith("/"):
|
|
return False
|
|
if ".." in normalized:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _looks_like_offloaded_placeholder(content: str) -> bool:
|
|
normalized = content.strip().lower()
|
|
return normalized in {'{"offloaded":true}', '{"offloaded": true}'}
|