from __future__ import annotations from datetime import date, datetime, time, timedelta, timezone import json from typing import Protocol from uuid import UUID from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from core.config.settings import config from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole from models.agent_chat_session import AgentChatSession from services.base.supabase import supabase_service class ToolResultPayloadStorage(Protocol): async def read_json( self, *, bucket: str, path: str ) -> dict[str, object] | None: ... class AgentRepository: def __init__( self, session: AsyncSession, *, tool_result_storage: ToolResultPayloadStorage | None = None, ) -> None: self._session = session self._tool_result_storage = tool_result_storage async def get_session_owner(self, *, session_id: str) -> str: try: session_uuid = UUID(session_id) except ValueError as exc: raise HTTPException(status_code=422, detail="Invalid session_id") from exc stmt = select(AgentChatSession.user_id).where( AgentChatSession.id == session_uuid ) owner_id = (await self._session.execute(stmt)).scalar_one_or_none() if owner_id is None: raise HTTPException(status_code=404, detail="Session not found") return str(owner_id) async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: try: user_uuid = UUID(user_id) except ValueError as exc: raise HTTPException(status_code=422, detail="Invalid user_id") from exc session_uuid = None if session_id is not None: try: session_uuid = UUID(session_id) except ValueError as exc: raise HTTPException( status_code=422, detail="Invalid session_id" ) from exc session = AgentChatSession( id=session_uuid, user_id=user_uuid, ) self._session.add(session) await self._session.flush() await self._session.refresh(session) return str(session.id) async def commit(self) -> None: await self._session.commit() async def rollback(self) -> None: await self._session.rollback() async def delete_session(self, *, session_id: str) -> None: try: session_uuid = UUID(session_id) except ValueError as exc: raise HTTPException(status_code=422, detail="Invalid session_id") from exc session = await self._session.get(AgentChatSession, session_uuid) if session is not None: await self._session.delete(session) await self._session.flush() async def persist_user_message( self, *, session_id: str, run_id: str, content_text: str, metadata: dict[str, object] | None, ) -> None: try: session_uuid = UUID(session_id) except ValueError as exc: raise HTTPException(status_code=422, detail="Invalid session_id") from exc stmt = ( select(AgentChatSession) .where(AgentChatSession.id == session_uuid) .with_for_update() ) session_row = (await self._session.execute(stmt)).scalar_one_or_none() if session_row is None: raise HTTPException(status_code=404, detail="Session not found") next_seq = int(session_row.message_count or 0) + 1 if not _has_title(session_row.title): session_title = _derive_session_title(content_text) if session_title is not None: session_row.title = session_title payload_metadata = dict(metadata or {}) payload_metadata["run_id"] = run_id message = AgentChatMessage( session_id=session_uuid, seq=next_seq, role=AgentChatMessageRole.USER, content=content_text, metadata_json=payload_metadata, ) self._session.add(message) session_row.message_count = next_seq session_row.last_activity_at = datetime.now(timezone.utc) await self._session.flush() async def get_history_day( self, *, session_id: str, before: date | None ) -> dict[str, object] | None: try: session_uuid = UUID(session_id) except ValueError as exc: raise HTTPException(status_code=422, detail="Invalid session_id") from exc timestamp_stmt = ( select(AgentChatMessage.created_at) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .order_by(AgentChatMessage.created_at.desc()) ) rows = (await self._session.execute(timestamp_stmt)).scalars().all() unique_days: list[date] = [] for created_at in rows: if created_at is None: continue day = created_at.astimezone(timezone.utc).date() if day not in unique_days: unique_days.append(day) if not unique_days: return None target_day: date | None = None if before is None: target_day = unique_days[0] else: for day in unique_days: if day < before: target_day = day break if target_day is None: return None start = datetime.combine(target_day, time.min, tzinfo=timezone.utc) end = start + timedelta(days=1) message_stmt = ( select(AgentChatMessage) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .where(AgentChatMessage.created_at >= start) .where(AgentChatMessage.created_at < end) .order_by(AgentChatMessage.seq.asc()) ) messages = (await self._session.execute(message_stmt)).scalars().all() has_more = any(day < target_day for day in unique_days) snapshot_messages: list[dict[str, object]] = [] for message in messages: snapshot_messages.append(await self._to_snapshot_message(message)) return { "day": target_day.isoformat(), "hasMore": has_more, "messages": snapshot_messages, } async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: try: user_uuid = UUID(user_id) except ValueError as exc: raise HTTPException(status_code=422, detail="Invalid user_id") from exc stmt = ( select(AgentChatSession.id) .where(AgentChatSession.user_id == user_uuid) .where(AgentChatSession.deleted_at.is_(None)) .order_by(AgentChatSession.last_activity_at.desc()) .limit(1) ) latest_id = (await self._session.execute(stmt)).scalar_one_or_none() if latest_id is None: return None return str(latest_id) async def _to_snapshot_message( self, message: AgentChatMessage ) -> dict[str, object]: role = ( message.role.value if isinstance(message.role, AgentChatMessageRole) else str(message.role) ) payload: dict[str, object] = { "id": str(message.id), "role": role, "timestamp": message.created_at.astimezone(timezone.utc).isoformat(), } if role == AgentChatMessageRole.TOOL.value: metadata = message.metadata_json or {} tool_call_id = metadata.get("tool_call_id") if isinstance(tool_call_id, str) and tool_call_id: payload["toolCallId"] = tool_call_id parsed_content: dict[str, object] | None = None try: decoded = json.loads(message.content) if isinstance(decoded, dict): parsed_content = decoded except (TypeError, ValueError): parsed_content = None hydrated_content: dict[str, object] | None = None if self._tool_result_storage is not None: storage_bucket = metadata.get("storage_bucket") storage_path = metadata.get("storage_path") if isinstance(storage_bucket, str) and isinstance(storage_path, str): expected_bucket = config.storage.bucket message_session_id = getattr(message, "session_id", None) expected_prefix = ( f"tool-results/{message_session_id}/" if message_session_id is not None else None ) tool_call_id = metadata.get("tool_call_id") is_legacy_path = isinstance( tool_call_id, str ) and storage_path.endswith(f"/{tool_call_id}.json") if ( storage_bucket == expected_bucket and _is_safe_storage_path(storage_path) and ( ( expected_prefix is not None and storage_path.startswith(expected_prefix) ) or ( storage_path.startswith("tool-results/") and is_legacy_path ) ) ): try: hydrated_content = ( await self._tool_result_storage.read_json( bucket=storage_bucket, path=storage_path, ) ) except Exception: hydrated_content = None resolved_content = hydrated_content or parsed_content payload["content"] = message.content if resolved_content is not None: ui = resolved_content.get("ui") if not isinstance(ui, dict): ui = resolved_content.get("ui_schema") if isinstance(ui, dict): payload["ui"] = ui display_content = resolved_content.get("content") if not isinstance(display_content, str): nested_result = resolved_content.get("result") if isinstance(nested_result, dict): nested_content = nested_result.get("content") if isinstance(nested_content, str): display_content = nested_content if ( isinstance(display_content, str) and display_content.strip() and ( not payload["content"] or _looks_like_offloaded_placeholder(str(payload["content"])) ) ): payload["content"] = display_content else: payload["content"] = message.content if role == AgentChatMessageRole.USER.value: metadata = message.metadata_json or {} user_attachments = metadata.get("user_message_attachments") if isinstance(user_attachments, dict): bucket = user_attachments.get("bucket") path = user_attachments.get("path") mime_type = user_attachments.get("mime_type") if ( isinstance(bucket, str) and isinstance(path, str) and isinstance(mime_type, str) ): try: signed_url = await supabase_service.create_signed_url( bucket=bucket, path=path, expires_in_seconds=3600, ) attachment_block = { "type": "binary", "mimeType": mime_type, "url": signed_url, } existing_content = message.content if ( isinstance(existing_content, str) and existing_content.strip() ): content_blocks = [ {"type": "text", "text": existing_content} ] content_blocks.append(attachment_block) payload["content"] = content_blocks else: payload["content"] = [attachment_block] except Exception: # noqa: BLE001 pass return payload def _has_title(title: object) -> bool: return isinstance(title, str) and bool(title.strip()) def _derive_session_title(content_text: str) -> str | None: normalized = " ".join(content_text.split()) if not normalized: return None return normalized[:80] def _is_safe_storage_path(path: str) -> bool: normalized = path.strip() if not normalized: return False if normalized.startswith("/"): return False if ".." in normalized: return False return True def _looks_like_offloaded_placeholder(content: str) -> bool: normalized = content.strip().lower() return normalized in {'{"offloaded":true}', '{"offloaded": true}'}