Files
eryao/backend/src/v1/agent/repository.py
T

517 lines
18 KiB
Python
Raw Normal View History

2026-04-03 16:56:47 +08:00
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from typing import Any, Protocol
2026-04-03 16:56:47 +08:00
from uuid import UUID, uuid4
from sqlalchemy import Select, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from core.http.errors import ApiProblemError
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.system_agents import SystemAgents
from schemas.enums import AgentChatMessageRole
from schemas.domain.chat_message import (
AgentChatMessage as AgentChatMessageSchema,
AgentChatMessageMetadata,
)
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository:
def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session: AsyncSession = session
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession.user_id)
.where(AgentChatSession.id == session_uuid)
.where(AgentChatSession.deleted_at.is_(None))
2026-04-03 16:56:47 +08:00
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
return str(owner_id)
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_uuid)
.with_for_update()
)
session = (await self._session.execute(stmt)).scalar_one_or_none()
if session is None:
return
if session.deleted_at is not None:
return
session.deleted_at = datetime.now(timezone.utc)
await self._session.flush()
2026-04-03 16:56:47 +08:00
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_uuid)
.with_for_update()
)
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
if session_row is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content)
if session_title is not None:
session_row.title = session_title
message = OrmAgentChatMessage(
id=uuid4(),
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content,
visibility_mask=max(int(visibility_mask), 0),
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
session_row.message_count = next_seq
session_row.last_activity_at = datetime.now(timezone.utc)
await self._session.flush()
async def get_user_message_count(self, *, session_id: str) -> int:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(func.count(AgentChatMessage.id))
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.role == AgentChatMessageRole.USER)
)
count = (await self._session.execute(stmt)).scalar_one()
return int(count)
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
before_start = (
datetime.combine(before, time.min, tzinfo=timezone.utc)
if before is not None
else None
)
target_created_at_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
target_created_at_stmt = self._apply_visibility_filter(
stmt=target_created_at_stmt,
visibility_mask=visibility_mask,
)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
)
target_created_at = (
await self._session.execute(target_created_at_stmt)
).scalar_one_or_none()
if target_created_at is None:
return None
target_day = target_created_at.astimezone(timezone.utc).date()
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more_stmt = (
select(AgentChatMessage.id)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more_stmt = self._apply_visibility_filter(
stmt=has_more_stmt,
visibility_mask=visibility_mask,
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": snapshot_messages,
}
async def get_session_messages(
self,
*,
session_id: str,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return snapshot_messages
2026-04-03 16:56:47 +08:00
async def get_recent_messages_by_user_window(
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
safe_user_limit = max(int(user_message_limit), 1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.desc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
if not messages_desc:
return []
selected_desc: list[AgentChatMessage] = []
user_count = 0
for message in messages_desc:
selected_desc.append(message)
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
if role == AgentChatMessageRole.USER.value:
user_count += 1
if user_count >= safe_user_limit:
break
selected = list(reversed(selected_desc))
snapshot_messages: list[dict[str, object]] = []
for message in selected:
snapshot_messages.append(await self._to_snapshot_message(message))
return snapshot_messages
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
async def get_latest_assistant_messages_by_user_sessions(
self,
*,
user_id: str,
visibility_mask: int | None = None,
session_limit: int = 50,
) -> list[dict[str, object]]:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
safe_limit = max(int(session_limit), 1)
session_stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(safe_limit)
)
session_ids = (await self._session.execute(session_stmt)).scalars().all()
if not session_ids:
return []
snapshots: list[dict[str, object]] = []
for session_id in session_ids:
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_id)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT)
.order_by(AgentChatMessage.created_at.desc())
.limit(20)
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
candidate_messages = (
(await self._session.execute(message_stmt)).scalars().all()
)
if not candidate_messages:
continue
selected_snapshot: dict[str, object] | None = None
for message in candidate_messages:
snapshot = await self._to_snapshot_message(message)
metadata = snapshot.get("metadata")
if not isinstance(metadata, dict):
continue
agent_output = metadata.get("agent_output")
if not isinstance(agent_output, dict):
continue
derived = agent_output.get("divination_derived")
if isinstance(derived, dict) and derived:
selected_snapshot = snapshot
break
if selected_snapshot is not None:
snapshots.append(selected_snapshot)
snapshots.sort(
key=lambda item: str(item.get("timestamp") or ""),
reverse=True,
)
return snapshots
2026-04-03 16:56:47 +08:00
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None:
normalized_type = agent_type.strip().lower()
if not normalized_type:
return None
stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type)
row = (await self._session.execute(stmt)).scalar_one_or_none()
if row is None:
return None
config_payload = row.config if isinstance(row.config, dict) else {}
return {
"agent_type": normalized_type,
"status": str(row.status),
"config": config_payload,
}
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload_model = AgentChatMessageSchema.model_validate(
{
"id": str(message.id),
"session_id": str(message.session_id),
2026-04-03 16:56:47 +08:00
"seq": int(message.seq),
"role": role,
"content": message.content,
"model_code": message.model_code,
"tool_name": message.tool_name,
"input_tokens": int(message.input_tokens or 0),
"output_tokens": int(message.output_tokens or 0),
"cost": str(message.cost if message.cost is not None else Decimal("0")),
"latency_ms": message.latency_ms,
"metadata": message.metadata_json,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
)
return payload_model.model_dump(mode="json", exclude_none=True)
def _apply_visibility_filter(
self,
*,
stmt: Select[Any],
2026-04-03 16:56:47 +08:00
visibility_mask: int | None,
) -> Select[Any]:
2026-04-03 16:56:47 +08:00
if visibility_mask is None:
return stmt
required_mask = max(int(visibility_mask), 0)
if required_mask == 0:
return stmt
return stmt.where(
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
)
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
def _derive_session_title(content_text: str) -> str | None:
normalized = " ".join(content_text.split())
if not normalized:
return None
return normalized[:80]