Files
social-app/backend/src/v1/agent/repository.py
T

222 lines
8.1 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
import json
from typing import Protocol
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository:
def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session = session
self._tool_result_storage = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
stmt = select(AgentChatSession.user_id).where(
AgentChatSession.id == session_uuid
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise HTTPException(status_code=404, detail="Session not found")
return str(owner_id)
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="Invalid session_id"
) from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
session = await self._session.get(AgentChatSession, session_uuid)
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def get_history_day(
self, *, session_id: str, before: date | None
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
timestamp_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
)
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
unique_days: list[date] = []
for created_at in rows:
if created_at is None:
continue
day = created_at.astimezone(timezone.utc).date()
if day not in unique_days:
unique_days.append(day)
if not unique_days:
return None
target_day: date | None = None
if before is None:
target_day = unique_days[0]
else:
for day in unique_days:
if day < before:
target_day = day
break
if target_day is None:
return None
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": snapshot_messages,
}
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid user_id") from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload: dict[str, object] = {
"id": str(message.id),
"role": role,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
if role == AgentChatMessageRole.TOOL.value:
metadata = message.metadata_json or {}
tool_call_id = metadata.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
payload["toolCallId"] = tool_call_id
parsed_content: dict[str, object] | None = None
try:
decoded = json.loads(message.content)
if isinstance(decoded, dict):
parsed_content = decoded
except (TypeError, ValueError):
parsed_content = None
hydrated_content: dict[str, object] | None = None
if self._tool_result_storage is not None:
storage_bucket = metadata.get("storage_bucket")
storage_path = metadata.get("storage_path")
if isinstance(storage_bucket, str) and isinstance(storage_path, str):
try:
hydrated_content = await self._tool_result_storage.read_json(
bucket=storage_bucket,
path=storage_path,
)
except Exception:
hydrated_content = None
resolved_content = hydrated_content or parsed_content
if resolved_content is not None:
result = resolved_content.get("result")
if isinstance(result, dict):
result_content = result.get("content")
if isinstance(result_content, str):
payload["content"] = result_content
ui = resolved_content.get("ui")
if isinstance(ui, dict):
payload["ui"] = ui
display_content = resolved_content.get("content")
if isinstance(display_content, str):
payload["content"] = display_content
if "content" not in payload:
payload["content"] = message.content
else:
payload["content"] = message.content
return payload