feat: 实现起卦、设置与积分系统

This commit is contained in:
qzl
2026-04-03 16:56:47 +08:00
parent 31594558eb
commit f245eec5f6
170 changed files with 20728 additions and 328 deletions
+405
View File
@@ -0,0 +1,405 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from typing import Protocol
from uuid import UUID, uuid4
from sqlalchemy import Select, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from core.http.errors import ApiProblemError
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.system_agents import SystemAgents
from schemas.enums import AgentChatMessageRole
from schemas.domain.chat_message import (
AgentChatMessage as AgentChatMessageSchema,
AgentChatMessageMetadata,
)
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository:
def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session: AsyncSession = session
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = select(AgentChatSession.user_id).where(
AgentChatSession.id == session_uuid
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
return str(owner_id)
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
session = await self._session.get(AgentChatSession, session_uuid)
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_uuid)
.with_for_update()
)
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
if session_row is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content)
if session_title is not None:
session_row.title = session_title
message = OrmAgentChatMessage(
id=uuid4(),
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content,
visibility_mask=max(int(visibility_mask), 0),
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
session_row.message_count = next_seq
session_row.last_activity_at = datetime.now(timezone.utc)
await self._session.flush()
async def get_user_message_count(self, *, session_id: str) -> int:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(func.count(AgentChatMessage.id))
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.role == AgentChatMessageRole.USER)
)
count = (await self._session.execute(stmt)).scalar_one()
return int(count)
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
before_start = (
datetime.combine(before, time.min, tzinfo=timezone.utc)
if before is not None
else None
)
target_created_at_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
target_created_at_stmt = self._apply_visibility_filter(
stmt=target_created_at_stmt,
visibility_mask=visibility_mask,
)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
)
target_created_at = (
await self._session.execute(target_created_at_stmt)
).scalar_one_or_none()
if target_created_at is None:
return None
target_day = target_created_at.astimezone(timezone.utc).date()
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more_stmt = (
select(AgentChatMessage.id)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more_stmt = self._apply_visibility_filter(
stmt=has_more_stmt,
visibility_mask=visibility_mask,
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": snapshot_messages,
}
async def get_recent_messages_by_user_window(
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
safe_user_limit = max(int(user_message_limit), 1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.desc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
if not messages_desc:
return []
selected_desc: list[AgentChatMessage] = []
user_count = 0
for message in messages_desc:
selected_desc.append(message)
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
if role == AgentChatMessageRole.USER.value:
user_count += 1
if user_count >= safe_user_limit:
break
selected = list(reversed(selected_desc))
snapshot_messages: list[dict[str, object]] = []
for message in selected:
snapshot_messages.append(await self._to_snapshot_message(message))
return snapshot_messages
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None:
normalized_type = agent_type.strip().lower()
if not normalized_type:
return None
stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type)
row = (await self._session.execute(stmt)).scalar_one_or_none()
if row is None:
return None
config_payload = row.config if isinstance(row.config, dict) else {}
return {
"agent_type": normalized_type,
"status": str(row.status),
"config": config_payload,
}
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload_model = AgentChatMessageSchema.model_validate(
{
"id": str(message.id),
"seq": int(message.seq),
"role": role,
"content": message.content,
"model_code": message.model_code,
"tool_name": message.tool_name,
"input_tokens": int(message.input_tokens or 0),
"output_tokens": int(message.output_tokens or 0),
"cost": str(message.cost if message.cost is not None else Decimal("0")),
"latency_ms": message.latency_ms,
"metadata": message.metadata_json,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
)
return payload_model.model_dump(mode="json", exclude_none=True)
def _apply_visibility_filter(
self,
*,
stmt: Select,
visibility_mask: int | None,
) -> Select:
if visibility_mask is None:
return stmt
required_mask = max(int(visibility_mask), 0)
if required_mask == 0:
return stmt
return stmt.where(
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
)
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
def _derive_session_title(content_text: str) -> str | None:
normalized = " ".join(content_text.split())
if not normalized:
return None
return normalized[:80]