0abf51e837
- Add consumer_registry and pipeline_registry for runtime orchestration - Add Visibility schema for message filtering - Add PipelineSpec for agent pipeline configuration - Add automation job models and configuration - Remove memory_prompt.py (consolidated into memory system) - Update runtime components: context_loader, context_service, orchestrator, runner, tasks - Update toolkit: tool_config, tool_middleware, custom tools (calendar, user_lookup) - Add auth_helpers and calendar_domain utilities - Add system_agents.yaml configuration
94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
|
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
|
|
|
|
|
class MessageRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def append_message(
|
|
self,
|
|
*,
|
|
session_id: UUID,
|
|
seq: int,
|
|
role: AgentChatMessageRole,
|
|
content: str,
|
|
model_code: str | None = None,
|
|
tool_name: str | None = None,
|
|
metadata: dict[str, object] | None = None,
|
|
input_tokens: int = 0,
|
|
output_tokens: int = 0,
|
|
cost: Decimal = Decimal("0"),
|
|
latency_ms: int | None = None,
|
|
visibility_mask: int = 0,
|
|
) -> AgentChatMessage:
|
|
message = AgentChatMessage(
|
|
session_id=session_id,
|
|
seq=seq,
|
|
role=role,
|
|
content=content,
|
|
model_code=model_code,
|
|
tool_name=tool_name,
|
|
metadata_json=metadata,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
cost=cost,
|
|
latency_ms=latency_ms,
|
|
visibility_mask=max(int(visibility_mask), 0),
|
|
)
|
|
self._session.add(message)
|
|
await self._session.flush()
|
|
return message
|
|
|
|
|
|
class SessionRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def get_session(self, *, session_id: UUID) -> AgentChatSession | None:
|
|
return await self._session.get(AgentChatSession, session_id)
|
|
|
|
async def lock_session_for_update(
|
|
self, *, session_id: UUID
|
|
) -> AgentChatSession | None:
|
|
stmt = (
|
|
select(AgentChatSession)
|
|
.where(AgentChatSession.id == session_id)
|
|
.with_for_update()
|
|
)
|
|
return (await self._session.execute(stmt)).scalar_one_or_none()
|
|
|
|
async def next_message_seq(self, *, session_id: UUID) -> int:
|
|
stmt = select(func.coalesce(func.max(AgentChatMessage.seq), 0)).where(
|
|
AgentChatMessage.session_id == session_id
|
|
)
|
|
current = (await self._session.execute(stmt)).scalar_one()
|
|
return int(current) + 1
|
|
|
|
async def update_runtime_state(
|
|
self,
|
|
*,
|
|
chat_session: AgentChatSession,
|
|
status: AgentChatSessionStatus,
|
|
state_snapshot: dict[str, object],
|
|
message_delta: int,
|
|
token_delta: int = 0,
|
|
cost_delta: Decimal = Decimal("0"),
|
|
) -> None:
|
|
chat_session.status = status
|
|
chat_session.state_snapshot = state_snapshot
|
|
chat_session.last_activity_at = datetime.now(timezone.utc)
|
|
chat_session.message_count += message_delta
|
|
chat_session.total_tokens += token_delta
|
|
chat_session.total_cost += cost_delta
|
|
await self._session.flush()
|