from __future__ import annotations from datetime import datetime, timezone from decimal import Decimal from uuid import UUID from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus class MessageRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def append_message( self, *, session_id: UUID, seq: int, role: AgentChatMessageRole, content: str, model_code: str | None = None, tool_name: str | None = None, metadata: dict[str, object] | None = None, input_tokens: int = 0, output_tokens: int = 0, cost: Decimal = Decimal("0"), latency_ms: int | None = None, visibility_mask: int = 0, ) -> AgentChatMessage: message = AgentChatMessage( session_id=session_id, seq=seq, role=role, content=content, model_code=model_code, tool_name=tool_name, metadata_json=metadata, input_tokens=input_tokens, output_tokens=output_tokens, cost=cost, latency_ms=latency_ms, visibility_mask=max(int(visibility_mask), 0), ) self._session.add(message) await self._session.flush() return message class SessionRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def get_session(self, *, session_id: UUID) -> AgentChatSession | None: return await self._session.get(AgentChatSession, session_id) async def lock_session_for_update( self, *, session_id: UUID ) -> AgentChatSession | None: stmt = ( select(AgentChatSession) .where(AgentChatSession.id == session_id) .with_for_update() ) return (await self._session.execute(stmt)).scalar_one_or_none() async def next_message_seq(self, *, session_id: UUID) -> int: stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where( AgentChatMessage.session_id == session_id ) current = (await self._session.execute(stmt)).scalar_one() return int(current) + 1 async def update_runtime_state( self, *, chat_session: AgentChatSession, status: AgentChatSessionStatus, state_snapshot: dict[str, object], message_delta: int, token_delta: int = 0, cost_delta: Decimal = Decimal("0"), ) -> None: chat_session.status = status chat_session.state_snapshot = state_snapshot chat_session.last_activity_at = datetime.now(timezone.utc) chat_session.message_count += message_delta chat_session.total_tokens += token_delta chat_session.total_cost += cost_delta await self._session.flush()