1e22f27de2
- Add invite code display and binding functionality via API - Fix notification unread count sync on auth state change - Improve notification mark read with server state validation - Add auth state listener to trigger notification refresh - Add YaoCoinConverter for coin-to-yao type conversion - Remove YaoLegend from divination screens (UI cleanup) - Abbreviate relation labels in yao detail view - Add re-register notice to account delete screen - Update 'coins' terminology to 'points' in localization - Fix backend points consumption to only run in CHAT mode - Add HttpxAuthNoiseFilter to suppress auth endpoint logging - Fix notification static_schema import path - Add test coverage for notification bloc error handling - Update AGENTS.md page header rules and image handling - Delete deprecated run-dev.sh script
529 lines
18 KiB
Python
529 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import date, datetime, time, timedelta, timezone
|
|
from decimal import Decimal
|
|
from typing import Any, Protocol
|
|
from uuid import UUID, uuid4
|
|
|
|
from sqlalchemy import Select, func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.http.errors import ApiProblemError
|
|
from models.agent_chat_message import AgentChatMessage
|
|
from models.agent_chat_session import AgentChatSession
|
|
from models.system_agents import SystemAgents
|
|
from schemas.enums import AgentChatMessageRole
|
|
from schemas.domain.chat_message import (
|
|
AgentChatMessage as AgentChatMessageSchema,
|
|
AgentChatMessageMetadata,
|
|
)
|
|
|
|
|
|
class ToolResultPayloadStorage(Protocol):
|
|
async def read_json(
|
|
self, *, bucket: str, path: str
|
|
) -> dict[str, object] | None: ...
|
|
|
|
|
|
class AgentRepository:
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
*,
|
|
tool_result_storage: ToolResultPayloadStorage | None = None,
|
|
) -> None:
|
|
self._session: AsyncSession = session
|
|
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
|
|
|
|
async def get_session_owner(self, *, session_id: str) -> str:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
stmt = (
|
|
select(AgentChatSession.user_id)
|
|
.where(AgentChatSession.id == session_uuid)
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
)
|
|
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if owner_id is None:
|
|
raise ApiProblemError(
|
|
status_code=404,
|
|
code="AGENT_SESSION_NOT_FOUND",
|
|
detail="Session not found",
|
|
)
|
|
return str(owner_id)
|
|
|
|
async def create_session_for_user(
|
|
self, *, user_id: str, session_id: str | None = None
|
|
) -> str:
|
|
try:
|
|
user_uuid = UUID(user_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_USER_ID_INVALID",
|
|
detail="Invalid user_id",
|
|
) from exc
|
|
session_uuid = None
|
|
if session_id is not None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
session = AgentChatSession(
|
|
id=session_uuid,
|
|
user_id=user_uuid,
|
|
)
|
|
self._session.add(session)
|
|
await self._session.flush()
|
|
await self._session.refresh(session)
|
|
return str(session.id)
|
|
|
|
async def commit(self) -> None:
|
|
await self._session.commit()
|
|
|
|
async def rollback(self) -> None:
|
|
await self._session.rollback()
|
|
|
|
async def delete_session(self, *, session_id: str) -> None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
stmt = (
|
|
select(AgentChatSession)
|
|
.where(AgentChatSession.id == session_uuid)
|
|
.with_for_update()
|
|
)
|
|
session = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if session is None:
|
|
return
|
|
if session.deleted_at is not None:
|
|
return
|
|
session.deleted_at = datetime.now(timezone.utc)
|
|
await self._session.flush()
|
|
|
|
async def persist_user_message(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
content: str,
|
|
metadata: AgentChatMessageMetadata | None,
|
|
visibility_mask: int,
|
|
) -> None:
|
|
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
|
|
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
stmt = (
|
|
select(AgentChatSession)
|
|
.where(AgentChatSession.id == session_uuid)
|
|
.with_for_update()
|
|
)
|
|
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if session_row is None:
|
|
raise ApiProblemError(
|
|
status_code=404,
|
|
code="AGENT_SESSION_NOT_FOUND",
|
|
detail="Session not found",
|
|
)
|
|
|
|
next_seq = int(session_row.message_count or 0) + 1
|
|
if not _has_title(session_row.title):
|
|
session_title = _derive_session_title(content)
|
|
if session_title is not None:
|
|
session_row.title = session_title
|
|
|
|
message = OrmAgentChatMessage(
|
|
id=uuid4(),
|
|
session_id=session_uuid,
|
|
seq=next_seq,
|
|
role=AgentChatMessageRole.USER,
|
|
content=content,
|
|
visibility_mask=max(int(visibility_mask), 0),
|
|
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
|
|
)
|
|
self._session.add(message)
|
|
session_row.message_count = next_seq
|
|
session_row.last_activity_at = datetime.now(timezone.utc)
|
|
await self._session.flush()
|
|
|
|
async def get_assistant_message_count(self, *, session_id: str) -> int:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
stmt = (
|
|
select(func.count(AgentChatMessage.id))
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT)
|
|
)
|
|
count = (await self._session.execute(stmt)).scalar_one()
|
|
return int(count)
|
|
|
|
async def get_history_day(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
before: date | None,
|
|
visibility_mask: int | None = None,
|
|
) -> dict[str, object] | None:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
before_start = (
|
|
datetime.combine(before, time.min, tzinfo=timezone.utc)
|
|
if before is not None
|
|
else None
|
|
)
|
|
|
|
target_created_at_stmt = (
|
|
select(AgentChatMessage.created_at)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.order_by(AgentChatMessage.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
target_created_at_stmt = self._apply_visibility_filter(
|
|
stmt=target_created_at_stmt,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
if before_start is not None:
|
|
target_created_at_stmt = target_created_at_stmt.where(
|
|
AgentChatMessage.created_at < before_start
|
|
)
|
|
target_created_at = (
|
|
await self._session.execute(target_created_at_stmt)
|
|
).scalar_one_or_none()
|
|
|
|
if target_created_at is None:
|
|
return None
|
|
|
|
target_day = target_created_at.astimezone(timezone.utc).date()
|
|
|
|
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
|
end = start + timedelta(days=1)
|
|
message_stmt = (
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.where(AgentChatMessage.created_at >= start)
|
|
.where(AgentChatMessage.created_at < end)
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
)
|
|
message_stmt = self._apply_visibility_filter(
|
|
stmt=message_stmt,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
messages = (await self._session.execute(message_stmt)).scalars().all()
|
|
has_more_stmt = (
|
|
select(AgentChatMessage.id)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.where(AgentChatMessage.created_at < start)
|
|
.limit(1)
|
|
)
|
|
has_more_stmt = self._apply_visibility_filter(
|
|
stmt=has_more_stmt,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
has_more = (
|
|
await self._session.execute(has_more_stmt)
|
|
).scalar_one_or_none() is not None
|
|
snapshot_messages: list[dict[str, object]] = []
|
|
for message in messages:
|
|
snapshot_messages.append(
|
|
(await self._to_chat_message_schema(message)).model_dump(
|
|
mode="json", by_alias=True, exclude_none=True
|
|
)
|
|
)
|
|
return {
|
|
"day": target_day.isoformat(),
|
|
"hasMore": has_more,
|
|
"messages": snapshot_messages,
|
|
}
|
|
|
|
async def get_session_messages(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
visibility_mask: int | None = None,
|
|
) -> list[AgentChatMessageSchema]:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
message_stmt = (
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
)
|
|
message_stmt = self._apply_visibility_filter(
|
|
stmt=message_stmt,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
messages = (await self._session.execute(message_stmt)).scalars().all()
|
|
snapshot_messages: list[AgentChatMessageSchema] = []
|
|
for message in messages:
|
|
snapshot_messages.append(await self._to_chat_message_schema(message))
|
|
return snapshot_messages
|
|
|
|
async def get_recent_messages_by_user_window(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
user_message_limit: int,
|
|
visibility_mask: int | None = None,
|
|
) -> list[dict[str, object]]:
|
|
try:
|
|
session_uuid = UUID(session_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_SESSION_ID_INVALID",
|
|
detail="Invalid session_id",
|
|
) from exc
|
|
|
|
safe_user_limit = max(int(user_message_limit), 1)
|
|
message_stmt = (
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.order_by(AgentChatMessage.seq.desc())
|
|
)
|
|
message_stmt = self._apply_visibility_filter(
|
|
stmt=message_stmt,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
|
|
if not messages_desc:
|
|
return []
|
|
|
|
selected_desc: list[AgentChatMessage] = []
|
|
user_count = 0
|
|
for message in messages_desc:
|
|
selected_desc.append(message)
|
|
role = (
|
|
message.role.value
|
|
if isinstance(message.role, AgentChatMessageRole)
|
|
else str(message.role)
|
|
)
|
|
if role == AgentChatMessageRole.USER.value:
|
|
user_count += 1
|
|
if user_count >= safe_user_limit:
|
|
break
|
|
|
|
selected = list(reversed(selected_desc))
|
|
snapshot_messages: list[dict[str, object]] = []
|
|
for message in selected:
|
|
snapshot_messages.append(
|
|
(await self._to_chat_message_schema(message)).model_dump(
|
|
mode="json", by_alias=True, exclude_none=True
|
|
)
|
|
)
|
|
return snapshot_messages
|
|
|
|
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
|
try:
|
|
user_uuid = UUID(user_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_USER_ID_INVALID",
|
|
detail="Invalid user_id",
|
|
) from exc
|
|
stmt = (
|
|
select(AgentChatSession.id)
|
|
.where(AgentChatSession.user_id == user_uuid)
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
.order_by(AgentChatSession.last_activity_at.desc())
|
|
.limit(1)
|
|
)
|
|
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if latest_id is None:
|
|
return None
|
|
return str(latest_id)
|
|
|
|
async def get_latest_assistant_messages_by_user_sessions(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
visibility_mask: int | None = None,
|
|
session_limit: int = 50,
|
|
) -> list[AgentChatMessageSchema]:
|
|
try:
|
|
user_uuid = UUID(user_id)
|
|
except ValueError as exc:
|
|
raise ApiProblemError(
|
|
status_code=422,
|
|
code="AGENT_USER_ID_INVALID",
|
|
detail="Invalid user_id",
|
|
) from exc
|
|
|
|
safe_limit = max(int(session_limit), 1)
|
|
session_stmt = (
|
|
select(AgentChatSession.id)
|
|
.where(AgentChatSession.user_id == user_uuid)
|
|
.where(AgentChatSession.deleted_at.is_(None))
|
|
.order_by(AgentChatSession.last_activity_at.desc())
|
|
.limit(safe_limit)
|
|
)
|
|
session_ids = (await self._session.execute(session_stmt)).scalars().all()
|
|
if not session_ids:
|
|
return []
|
|
|
|
snapshots: list[AgentChatMessageSchema] = []
|
|
for session_id in session_ids:
|
|
message_stmt = (
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == session_id)
|
|
.where(AgentChatMessage.deleted_at.is_(None))
|
|
.where(AgentChatMessage.role == AgentChatMessageRole.ASSISTANT)
|
|
.order_by(AgentChatMessage.created_at.desc())
|
|
.limit(20)
|
|
)
|
|
message_stmt = self._apply_visibility_filter(
|
|
stmt=message_stmt,
|
|
visibility_mask=visibility_mask,
|
|
)
|
|
candidate_messages = (
|
|
(await self._session.execute(message_stmt)).scalars().all()
|
|
)
|
|
if not candidate_messages:
|
|
continue
|
|
selected_snapshot: AgentChatMessageSchema | None = None
|
|
for message in candidate_messages:
|
|
snapshot = await self._to_chat_message_schema(message)
|
|
metadata = (
|
|
snapshot.metadata.model_dump(mode="json", exclude_none=True)
|
|
if snapshot.metadata is not None
|
|
else None
|
|
)
|
|
if not isinstance(metadata, dict):
|
|
continue
|
|
agent_output = metadata.get("agent_output")
|
|
if not isinstance(agent_output, dict):
|
|
continue
|
|
derived = agent_output.get("divination_derived")
|
|
if isinstance(derived, dict) and derived:
|
|
selected_snapshot = snapshot
|
|
break
|
|
if selected_snapshot is not None:
|
|
snapshots.append(selected_snapshot)
|
|
|
|
snapshots.sort(
|
|
key=lambda item: str(item.timestamp),
|
|
reverse=True,
|
|
)
|
|
return snapshots
|
|
|
|
async def get_system_agent_config(
|
|
self, *, agent_type: str
|
|
) -> dict[str, object] | None:
|
|
normalized_type = agent_type.strip().lower()
|
|
if not normalized_type:
|
|
return None
|
|
stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type)
|
|
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
return None
|
|
config_payload = row.config if isinstance(row.config, dict) else {}
|
|
return {
|
|
"agent_type": normalized_type,
|
|
"status": str(row.status),
|
|
"config": config_payload,
|
|
}
|
|
|
|
async def _to_chat_message_schema(
|
|
self, message: AgentChatMessage
|
|
) -> AgentChatMessageSchema:
|
|
role = (
|
|
message.role.value
|
|
if isinstance(message.role, AgentChatMessageRole)
|
|
else str(message.role)
|
|
)
|
|
payload_model = AgentChatMessageSchema.model_validate(
|
|
{
|
|
"id": str(message.id),
|
|
"session_id": str(message.session_id),
|
|
"seq": int(message.seq),
|
|
"role": role,
|
|
"content": message.content,
|
|
"model_code": message.model_code,
|
|
"tool_name": message.tool_name,
|
|
"input_tokens": int(message.input_tokens or 0),
|
|
"output_tokens": int(message.output_tokens or 0),
|
|
"cost": str(message.cost if message.cost is not None else Decimal("0")),
|
|
"latency_ms": message.latency_ms,
|
|
"metadata": message.metadata_json,
|
|
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
return payload_model
|
|
|
|
def _apply_visibility_filter(
|
|
self,
|
|
*,
|
|
stmt: Select[Any],
|
|
visibility_mask: int | None,
|
|
) -> Select[Any]:
|
|
if visibility_mask is None:
|
|
return stmt
|
|
required_mask = max(int(visibility_mask), 0)
|
|
if required_mask == 0:
|
|
return stmt
|
|
return stmt.where(
|
|
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
|
|
)
|
|
|
|
|
|
def _has_title(title: object) -> bool:
|
|
return isinstance(title, str) and bool(title.strip())
|
|
|
|
|
|
def _derive_session_title(content_text: str) -> str | None:
|
|
normalized = " ".join(content_text.split())
|
|
if not normalized:
|
|
return None
|
|
return normalized[:80]
|