Files
eryao/backend/src/v1/agent/repository.py
T
qzl 1e22f27de2 feat: integrate invite API and improve notification handling
- Add invite code display and binding functionality via API
- Fix notification unread count sync on auth state change
- Improve notification mark read with server state validation
- Add auth state listener to trigger notification refresh
- Add YaoCoinConverter for coin-to-yao type conversion
- Remove YaoLegend from divination screens (UI cleanup)
- Abbreviate relation labels in yao detail view
- Add re-register notice to account delete screen
- Update 'coins' terminology to 'points' in localization
- Fix backend points consumption to only run in CHAT mode
- Add HttpxAuthNoiseFilter to suppress auth endpoint logging
- Fix notification static_schema import path
- Add test coverage for notification bloc error handling
- Update AGENTS.md page header rules and image handling
- Delete deprecated run-dev.sh script
2026-04-13 14:52:22 +08:00

529 lines
18 KiB
Python

from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from typing import Any, Protocol
from uuid import UUID, uuid4
from sqlalchemy import Select, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from core.http.errors import ApiProblemError
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.system_agents import SystemAgents
from schemas.enums import AgentChatMessageRole
from schemas.domain.chat_message import (
AgentChatMessage as AgentChatMessageSchema,
AgentChatMessageMetadata,
)
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository:
def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session: AsyncSession = session
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession.user_id)
.where(AgentChatSession.id == session_uuid)
.where(AgentChatSession.deleted_at.is_(None))
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
return str(owner_id)
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_uuid)
.with_for_update()
)
session = (await self._session.execute(stmt)).scalar_one_or_none()
if session is None:
return
if session.deleted_at is not None:
return
session.deleted_at = datetime.now(timezone.utc)
await self._session.flush()
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_uuid)
.with_for_update()
)
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
if session_row is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content)
if session_title is not None:
session_row.title = session_title
message = OrmAgentChatMessage(
id=uuid4(),
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content,
visibility_mask=max(int(visibility_mask), 0),
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
session_row.message_count = next_seq
session_row.last_activity_at = datetime.now(timezone.utc)
await self._session.flush()
async def get_assistant_message_count(self, *, session_id: str) -> int:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(func.count(AgentChatMessage.id))
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT)
)
count = (await self._session.execute(stmt)).scalar_one()
return int(count)
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
before_start = (
datetime.combine(before, time.min, tzinfo=timezone.utc)
if before is not None
else None
)
target_created_at_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
target_created_at_stmt = self._apply_visibility_filter(
stmt=target_created_at_stmt,
visibility_mask=visibility_mask,
)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
)
target_created_at = (
await self._session.execute(target_created_at_stmt)
).scalar_one_or_none()
if target_created_at is None:
return None
target_day = target_created_at.astimezone(timezone.utc).date()
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more_stmt = (
select(AgentChatMessage.id)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more_stmt = self._apply_visibility_filter(
stmt=has_more_stmt,
visibility_mask=visibility_mask,
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(
(await self._to_chat_message_schema(message)).model_dump(
mode="json", by_alias=True, exclude_none=True
)
)
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": snapshot_messages,
}
async def get_session_messages(
self,
*,
session_id: str,
visibility_mask: int | None = None,
) -> list[AgentChatMessageSchema]:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
snapshot_messages: list[AgentChatMessageSchema] = []
for message in messages:
snapshot_messages.append(await self._to_chat_message_schema(message))
return snapshot_messages
async def get_recent_messages_by_user_window(
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
safe_user_limit = max(int(user_message_limit), 1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.desc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
if not messages_desc:
return []
selected_desc: list[AgentChatMessage] = []
user_count = 0
for message in messages_desc:
selected_desc.append(message)
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
if role == AgentChatMessageRole.USER.value:
user_count += 1
if user_count >= safe_user_limit:
break
selected = list(reversed(selected_desc))
snapshot_messages: list[dict[str, object]] = []
for message in selected:
snapshot_messages.append(
(await self._to_chat_message_schema(message)).model_dump(
mode="json", by_alias=True, exclude_none=True
)
)
return snapshot_messages
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
async def get_latest_assistant_messages_by_user_sessions(
self,
*,
user_id: str,
visibility_mask: int | None = None,
session_limit: int = 50,
) -> list[AgentChatMessageSchema]:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
safe_limit = max(int(session_limit), 1)
session_stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(safe_limit)
)
session_ids = (await self._session.execute(session_stmt)).scalars().all()
if not session_ids:
return []
snapshots: list[AgentChatMessageSchema] = []
for session_id in session_ids:
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_id)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT)
.order_by(AgentChatMessage.created_at.desc())
.limit(20)
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
candidate_messages = (
(await self._session.execute(message_stmt)).scalars().all()
)
if not candidate_messages:
continue
selected_snapshot: AgentChatMessageSchema | None = None
for message in candidate_messages:
snapshot = await self._to_chat_message_schema(message)
metadata = (
snapshot.metadata.model_dump(mode="json", exclude_none=True)
if snapshot.metadata is not None
else None
)
if not isinstance(metadata, dict):
continue
agent_output = metadata.get("agent_output")
if not isinstance(agent_output, dict):
continue
derived = agent_output.get("divination_derived")
if isinstance(derived, dict) and derived:
selected_snapshot = snapshot
break
if selected_snapshot is not None:
snapshots.append(selected_snapshot)
snapshots.sort(
key=lambda item: str(item.timestamp),
reverse=True,
)
return snapshots
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None:
normalized_type = agent_type.strip().lower()
if not normalized_type:
return None
stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type)
row = (await self._session.execute(stmt)).scalar_one_or_none()
if row is None:
return None
config_payload = row.config if isinstance(row.config, dict) else {}
return {
"agent_type": normalized_type,
"status": str(row.status),
"config": config_payload,
}
async def _to_chat_message_schema(
self, message: AgentChatMessage
) -> AgentChatMessageSchema:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload_model = AgentChatMessageSchema.model_validate(
{
"id": str(message.id),
"session_id": str(message.session_id),
"seq": int(message.seq),
"role": role,
"content": message.content,
"model_code": message.model_code,
"tool_name": message.tool_name,
"input_tokens": int(message.input_tokens or 0),
"output_tokens": int(message.output_tokens or 0),
"cost": str(message.cost if message.cost is not None else Decimal("0")),
"latency_ms": message.latency_ms,
"metadata": message.metadata_json,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
)
return payload_model
def _apply_visibility_filter(
self,
*,
stmt: Select[Any],
visibility_mask: int | None,
) -> Select[Any]:
if visibility_mask is None:
return stmt
required_mask = max(int(visibility_mask), 0)
if required_mask == 0:
return stmt
return stmt.where(
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
)
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
def _derive_session_title(content_text: str) -> str | None:
normalized = " ".join(content_text.split())
if not normalized:
return None
return normalized[:80]