2026-03-05 15:34:37 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
from datetime import date, datetime, time, timedelta, timezone
|
|
|
|
|
import json
|
2026-03-08 17:07:09 +08:00
|
|
|
from typing import Protocol
|
2026-03-05 15:34:37 +08:00
|
|
|
from uuid import UUID
|
|
|
|
|
|
|
|
|
|
from fastapi import HTTPException
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
2026-03-05 15:34:37 +08:00
|
|
|
from models.agent_chat_session import AgentChatSession
|
|
|
|
|
|
|
|
|
|
|
2026-03-08 17:07:09 +08:00
|
|
|
class ToolResultPayloadStorage(Protocol):
|
|
|
|
|
async def read_json(
|
|
|
|
|
self, *, bucket: str, path: str
|
|
|
|
|
) -> dict[str, object] | None: ...
|
|
|
|
|
|
|
|
|
|
|
2026-03-05 15:34:37 +08:00
|
|
|
class AgentRepository:
|
2026-03-08 17:07:09 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
session: AsyncSession,
|
|
|
|
|
*,
|
|
|
|
|
tool_result_storage: ToolResultPayloadStorage | None = None,
|
|
|
|
|
) -> None:
|
2026-03-05 15:34:37 +08:00
|
|
|
self._session = session
|
2026-03-08 17:07:09 +08:00
|
|
|
self._tool_result_storage = tool_result_storage
|
2026-03-05 15:34:37 +08:00
|
|
|
|
|
|
|
|
async def get_session_owner(self, *, session_id: str) -> str:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
|
|
|
|
|
stmt = select(AgentChatSession.user_id).where(
|
|
|
|
|
AgentChatSession.id == session_uuid
|
|
|
|
|
)
|
|
|
|
|
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if owner_id is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
return str(owner_id)
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
async def create_session_for_user(
|
|
|
|
|
self, *, user_id: str, session_id: str | None = None
|
|
|
|
|
) -> str:
|
2026-03-05 15:34:37 +08:00
|
|
|
try:
|
|
|
|
|
user_uuid = UUID(user_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
2026-03-07 17:30:20 +08:00
|
|
|
session_uuid = None
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
2026-03-08 17:07:09 +08:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=422, detail="Invalid session_id"
|
|
|
|
|
) from exc
|
2026-03-05 15:34:37 +08:00
|
|
|
|
2026-03-06 12:02:10 +08:00
|
|
|
session = AgentChatSession(
|
2026-03-07 17:30:20 +08:00
|
|
|
id=session_uuid,
|
2026-03-06 12:02:10 +08:00
|
|
|
user_id=user_uuid,
|
|
|
|
|
)
|
2026-03-05 15:34:37 +08:00
|
|
|
self._session.add(session)
|
|
|
|
|
await self._session.flush()
|
|
|
|
|
await self._session.refresh(session)
|
|
|
|
|
return str(session.id)
|
|
|
|
|
|
|
|
|
|
async def commit(self) -> None:
|
|
|
|
|
await self._session.commit()
|
|
|
|
|
|
|
|
|
|
async def rollback(self) -> None:
|
|
|
|
|
await self._session.rollback()
|
|
|
|
|
|
|
|
|
|
async def delete_session(self, *, session_id: str) -> None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
session = await self._session.get(AgentChatSession, session_uuid)
|
|
|
|
|
if session is not None:
|
|
|
|
|
await self._session.delete(session)
|
|
|
|
|
await self._session.flush()
|
2026-03-07 17:30:20 +08:00
|
|
|
|
2026-03-11 21:06:02 +08:00
|
|
|
async def persist_user_message(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
session_id: str,
|
|
|
|
|
run_id: str,
|
|
|
|
|
content_text: str,
|
|
|
|
|
metadata: dict[str, object] | None,
|
|
|
|
|
) -> None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
|
|
|
|
|
stmt = (
|
|
|
|
|
select(AgentChatSession)
|
|
|
|
|
.where(AgentChatSession.id == session_uuid)
|
|
|
|
|
.with_for_update()
|
|
|
|
|
)
|
|
|
|
|
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if session_row is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
|
|
|
|
|
next_seq = int(session_row.message_count or 0) + 1
|
2026-03-12 00:18:45 +08:00
|
|
|
if not _has_title(session_row.title):
|
|
|
|
|
session_title = _derive_session_title(content_text)
|
|
|
|
|
if session_title is not None:
|
|
|
|
|
session_row.title = session_title
|
2026-03-11 21:06:02 +08:00
|
|
|
payload_metadata = dict(metadata or {})
|
|
|
|
|
payload_metadata["run_id"] = run_id
|
|
|
|
|
message = AgentChatMessage(
|
|
|
|
|
session_id=session_uuid,
|
|
|
|
|
seq=next_seq,
|
|
|
|
|
role=AgentChatMessageRole.USER,
|
|
|
|
|
content=content_text,
|
|
|
|
|
metadata_json=payload_metadata,
|
|
|
|
|
)
|
|
|
|
|
self._session.add(message)
|
|
|
|
|
session_row.message_count = next_seq
|
|
|
|
|
session_row.last_activity_at = datetime.now(timezone.utc)
|
|
|
|
|
await self._session.flush()
|
|
|
|
|
|
2026-03-07 17:30:20 +08:00
|
|
|
async def get_history_day(
|
|
|
|
|
self, *, session_id: str, before: date | None
|
|
|
|
|
) -> dict[str, object] | None:
|
|
|
|
|
try:
|
|
|
|
|
session_uuid = UUID(session_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
|
|
|
|
|
|
|
|
|
|
timestamp_stmt = (
|
|
|
|
|
select(AgentChatMessage.created_at)
|
|
|
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
|
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
|
|
|
.order_by(AgentChatMessage.created_at.desc())
|
|
|
|
|
)
|
|
|
|
|
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
|
|
|
|
|
unique_days: list[date] = []
|
|
|
|
|
for created_at in rows:
|
|
|
|
|
if created_at is None:
|
|
|
|
|
continue
|
|
|
|
|
day = created_at.astimezone(timezone.utc).date()
|
|
|
|
|
if day not in unique_days:
|
|
|
|
|
unique_days.append(day)
|
|
|
|
|
|
|
|
|
|
if not unique_days:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
target_day: date | None = None
|
|
|
|
|
if before is None:
|
|
|
|
|
target_day = unique_days[0]
|
|
|
|
|
else:
|
|
|
|
|
for day in unique_days:
|
|
|
|
|
if day < before:
|
|
|
|
|
target_day = day
|
|
|
|
|
break
|
|
|
|
|
if target_day is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
|
|
|
|
end = start + timedelta(days=1)
|
|
|
|
|
message_stmt = (
|
|
|
|
|
select(AgentChatMessage)
|
|
|
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
|
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
|
|
|
.where(AgentChatMessage.created_at >= start)
|
|
|
|
|
.where(AgentChatMessage.created_at < end)
|
|
|
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
|
|
|
)
|
|
|
|
|
messages = (await self._session.execute(message_stmt)).scalars().all()
|
|
|
|
|
has_more = any(day < target_day for day in unique_days)
|
2026-03-08 17:07:09 +08:00
|
|
|
snapshot_messages: list[dict[str, object]] = []
|
|
|
|
|
for message in messages:
|
|
|
|
|
snapshot_messages.append(await self._to_snapshot_message(message))
|
2026-03-07 17:30:20 +08:00
|
|
|
return {
|
|
|
|
|
"day": target_day.isoformat(),
|
|
|
|
|
"hasMore": has_more,
|
2026-03-08 17:07:09 +08:00
|
|
|
"messages": snapshot_messages,
|
2026-03-07 17:30:20 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
|
|
|
|
try:
|
|
|
|
|
user_uuid = UUID(user_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
|
|
|
|
|
stmt = (
|
|
|
|
|
select(AgentChatSession.id)
|
|
|
|
|
.where(AgentChatSession.user_id == user_uuid)
|
|
|
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
|
|
|
.order_by(AgentChatSession.last_activity_at.desc())
|
|
|
|
|
.limit(1)
|
|
|
|
|
)
|
|
|
|
|
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
if latest_id is None:
|
|
|
|
|
return None
|
|
|
|
|
return str(latest_id)
|
|
|
|
|
|
2026-03-08 17:07:09 +08:00
|
|
|
async def _to_snapshot_message(
|
|
|
|
|
self, message: AgentChatMessage
|
|
|
|
|
) -> dict[str, object]:
|
2026-03-07 17:30:20 +08:00
|
|
|
role = (
|
|
|
|
|
message.role.value
|
|
|
|
|
if isinstance(message.role, AgentChatMessageRole)
|
|
|
|
|
else str(message.role)
|
|
|
|
|
)
|
|
|
|
|
payload: dict[str, object] = {
|
|
|
|
|
"id": str(message.id),
|
|
|
|
|
"role": role,
|
|
|
|
|
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if role == AgentChatMessageRole.TOOL.value:
|
|
|
|
|
metadata = message.metadata_json or {}
|
|
|
|
|
tool_call_id = metadata.get("tool_call_id")
|
|
|
|
|
if isinstance(tool_call_id, str) and tool_call_id:
|
|
|
|
|
payload["toolCallId"] = tool_call_id
|
|
|
|
|
|
|
|
|
|
parsed_content: dict[str, object] | None = None
|
|
|
|
|
try:
|
|
|
|
|
decoded = json.loads(message.content)
|
|
|
|
|
if isinstance(decoded, dict):
|
|
|
|
|
parsed_content = decoded
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
parsed_content = None
|
2026-03-08 17:07:09 +08:00
|
|
|
|
|
|
|
|
hydrated_content: dict[str, object] | None = None
|
|
|
|
|
if self._tool_result_storage is not None:
|
|
|
|
|
storage_bucket = metadata.get("storage_bucket")
|
|
|
|
|
storage_path = metadata.get("storage_path")
|
|
|
|
|
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
|
|
|
|
|
try:
|
|
|
|
|
hydrated_content = await self._tool_result_storage.read_json(
|
|
|
|
|
bucket=storage_bucket,
|
|
|
|
|
path=storage_path,
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
hydrated_content = None
|
|
|
|
|
|
|
|
|
|
resolved_content = hydrated_content or parsed_content
|
|
|
|
|
if resolved_content is not None:
|
|
|
|
|
result = resolved_content.get("result")
|
|
|
|
|
if isinstance(result, dict):
|
|
|
|
|
result_content = result.get("content")
|
|
|
|
|
if isinstance(result_content, str):
|
|
|
|
|
payload["content"] = result_content
|
|
|
|
|
ui = resolved_content.get("ui")
|
2026-03-07 17:30:20 +08:00
|
|
|
if isinstance(ui, dict):
|
|
|
|
|
payload["ui"] = ui
|
2026-03-08 17:07:09 +08:00
|
|
|
display_content = resolved_content.get("content")
|
2026-03-07 17:30:20 +08:00
|
|
|
if isinstance(display_content, str):
|
|
|
|
|
payload["content"] = display_content
|
2026-03-08 17:07:09 +08:00
|
|
|
|
|
|
|
|
if "content" not in payload:
|
2026-03-07 17:30:20 +08:00
|
|
|
payload["content"] = message.content
|
|
|
|
|
else:
|
|
|
|
|
payload["content"] = message.content
|
2026-03-11 21:06:02 +08:00
|
|
|
metadata = message.metadata_json or {}
|
|
|
|
|
attachments = (
|
|
|
|
|
metadata.get("attachments") if isinstance(metadata, dict) else None
|
|
|
|
|
)
|
|
|
|
|
if isinstance(attachments, list):
|
|
|
|
|
rendered = [item for item in attachments if isinstance(item, dict)]
|
|
|
|
|
if rendered:
|
|
|
|
|
payload["attachments"] = rendered
|
2026-03-07 17:30:20 +08:00
|
|
|
return payload
|
2026-03-12 00:18:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _has_title(title: object) -> bool:
|
|
|
|
|
return isinstance(title, str) and bool(title.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _derive_session_title(content_text: str) -> str | None:
|
|
|
|
|
normalized = " ".join(content_text.split())
|
|
|
|
|
if not normalized:
|
|
|
|
|
return None
|
|
|
|
|
return normalized[:80]
|