Files
social-app/backend/src/core/agentscope/events/persistence.py
T

94 lines
3.0 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
class MessageRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_message(
self,
*,
session_id: UUID,
seq: int,
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
tool_name: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
latency_ms: int | None = None,
visibility_mask: int = 0,
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
seq=seq,
role=role,
content=content,
model_code=model_code,
tool_name=tool_name,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
visibility_mask=max(int(visibility_mask), 0),
)
self._session.add(message)
await self._session.flush()
return message
class SessionRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
return await self._session.get(AgentChatSession, session_id)
async def lock_session_for_update(
self, *, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
)
return (await self._session.execute(stmt)).scalar_one_or_none()
async def next_message_seq(self, *, session_id: UUID) -> int:
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
AgentChatMessage.session_id == session_id
)
current = (await self._session.execute(stmt)).scalar_one()
return int(current) + 1
async def update_runtime_state(
self,
*,
chat_session: AgentChatSession,
status: AgentChatSessionStatus,
state_snapshot: dict[str, object],
message_delta: int,
token_delta: int = 0,
cost_delta: Decimal = Decimal("0"),
) -> None:
chat_session.status = status
chat_session.state_snapshot = state_snapshot
chat_session.last_activity_at = datetime.now(timezone.utc)
chat_session.message_count += message_delta
chat_session.total_tokens += token_delta
chat_session.total_cost += cost_delta
await self._session.flush()