from __future__ import annotations from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from typing import Any, Protocol from uuid import UUID, uuid4 from sqlalchemy import Select, func, select from sqlalchemy.ext.asyncio import AsyncSession from core.http.errors import ApiProblemError from models.agent_chat_message import AgentChatMessage from models.agent_chat_session import AgentChatSession from models.system_agents import SystemAgents from schemas.enums import AgentChatMessageRole from schemas.domain.chat_message import ( AgentChatMessage as AgentChatMessageSchema, AgentChatMessageMetadata, ) class ToolResultPayloadStorage(Protocol): async def read_json( self, *, bucket: str, path: str ) -> dict[str, object] | None: ... class AgentRepository: def __init__( self, session: AsyncSession, *, tool_result_storage: ToolResultPayloadStorage | None = None, ) -> None: self._session: AsyncSession = session self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage async def get_session_owner(self, *, session_id: str) -> str: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc stmt = ( select(AgentChatSession.user_id) .where(AgentChatSession.id == session_uuid) .where(AgentChatSession.deleted_at.is_(None)) ) owner_id = (await self._session.execute(stmt)).scalar_one_or_none() if owner_id is None: raise ApiProblemError( status_code=404, code="AGENT_SESSION_NOT_FOUND", detail="Session not found", ) return str(owner_id) async def create_session_for_user( self, *, user_id: str, session_id: str | None = None ) -> str: try: user_uuid = UUID(user_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_USER_ID_INVALID", detail="Invalid user_id", ) from exc session_uuid = None if session_id is not None: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc session = AgentChatSession( id=session_uuid, user_id=user_uuid, ) self._session.add(session) await self._session.flush() await self._session.refresh(session) return str(session.id) async def commit(self) -> None: await self._session.commit() async def rollback(self) -> None: await self._session.rollback() async def delete_session(self, *, session_id: str) -> None: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc stmt = ( select(AgentChatSession) .where(AgentChatSession.id == session_uuid) .with_for_update() ) session = (await self._session.execute(stmt)).scalar_one_or_none() if session is None: return if session.deleted_at is not None: return session.deleted_at = datetime.now(timezone.utc) await self._session.flush() async def persist_user_message( self, *, session_id: str, content: str, metadata: AgentChatMessageMetadata | None, visibility_mask: int, ) -> None: from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc stmt = ( select(AgentChatSession) .where(AgentChatSession.id == session_uuid) .with_for_update() ) session_row = (await self._session.execute(stmt)).scalar_one_or_none() if session_row is None: raise ApiProblemError( status_code=404, code="AGENT_SESSION_NOT_FOUND", detail="Session not found", ) next_seq = int(session_row.message_count or 0) + 1 if not _has_title(session_row.title): session_title = _derive_session_title(content) if session_title is not None: session_row.title = session_title message = OrmAgentChatMessage( id=uuid4(), session_id=session_uuid, seq=next_seq, role=AgentChatMessageRole.USER, content=content, visibility_mask=max(int(visibility_mask), 0), metadata_json=metadata.model_dump(by_alias=True) if metadata else None, ) self._session.add(message) session_row.message_count = next_seq session_row.last_activity_at = datetime.now(timezone.utc) await self._session.flush() async def get_assistant_message_count(self, *, session_id: str) -> int: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc stmt = ( select(func.count(AgentChatMessage.id)) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT) ) count = (await self._session.execute(stmt)).scalar_one() return int(count) async def get_history_day( self, *, session_id: str, before: date | None, visibility_mask: int | None = None, ) -> dict[str, object] | None: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc before_start = ( datetime.combine(before, time.min, tzinfo=timezone.utc) if before is not None else None ) target_created_at_stmt = ( select(AgentChatMessage.created_at) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .order_by(AgentChatMessage.created_at.desc()) .limit(1) ) target_created_at_stmt = self._apply_visibility_filter( stmt=target_created_at_stmt, visibility_mask=visibility_mask, ) if before_start is not None: target_created_at_stmt = target_created_at_stmt.where( AgentChatMessage.created_at < before_start ) target_created_at = ( await self._session.execute(target_created_at_stmt) ).scalar_one_or_none() if target_created_at is None: return None target_day = target_created_at.astimezone(timezone.utc).date() start = datetime.combine(target_day, time.min, tzinfo=timezone.utc) end = start + timedelta(days=1) message_stmt = ( select(AgentChatMessage) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .where(AgentChatMessage.created_at >= start) .where(AgentChatMessage.created_at < end) .order_by(AgentChatMessage.seq.asc()) ) message_stmt = self._apply_visibility_filter( stmt=message_stmt, visibility_mask=visibility_mask, ) messages = (await self._session.execute(message_stmt)).scalars().all() has_more_stmt = ( select(AgentChatMessage.id) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .where(AgentChatMessage.created_at < start) .limit(1) ) has_more_stmt = self._apply_visibility_filter( stmt=has_more_stmt, visibility_mask=visibility_mask, ) has_more = ( await self._session.execute(has_more_stmt) ).scalar_one_or_none() is not None snapshot_messages: list[dict[str, object]] = [] for message in messages: snapshot_messages.append( (await self._to_chat_message_schema(message)).model_dump( mode="json", by_alias=True, exclude_none=True ) ) return { "day": target_day.isoformat(), "hasMore": has_more, "messages": snapshot_messages, } async def get_session_messages( self, *, session_id: str, visibility_mask: int | None = None, ) -> list[AgentChatMessageSchema]: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc message_stmt = ( select(AgentChatMessage) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .order_by(AgentChatMessage.seq.asc()) ) message_stmt = self._apply_visibility_filter( stmt=message_stmt, visibility_mask=visibility_mask, ) messages = (await self._session.execute(message_stmt)).scalars().all() snapshot_messages: list[AgentChatMessageSchema] = [] for message in messages: snapshot_messages.append(await self._to_chat_message_schema(message)) return snapshot_messages async def get_recent_messages_by_user_window( self, *, session_id: str, user_message_limit: int, visibility_mask: int | None = None, ) -> list[dict[str, object]]: try: session_uuid = UUID(session_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_SESSION_ID_INVALID", detail="Invalid session_id", ) from exc safe_user_limit = max(int(user_message_limit), 1) message_stmt = ( select(AgentChatMessage) .where(AgentChatMessage.session_id == session_uuid) .where(AgentChatMessage.deleted_at.is_(None)) .order_by(AgentChatMessage.seq.desc()) ) message_stmt = self._apply_visibility_filter( stmt=message_stmt, visibility_mask=visibility_mask, ) messages_desc = (await self._session.execute(message_stmt)).scalars().all() if not messages_desc: return [] selected_desc: list[AgentChatMessage] = [] user_count = 0 for message in messages_desc: selected_desc.append(message) role = ( message.role.value if isinstance(message.role, AgentChatMessageRole) else str(message.role) ) if role == AgentChatMessageRole.USER.value: user_count += 1 if user_count >= safe_user_limit: break selected = list(reversed(selected_desc)) snapshot_messages: list[dict[str, object]] = [] for message in selected: snapshot_messages.append( (await self._to_chat_message_schema(message)).model_dump( mode="json", by_alias=True, exclude_none=True ) ) return snapshot_messages async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: try: user_uuid = UUID(user_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_USER_ID_INVALID", detail="Invalid user_id", ) from exc stmt = ( select(AgentChatSession.id) .where(AgentChatSession.user_id == user_uuid) .where(AgentChatSession.deleted_at.is_(None)) .order_by(AgentChatSession.last_activity_at.desc()) .limit(1) ) latest_id = (await self._session.execute(stmt)).scalar_one_or_none() if latest_id is None: return None return str(latest_id) async def get_latest_assistant_messages_by_user_sessions( self, *, user_id: str, visibility_mask: int | None = None, session_limit: int = 50, ) -> list[AgentChatMessageSchema]: try: user_uuid = UUID(user_id) except ValueError as exc: raise ApiProblemError( status_code=422, code="AGENT_USER_ID_INVALID", detail="Invalid user_id", ) from exc safe_limit = max(int(session_limit), 1) session_stmt = ( select(AgentChatSession.id) .where(AgentChatSession.user_id == user_uuid) .where(AgentChatSession.deleted_at.is_(None)) .order_by(AgentChatSession.last_activity_at.desc()) .limit(safe_limit) ) session_ids = (await self._session.execute(session_stmt)).scalars().all() if not session_ids: return [] snapshots: list[AgentChatMessageSchema] = [] for session_id in session_ids: message_stmt = ( select(AgentChatMessage) .where(AgentChatMessage.session_id == session_id) .where(AgentChatMessage.deleted_at.is_(None)) .where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT) .order_by(AgentChatMessage.created_at.desc()) .limit(20) ) message_stmt = self._apply_visibility_filter( stmt=message_stmt, visibility_mask=visibility_mask, ) candidate_messages = ( (await self._session.execute(message_stmt)).scalars().all() ) if not candidate_messages: continue selected_snapshot: AgentChatMessageSchema | None = None for message in candidate_messages: snapshot = await self._to_chat_message_schema(message) metadata = ( snapshot.metadata.model_dump(mode="json", exclude_none=True) if snapshot.metadata is not None else None ) if not isinstance(metadata, dict): continue agent_output = metadata.get("agent_output") if not isinstance(agent_output, dict): continue derived = agent_output.get("divination_derived") if isinstance(derived, dict) and derived: selected_snapshot = snapshot break if selected_snapshot is not None: snapshots.append(selected_snapshot) snapshots.sort( key=lambda item: str(item.timestamp), reverse=True, ) return snapshots async def get_system_agent_config( self, *, agent_type: str ) -> dict[str, object] | None: normalized_type = agent_type.strip().lower() if not normalized_type: return None stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type) row = (await self._session.execute(stmt)).scalar_one_or_none() if row is None: return None config_payload = row.config if isinstance(row.config, dict) else {} return { "agent_type": normalized_type, "status": str(row.status), "config": config_payload, } async def _to_chat_message_schema( self, message: AgentChatMessage ) -> AgentChatMessageSchema: role = ( message.role.value if isinstance(message.role, AgentChatMessageRole) else str(message.role) ) payload_model = AgentChatMessageSchema.model_validate( { "id": str(message.id), "session_id": str(message.session_id), "seq": int(message.seq), "role": role, "content": message.content, "model_code": message.model_code, "tool_name": message.tool_name, "input_tokens": int(message.input_tokens or 0), "output_tokens": int(message.output_tokens or 0), "cost": str(message.cost if message.cost is not None else Decimal("0")), "latency_ms": message.latency_ms, "metadata": message.metadata_json, "timestamp": message.created_at.astimezone(timezone.utc).isoformat(), } ) return payload_model def _apply_visibility_filter( self, *, stmt: Select[Any], visibility_mask: int | None, ) -> Select[Any]: if visibility_mask is None: return stmt required_mask = max(int(visibility_mask), 0) if required_mask == 0: return stmt return stmt.where( (AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0 ) def _has_title(title: object) -> bool: return isinstance(title, str) and bool(title.strip()) def _derive_session_title(content_text: str) -> str | None: normalized = " ".join(content_text.split()) if not normalized: return None return normalized[:80]