feat: 实现起卦、设置与积分系统
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import dashscope
|
||||
from dashscope.audio.asr import Recognition, RecognitionCallback
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsrService:
|
||||
def __init__(self) -> None:
|
||||
self._api_key: str | None = None
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self._api_key is None:
|
||||
dashscope_key = config.llm.provider_keys.get("dashscope")
|
||||
if not dashscope_key:
|
||||
raise ValueError(
|
||||
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
|
||||
)
|
||||
self._api_key = dashscope_key
|
||||
return self._api_key
|
||||
|
||||
async def transcribe_file(self, file_path: str, filename: str) -> str:
|
||||
try:
|
||||
dashscope.api_key = self._get_api_key()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
class SyncCallback(RecognitionCallback):
|
||||
error: str | None = None
|
||||
|
||||
def on_error(self, result: Any) -> None:
|
||||
self.error = str(result)
|
||||
|
||||
callback = SyncCallback()
|
||||
recognizer = Recognition(
|
||||
model="fun-asr-realtime-2026-02-28",
|
||||
callback=callback,
|
||||
format="wav",
|
||||
sample_rate=16000,
|
||||
)
|
||||
|
||||
result: Any = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: recognizer.call(file=file_path),
|
||||
)
|
||||
|
||||
if callback.error:
|
||||
raise RuntimeError(f"ASR error: {callback.error}")
|
||||
status_code = self._extract_field(result, "status_code")
|
||||
if status_code != 200:
|
||||
message = self._extract_field(result, "message")
|
||||
raise RuntimeError(f"ASR transcription failed: {message}")
|
||||
|
||||
sentence = self._extract_sentence_payload(result)
|
||||
if sentence is None:
|
||||
request_id = self._extract_field(result, "request_id")
|
||||
logger.warning(
|
||||
"ASR returned empty result", extra={"request_id": request_id}
|
||||
)
|
||||
return ""
|
||||
|
||||
if isinstance(sentence, dict):
|
||||
transcription = sentence.get("text", "")
|
||||
elif isinstance(sentence, list):
|
||||
transcription = " ".join(
|
||||
item.get("text", "") for item in sentence if isinstance(item, dict)
|
||||
)
|
||||
else:
|
||||
transcription = str(sentence) if sentence else ""
|
||||
|
||||
logger.info(
|
||||
"ASR transcription completed",
|
||||
extra={"filename": filename, "transcript_length": len(transcription)},
|
||||
)
|
||||
return transcription
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("ASR transcription error")
|
||||
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
|
||||
|
||||
def _extract_sentence_payload(self, result: Any) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
output = result.get("output")
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
if output is not None:
|
||||
return getattr(output, "sentence", None)
|
||||
return result.get("sentence")
|
||||
|
||||
get_sentence = getattr(result, "get_sentence", None)
|
||||
if callable(get_sentence):
|
||||
sentence = get_sentence()
|
||||
if sentence is not None:
|
||||
return sentence
|
||||
|
||||
output = getattr(result, "output", None)
|
||||
if output is None:
|
||||
return None
|
||||
if isinstance(output, dict):
|
||||
return output.get("sentence")
|
||||
return getattr(output, "sentence", None)
|
||||
|
||||
def _extract_field(self, result: Any, field: str) -> Any | None:
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return getattr(result, field, None)
|
||||
|
||||
|
||||
asr_service = AsrService()
|
||||
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agentscope.events import RedisStreamBus
|
||||
from core.agentscope.tools.tool_result_storage import (
|
||||
create_tool_result_storage,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.agent.repository import AgentRepository
|
||||
from v1.agent.service import AgentService
|
||||
from v1.points.repository import PointsRepository
|
||||
from v1.points.service import PointsService
|
||||
|
||||
DEDUP_WAIT_RETRIES = 20
|
||||
DEDUP_WAIT_SECONDS = 0.05
|
||||
DEDUP_LOCK_SECONDS = 300
|
||||
DEDUP_INFLIGHT_MARKER = "__inflight__"
|
||||
RUN_CANCEL_SIGNAL_TTL_SECONDS = 1800
|
||||
|
||||
|
||||
def _event_stream_block_ms() -> int:
|
||||
configured = int(config.agent_runtime.redis_stream_block_ms)
|
||||
socket_timeout = float(config.redis.socket_timeout)
|
||||
socket_timeout_ms = max(int(socket_timeout * 1000), 1)
|
||||
safe_max = max(socket_timeout_ms - 100, 1)
|
||||
return max(1, min(configured, safe_max))
|
||||
|
||||
|
||||
class TaskiqQueueClient:
|
||||
def __init__(self) -> None:
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def _get_redis(self) -> Redis:
|
||||
if self._redis is None:
|
||||
self._redis = await get_or_init_redis_client()
|
||||
return self._redis
|
||||
|
||||
@staticmethod
|
||||
def _select_queue_task(command: dict[str, object]) -> Any:
|
||||
from core.agentscope.runtime.tasks import (
|
||||
run_command_task_agent,
|
||||
run_command_task_general,
|
||||
)
|
||||
|
||||
queue = str(command.get("queue", "agent")).strip().lower()
|
||||
if queue == "general":
|
||||
return run_command_task_general
|
||||
return run_command_task_agent
|
||||
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
redis_client = await self._get_redis()
|
||||
redis_key = None
|
||||
if dedup_key:
|
||||
redis_key = f"agent:dedup:{dedup_key}"
|
||||
locked = await redis_client.set(
|
||||
redis_key,
|
||||
DEDUP_INFLIGHT_MARKER,
|
||||
nx=True,
|
||||
ex=DEDUP_LOCK_SECONDS,
|
||||
)
|
||||
if not locked:
|
||||
for _ in range(DEDUP_WAIT_RETRIES):
|
||||
existing = await redis_client.get(redis_key)
|
||||
if existing and existing != DEDUP_INFLIGHT_MARKER:
|
||||
return existing
|
||||
await asyncio.sleep(DEDUP_WAIT_SECONDS)
|
||||
raise RuntimeError("duplicate request is still in progress")
|
||||
|
||||
payload = dict(command)
|
||||
queue_task = self._select_queue_task(payload)
|
||||
try:
|
||||
result = await queue_task.kiq(payload)
|
||||
task_id = str(result.task_id)
|
||||
if redis_key is not None:
|
||||
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
|
||||
return task_id
|
||||
except Exception:
|
||||
if redis_key is not None:
|
||||
await redis_client.delete(redis_key)
|
||||
raise
|
||||
|
||||
async def request_cancel(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
requested_by: str,
|
||||
) -> None:
|
||||
redis_client = await self._get_redis()
|
||||
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
|
||||
payload = json.dumps(
|
||||
{
|
||||
"requested_by": requested_by,
|
||||
"requested_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
await redis_client.set(
|
||||
cancel_key,
|
||||
payload,
|
||||
ex=RUN_CANCEL_SIGNAL_TTL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
self._bus: RedisStreamBus | None = None
|
||||
|
||||
async def _get_bus(self) -> RedisStreamBus:
|
||||
if self._bus is None:
|
||||
client = await get_or_init_redis_client()
|
||||
self._bus = RedisStreamBus(
|
||||
client=client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=_event_stream_block_ms(),
|
||||
)
|
||||
return self._bus
|
||||
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
bus = await self._get_bus()
|
||||
rows = await bus.read(session_id=session_id, last_event_id=last_event_id)
|
||||
return [{**row, "cursor": row.get("id")} for row in rows]
|
||||
|
||||
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
tool_result_storage = create_tool_result_storage()
|
||||
return AgentService(
|
||||
repository=AgentRepository(session, tool_result_storage=tool_result_storage),
|
||||
queue=TaskiqQueueClient(),
|
||||
stream=RedisEventStream(),
|
||||
points_service=PointsService(repository=PointsRepository(session)),
|
||||
attachment_storage=supabase_service,
|
||||
)
|
||||
@@ -0,0 +1,405 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Protocol
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import Select, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.http.errors import ApiProblemError
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from models.system_agents import SystemAgents
|
||||
from schemas.enums import AgentChatMessageRole
|
||||
from schemas.domain.chat_message import (
|
||||
AgentChatMessage as AgentChatMessageSchema,
|
||||
AgentChatMessageMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ToolResultPayloadStorage(Protocol):
|
||||
async def read_json(
|
||||
self, *, bucket: str, path: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class AgentRepository:
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
tool_result_storage: ToolResultPayloadStorage | None = None,
|
||||
) -> None:
|
||||
self._session: AsyncSession = session
|
||||
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
|
||||
stmt = select(AgentChatSession.user_id).where(
|
||||
AgentChatSession.id == session_uuid
|
||||
)
|
||||
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="AGENT_SESSION_NOT_FOUND",
|
||||
detail="Session not found",
|
||||
)
|
||||
return str(owner_id)
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_USER_ID_INVALID",
|
||||
detail="Invalid user_id",
|
||||
) from exc
|
||||
session_uuid = None
|
||||
if session_id is not None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
|
||||
session = AgentChatSession(
|
||||
id=session_uuid,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
self._session.add(session)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(session)
|
||||
return str(session.id)
|
||||
|
||||
async def commit(self) -> None:
|
||||
await self._session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
await self._session.rollback()
|
||||
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
session = await self._session.get(AgentChatSession, session_uuid)
|
||||
if session is not None:
|
||||
await self._session.delete(session)
|
||||
await self._session.flush()
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
visibility_mask: int,
|
||||
) -> None:
|
||||
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
|
||||
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_uuid)
|
||||
.with_for_update()
|
||||
)
|
||||
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if session_row is None:
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="AGENT_SESSION_NOT_FOUND",
|
||||
detail="Session not found",
|
||||
)
|
||||
|
||||
next_seq = int(session_row.message_count or 0) + 1
|
||||
if not _has_title(session_row.title):
|
||||
session_title = _derive_session_title(content)
|
||||
if session_title is not None:
|
||||
session_row.title = session_title
|
||||
|
||||
message = OrmAgentChatMessage(
|
||||
id=uuid4(),
|
||||
session_id=session_uuid,
|
||||
seq=next_seq,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=content,
|
||||
visibility_mask=max(int(visibility_mask), 0),
|
||||
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
|
||||
)
|
||||
self._session.add(message)
|
||||
session_row.message_count = next_seq
|
||||
session_row.last_activity_at = datetime.now(timezone.utc)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_user_message_count(self, *, session_id: str) -> int:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
|
||||
stmt = (
|
||||
select(func.count(AgentChatMessage.id))
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.role == AgentChatMessageRole.USER)
|
||||
)
|
||||
count = (await self._session.execute(stmt)).scalar_one()
|
||||
return int(count)
|
||||
|
||||
async def get_history_day(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
|
||||
before_start = (
|
||||
datetime.combine(before, time.min, tzinfo=timezone.utc)
|
||||
if before is not None
|
||||
else None
|
||||
)
|
||||
|
||||
target_created_at_stmt = (
|
||||
select(AgentChatMessage.created_at)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
target_created_at_stmt = self._apply_visibility_filter(
|
||||
stmt=target_created_at_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
if before_start is not None:
|
||||
target_created_at_stmt = target_created_at_stmt.where(
|
||||
AgentChatMessage.created_at < before_start
|
||||
)
|
||||
target_created_at = (
|
||||
await self._session.execute(target_created_at_stmt)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if target_created_at is None:
|
||||
return None
|
||||
|
||||
target_day = target_created_at.astimezone(timezone.utc).date()
|
||||
|
||||
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
message_stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at >= start)
|
||||
.where(AgentChatMessage.created_at < end)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
message_stmt = self._apply_visibility_filter(
|
||||
stmt=message_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages = (await self._session.execute(message_stmt)).scalars().all()
|
||||
has_more_stmt = (
|
||||
select(AgentChatMessage.id)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.where(AgentChatMessage.created_at < start)
|
||||
.limit(1)
|
||||
)
|
||||
has_more_stmt = self._apply_visibility_filter(
|
||||
stmt=has_more_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
has_more = (
|
||||
await self._session.execute(has_more_stmt)
|
||||
).scalar_one_or_none() is not None
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
for message in messages:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
return {
|
||||
"day": target_day.isoformat(),
|
||||
"hasMore": has_more,
|
||||
"messages": snapshot_messages,
|
||||
}
|
||||
|
||||
async def get_recent_messages_by_user_window(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
user_message_limit: int,
|
||||
visibility_mask: int | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
try:
|
||||
session_uuid = UUID(session_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_SESSION_ID_INVALID",
|
||||
detail="Invalid session_id",
|
||||
) from exc
|
||||
|
||||
safe_user_limit = max(int(user_message_limit), 1)
|
||||
message_stmt = (
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.where(AgentChatMessage.deleted_at.is_(None))
|
||||
.order_by(AgentChatMessage.seq.desc())
|
||||
)
|
||||
message_stmt = self._apply_visibility_filter(
|
||||
stmt=message_stmt,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
|
||||
if not messages_desc:
|
||||
return []
|
||||
|
||||
selected_desc: list[AgentChatMessage] = []
|
||||
user_count = 0
|
||||
for message in messages_desc:
|
||||
selected_desc.append(message)
|
||||
role = (
|
||||
message.role.value
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
else str(message.role)
|
||||
)
|
||||
if role == AgentChatMessageRole.USER.value:
|
||||
user_count += 1
|
||||
if user_count >= safe_user_limit:
|
||||
break
|
||||
|
||||
selected = list(reversed(selected_desc))
|
||||
snapshot_messages: list[dict[str, object]] = []
|
||||
for message in selected:
|
||||
snapshot_messages.append(await self._to_snapshot_message(message))
|
||||
return snapshot_messages
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="AGENT_USER_ID_INVALID",
|
||||
detail="Invalid user_id",
|
||||
) from exc
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == user_uuid)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if latest_id is None:
|
||||
return None
|
||||
return str(latest_id)
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None:
|
||||
normalized_type = agent_type.strip().lower()
|
||||
if not normalized_type:
|
||||
return None
|
||||
stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type)
|
||||
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
config_payload = row.config if isinstance(row.config, dict) else {}
|
||||
return {
|
||||
"agent_type": normalized_type,
|
||||
"status": str(row.status),
|
||||
"config": config_payload,
|
||||
}
|
||||
|
||||
async def _to_snapshot_message(
|
||||
self, message: AgentChatMessage
|
||||
) -> dict[str, object]:
|
||||
role = (
|
||||
message.role.value
|
||||
if isinstance(message.role, AgentChatMessageRole)
|
||||
else str(message.role)
|
||||
)
|
||||
payload_model = AgentChatMessageSchema.model_validate(
|
||||
{
|
||||
"id": str(message.id),
|
||||
"seq": int(message.seq),
|
||||
"role": role,
|
||||
"content": message.content,
|
||||
"model_code": message.model_code,
|
||||
"tool_name": message.tool_name,
|
||||
"input_tokens": int(message.input_tokens or 0),
|
||||
"output_tokens": int(message.output_tokens or 0),
|
||||
"cost": str(message.cost if message.cost is not None else Decimal("0")),
|
||||
"latency_ms": message.latency_ms,
|
||||
"metadata": message.metadata_json,
|
||||
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
return payload_model.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
def _apply_visibility_filter(
|
||||
self,
|
||||
*,
|
||||
stmt: Select,
|
||||
visibility_mask: int | None,
|
||||
) -> Select:
|
||||
if visibility_mask is None:
|
||||
return stmt
|
||||
required_mask = max(int(visibility_mask), 0)
|
||||
if required_mask == 0:
|
||||
return stmt
|
||||
return stmt.where(
|
||||
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
|
||||
)
|
||||
|
||||
|
||||
def _has_title(title: object) -> bool:
|
||||
return isinstance(title, str) and bool(title.strip())
|
||||
|
||||
|
||||
def _derive_session_title(content_text: str) -> str | None:
|
||||
normalized = " ".join(content_text.split())
|
||||
if not normalized:
|
||||
return None
|
||||
return normalized[:80]
|
||||
@@ -0,0 +1,473 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import date
|
||||
from typing import Annotated
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from core.agentscope.events import to_sse_event
|
||||
from core.agentscope.schemas.agui_input import (
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.logging import get_logger
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
Query,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
AsrTranscribeResponse,
|
||||
AttachmentReference,
|
||||
AttachmentSignedUrlResponse,
|
||||
AttachmentUploadResponse,
|
||||
CancelRunResponse,
|
||||
HistorySnapshotResponse,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from v1.agent.asr import asr_service
|
||||
from v1.agent.service import AgentService
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
logger = get_logger("v1.agent.router")
|
||||
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,128}$")
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
|
||||
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
|
||||
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
|
||||
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
|
||||
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
|
||||
_WAV_HEADER_MIN_BYTES = 12
|
||||
_ALLOWED_AUDIO_CONTENT_TYPES = {
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"audio/wave",
|
||||
}
|
||||
|
||||
|
||||
def _looks_like_wav_header(header: bytes) -> bool:
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
return False
|
||||
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
|
||||
|
||||
|
||||
async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
try:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"agent:sse-active:{user_id}"
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
await redis.decr(key)
|
||||
return False
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"SSE slot acquire failed",
|
||||
user_id=user_id,
|
||||
reason=str(exc),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def _release_sse_slot(*, user_id: str) -> None:
|
||||
try:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"agent:sse-active:{user_id}"
|
||||
count = await redis.decr(key)
|
||||
if count <= 0:
|
||||
await redis.delete(key)
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"SSE slot release failed",
|
||||
user_id=user_id,
|
||||
reason=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _is_terminal_run_event(event: dict[str, object]) -> bool:
|
||||
raw_event_type = event.get("type")
|
||||
return (
|
||||
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
|
||||
)
|
||||
|
||||
|
||||
def _is_target_run_event(event: dict[str, object], *, target_run_id: str) -> bool:
|
||||
run_id = event.get("runId")
|
||||
return isinstance(run_id, str) and run_id == target_run_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def enqueue_run(
|
||||
request: RunAgentInput,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> TaskAcceptedResponse:
|
||||
try:
|
||||
request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True))
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(code="AGENT_RUN_INPUT_INVALID", detail=str(exc)),
|
||||
) from exc
|
||||
try:
|
||||
validate_run_request_messages_contract(request)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(code="AGENT_RUN_MESSAGES_INVALID", detail=str(exc)),
|
||||
) from exc
|
||||
task = await service.enqueue_run(
|
||||
run_input=request,
|
||||
current_user=current_user,
|
||||
)
|
||||
return TaskAcceptedResponse(
|
||||
taskId=task.task_id,
|
||||
threadId=task.thread_id,
|
||||
runId=task.run_id,
|
||||
created=task.created,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/runs/{thread_id}/cancel",
|
||||
response_model=CancelRunResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
run_id: str = Query(
|
||||
alias="runId",
|
||||
min_length=1,
|
||||
max_length=128,
|
||||
pattern=r"^[A-Za-z0-9_-]+$",
|
||||
),
|
||||
) -> CancelRunResponse:
|
||||
canceled = await service.cancel_run(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
return CancelRunResponse(
|
||||
threadId=canceled.thread_id,
|
||||
runId=canceled.run_id,
|
||||
accepted=canceled.accepted,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/runs/{thread_id}/events")
|
||||
async def stream_events(
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
run_id: str | None = Query(default=None, alias="runId"),
|
||||
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
||||
idle_limit: int = Query(default=300, ge=1, le=3600),
|
||||
) -> StreamingResponse:
|
||||
if run_id is None or _RUN_ID_RE.fullmatch(run_id) is None:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_INVALID_RUN_ID",
|
||||
detail="Invalid runId",
|
||||
),
|
||||
)
|
||||
|
||||
if last_event_id is not None and (
|
||||
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
|
||||
):
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_INVALID_LAST_EVENT_ID",
|
||||
detail="Invalid Last-Event-ID",
|
||||
),
|
||||
)
|
||||
|
||||
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
|
||||
if not sse_slot_acquired:
|
||||
raise ApiProblemError(
|
||||
status_code=429,
|
||||
detail=problem_payload(
|
||||
code="AGENT_SSE_CONNECTION_LIMIT",
|
||||
detail="Too many SSE connections",
|
||||
),
|
||||
)
|
||||
|
||||
async def _event_iter() -> AsyncIterator[str]:
|
||||
cursor = last_event_id
|
||||
idle_polls = 0
|
||||
terminal_event_reached = False
|
||||
try:
|
||||
while (
|
||||
not terminal_event_reached
|
||||
and not await request.is_disconnected()
|
||||
and idle_polls < idle_limit
|
||||
):
|
||||
try:
|
||||
rows = await service.stream_events(
|
||||
thread_id=thread_id,
|
||||
last_event_id=cursor,
|
||||
current_user=current_user,
|
||||
)
|
||||
except (TimeoutError, RedisTimeoutError):
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"SSE stream read failed",
|
||||
thread_id=thread_id,
|
||||
user_id=str(current_user.id),
|
||||
reason=str(exc),
|
||||
)
|
||||
break
|
||||
|
||||
if not rows:
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
|
||||
idle_polls = 0
|
||||
for row in rows:
|
||||
row_id = str(row.get("id", ""))
|
||||
event = row.get("event")
|
||||
if not row_id or not isinstance(event, dict):
|
||||
continue
|
||||
cursor = row_id
|
||||
if not _is_target_run_event(event, target_run_id=run_id):
|
||||
continue
|
||||
yield to_sse_event(row_id, event)
|
||||
if _is_terminal_run_event(event):
|
||||
terminal_event_reached = True
|
||||
break
|
||||
|
||||
finally:
|
||||
await _release_sse_slot(user_id=str(current_user.id))
|
||||
|
||||
return StreamingResponse(
|
||||
_event_iter(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history", response_model=HistorySnapshotResponse)
|
||||
async def get_user_history_snapshot(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str | None = Query(default=None, alias="threadId"),
|
||||
before: date | None = Query(default=None),
|
||||
) -> HistorySnapshotResponse:
|
||||
return await service.get_user_history_snapshot(
|
||||
current_user=current_user,
|
||||
thread_id=thread_id,
|
||||
before=before,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/attachments",
|
||||
response_model=AttachmentUploadResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def upload_attachment(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
thread_id: str = Form(alias="threadId"),
|
||||
file: UploadFile = File(),
|
||||
) -> AttachmentUploadResponse:
|
||||
payload = await file.read()
|
||||
if not payload:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_EMPTY",
|
||||
detail="Empty attachment",
|
||||
),
|
||||
)
|
||||
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
|
||||
raise ApiProblemError(
|
||||
status_code=413,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_TOO_LARGE",
|
||||
detail="Attachment too large",
|
||||
params={"maxBytes": _MAX_ATTACHMENT_UPLOAD_BYTES},
|
||||
),
|
||||
)
|
||||
attachment = await service.upload_attachment(
|
||||
thread_id=thread_id,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type,
|
||||
payload=payload,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentUploadResponse(
|
||||
attachment=AttachmentReference.model_validate(attachment),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/attachments/signed-url",
|
||||
response_model=AttachmentSignedUrlResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def create_attachment_signed_url(
|
||||
service: Annotated[AgentService, Depends(get_agent_service)],
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
bucket: str = Query(min_length=1, max_length=100),
|
||||
path: str = Query(min_length=1, max_length=500),
|
||||
) -> AttachmentSignedUrlResponse:
|
||||
signed = await service.create_attachment_signed_url(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
current_user=current_user,
|
||||
)
|
||||
return AttachmentSignedUrlResponse(**signed)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
response_model=AsrTranscribeResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def transcribe(
|
||||
audio: UploadFile,
|
||||
request: Request,
|
||||
current_user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> AsrTranscribeResponse:
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="AGENT_AUDIO_UNSUPPORTED_FORMAT",
|
||||
detail="Unsupported audio format",
|
||||
),
|
||||
)
|
||||
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length is not None:
|
||||
try:
|
||||
declared_length = int(content_length)
|
||||
except ValueError:
|
||||
declared_length = None
|
||||
if (
|
||||
declared_length is not None
|
||||
and declared_length
|
||||
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
|
||||
):
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="AGENT_AUDIO_TOO_LARGE",
|
||||
detail="Audio file too large",
|
||||
params={"maxBytes": _MAX_TRANSCRIBE_AUDIO_BYTES},
|
||||
),
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
|
||||
total_bytes = 0
|
||||
header = bytearray()
|
||||
while True:
|
||||
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
|
||||
if not chunk:
|
||||
break
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="AGENT_AUDIO_TOO_LARGE",
|
||||
detail="Audio file too large",
|
||||
params={"maxBytes": _MAX_TRANSCRIBE_AUDIO_BYTES},
|
||||
),
|
||||
)
|
||||
if len(header) < _WAV_HEADER_MIN_BYTES:
|
||||
required = _WAV_HEADER_MIN_BYTES - len(header)
|
||||
header.extend(chunk[:required])
|
||||
tmp_file.write(chunk)
|
||||
|
||||
if total_bytes == 0:
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="AGENT_AUDIO_EMPTY",
|
||||
detail="Empty audio file",
|
||||
),
|
||||
)
|
||||
if not _looks_like_wav_header(bytes(header)):
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="AGENT_AUDIO_UNSUPPORTED_FORMAT",
|
||||
detail="Unsupported audio format",
|
||||
),
|
||||
)
|
||||
|
||||
transcript = await asr_service.transcribe_file(
|
||||
temp_path, audio.filename or "unknown"
|
||||
)
|
||||
|
||||
return AsrTranscribeResponse(transcript=transcript)
|
||||
|
||||
except ApiProblemError:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise ApiProblemError(
|
||||
status_code=502,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ASR_UNAVAILABLE",
|
||||
detail="ASR service unavailable",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await audio.close()
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Any, Literal, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from schemas.agent.ui_schema import UiSchemaRenderer
|
||||
|
||||
|
||||
class AgentRepositoryLike(Protocol):
|
||||
async def get_session_owner(self, *, session_id: str) -> str: ...
|
||||
|
||||
async def create_session_for_user(
|
||||
self, *, user_id: str, session_id: str | None = None
|
||||
) -> str: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
async def get_history_day(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
before: date | None,
|
||||
visibility_mask: int | None = None,
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
|
||||
|
||||
async def persist_user_message(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
content: str,
|
||||
metadata: Any,
|
||||
visibility_mask: int,
|
||||
) -> None: ...
|
||||
|
||||
async def get_user_message_count(self, *, session_id: str) -> int: ...
|
||||
|
||||
async def get_system_agent_config(
|
||||
self, *, agent_type: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class QueueClientLike(Protocol):
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str: ...
|
||||
|
||||
async def request_cancel(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
requested_by: str,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class EventStreamLike(Protocol):
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, object]]: ...
|
||||
|
||||
|
||||
class PointsServiceLike(Protocol):
|
||||
async def ensure_run_points_available(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
) -> int: ...
|
||||
|
||||
async def consume_successful_run_points(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
session_id: UUID,
|
||||
run_id: str,
|
||||
operator_id: UUID | None,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class AttachmentStorageLike(Protocol):
|
||||
async def upload_bytes(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
) -> str: ...
|
||||
|
||||
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
|
||||
|
||||
async def create_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
expires_in_seconds: int,
|
||||
) -> str: ...
|
||||
|
||||
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
thread_id: str
|
||||
run_id: str
|
||||
created: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CancelRequested:
|
||||
thread_id: str
|
||||
run_id: str
|
||||
accepted: bool
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
task_id: str = Field(alias="taskId")
|
||||
thread_id: str = Field(alias="threadId")
|
||||
run_id: str = Field(alias="runId")
|
||||
created: bool
|
||||
|
||||
|
||||
class CancelRunResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
thread_id: str = Field(alias="threadId")
|
||||
run_id: str = Field(alias="runId")
|
||||
accepted: bool
|
||||
|
||||
|
||||
class AsrTranscribeResponse(BaseModel):
|
||||
transcript: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
class AttachmentReference(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
bucket: str
|
||||
path: str
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
url: str
|
||||
|
||||
|
||||
class AttachmentUploadResponse(BaseModel):
|
||||
attachment: AttachmentReference
|
||||
|
||||
|
||||
class AttachmentSignedUrlResponse(BaseModel):
|
||||
bucket: str
|
||||
path: str
|
||||
url: str
|
||||
|
||||
|
||||
class HistoryMessageAttachment(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
url: str
|
||||
|
||||
|
||||
class HistoryMessage(BaseModel):
|
||||
"""History message schema for /history endpoint response."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
id: str = Field(description="Message UUID")
|
||||
seq: int = Field(description="Message sequence number")
|
||||
role: Literal["user", "assistant"] = Field(
|
||||
description="Message role: user | assistant"
|
||||
)
|
||||
content: str = Field(description="Message text content")
|
||||
attachments: list[HistoryMessageAttachment] = Field(
|
||||
default_factory=list,
|
||||
description="Temporary signed URLs for user-attached images",
|
||||
)
|
||||
ui_schema: UiSchemaRenderer | None = Field(
|
||||
default=None,
|
||||
description="Compiled UI schema from worker ui_hints for frontend rendering",
|
||||
)
|
||||
timestamp: str = Field(description="Message creation timestamp in ISO-8601 format")
|
||||
|
||||
|
||||
class HistorySnapshotResponse(BaseModel):
|
||||
"""Response schema for GET /api/v1/agent/history"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
scope: str = Field(default="history_day")
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
day: str | None = None
|
||||
has_more: bool = Field(default=False, alias="hasMore")
|
||||
messages: list[HistoryMessage] = Field(default_factory=list)
|
||||
@@ -0,0 +1,706 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timezone
|
||||
import hashlib
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ag_ui.core import RunAgentInput
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agentscope.caches.context_messages_cache import (
|
||||
create_context_messages_cache,
|
||||
)
|
||||
from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from schemas.agent.forwarded_props import (
|
||||
parse_forwarded_props_runtime_mode,
|
||||
RuntimeMode,
|
||||
)
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.agent.runtime_config import RuntimeConfig
|
||||
from schemas.domain.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachment,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from v1.agent.schemas import (
|
||||
AgentRepositoryLike,
|
||||
AttachmentStorageLike,
|
||||
CancelRequested,
|
||||
EventStreamLike,
|
||||
HistorySnapshotResponse,
|
||||
PointsServiceLike,
|
||||
QueueClientLike,
|
||||
TaskAccepted,
|
||||
)
|
||||
from v1.agent.utils import (
|
||||
MAX_ATTACHMENT_BYTES,
|
||||
MAX_ATTACHMENTS_PER_MESSAGE,
|
||||
is_safe_attachment_path,
|
||||
mime_to_suffix,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
MAX_RUNS_PER_SESSION = 4
|
||||
|
||||
|
||||
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
if owner_id != str(current_user.id):
|
||||
raise ApiProblemError(
|
||||
status_code=403,
|
||||
detail=problem_payload(code="AGENT_FORBIDDEN", detail="Forbidden"),
|
||||
)
|
||||
|
||||
|
||||
class AgentService:
|
||||
_repository: AgentRepositoryLike
|
||||
_queue: QueueClientLike
|
||||
_stream: EventStreamLike
|
||||
_points_service: PointsServiceLike
|
||||
_attachment_storage: AttachmentStorageLike | None
|
||||
|
||||
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
repository: AgentRepositoryLike,
|
||||
queue: QueueClientLike,
|
||||
stream: EventStreamLike,
|
||||
points_service: PointsServiceLike,
|
||||
attachment_storage: AttachmentStorageLike | None = None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._queue = queue
|
||||
self._stream = stream
|
||||
self._points_service = points_service
|
||||
self._attachment_storage = attachment_storage
|
||||
|
||||
async def enqueue_run(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
runtime_config: RuntimeConfig | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
run_id = run_input.run_id
|
||||
forwarded_props = getattr(run_input, "forwarded_props", None)
|
||||
try:
|
||||
runtime_mode = parse_forwarded_props_runtime_mode(forwarded_props)
|
||||
except ValueError as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(code="AGENT_PAYLOAD_INVALID", detail=str(exc)),
|
||||
) from exc
|
||||
|
||||
if runtime_config is None:
|
||||
from v1.agent.system_agents_config import (
|
||||
build_runtime_config_from_system_agents,
|
||||
)
|
||||
|
||||
runtime_config = build_runtime_config_from_system_agents()
|
||||
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except ApiProblemError as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
created = await self._create_session_if_missing(
|
||||
thread_id=thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
else:
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
try:
|
||||
await self._enforce_run_preconditions(
|
||||
thread_id=thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
except ApiProblemError:
|
||||
if created:
|
||||
await self._repository.rollback()
|
||||
raise
|
||||
|
||||
user_message_text, user_message_metadata = await self._prepare_user_message(
|
||||
run_input=run_input,
|
||||
current_user=current_user,
|
||||
)
|
||||
visibility_mask = await self._resolve_user_message_visibility_mask(
|
||||
runtime_mode=runtime_mode
|
||||
)
|
||||
await self._repository.persist_user_message(
|
||||
session_id=thread_id,
|
||||
content=user_message_text,
|
||||
metadata=user_message_metadata,
|
||||
visibility_mask=visibility_mask,
|
||||
)
|
||||
await self._repository.commit()
|
||||
await self._append_context_cache_user_message(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode,
|
||||
visibility_mask=visibility_mask,
|
||||
content=user_message_text,
|
||||
metadata=user_message_metadata,
|
||||
)
|
||||
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"owner_email": current_user.email,
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"runtime_config": runtime_config.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"queue": "agent",
|
||||
},
|
||||
dedup_key=None,
|
||||
)
|
||||
return TaskAccepted(
|
||||
task_id=task_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
created=created,
|
||||
)
|
||||
|
||||
async def cancel_run(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> CancelRequested:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
await self._queue.request_cancel(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
requested_by=str(current_user.id),
|
||||
)
|
||||
return CancelRequested(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
accepted=True,
|
||||
)
|
||||
|
||||
async def _append_context_cache_user_message(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
runtime_mode: RuntimeMode,
|
||||
visibility_mask: int,
|
||||
content: str,
|
||||
metadata: AgentChatMessageMetadata | None,
|
||||
) -> None:
|
||||
metadata_payload = (
|
||||
metadata.model_dump(mode="json", exclude_none=True)
|
||||
if isinstance(metadata, AgentChatMessageMetadata)
|
||||
else None
|
||||
)
|
||||
message_payload: dict[str, object] = {
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
|
||||
}
|
||||
if isinstance(metadata_payload, dict):
|
||||
message_payload["metadata"] = metadata_payload
|
||||
|
||||
try:
|
||||
context_cache = create_context_messages_cache()
|
||||
await context_cache.append_message(
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode.value,
|
||||
visibility_mask=visibility_mask,
|
||||
message=message_payload,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to append user message to context cache",
|
||||
thread_id=thread_id,
|
||||
runtime_mode=runtime_mode.value,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
async def _resolve_user_message_visibility_mask(
|
||||
self, *, runtime_mode: RuntimeMode
|
||||
) -> int:
|
||||
if runtime_mode == RuntimeMode.CHAT:
|
||||
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
|
||||
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
|
||||
)
|
||||
return 0
|
||||
|
||||
async def _prepare_user_message(
|
||||
self,
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, AgentChatMessageMetadata | None]:
|
||||
text, content_blocks = extract_latest_user_payload(run_input)
|
||||
|
||||
user_attachments: list[UserMessageAttachment] = []
|
||||
for block in content_blocks:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type != "binary":
|
||||
continue
|
||||
|
||||
url = block.get("url")
|
||||
mime_type = block.get("mimeType")
|
||||
if not isinstance(url, str) or not url:
|
||||
continue
|
||||
if not isinstance(mime_type, str):
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
if self._attachment_storage is None:
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
|
||||
detail="Attachment storage unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
bucket, path = self._validate_binary_signed_url(
|
||||
url=url,
|
||||
thread_id=run_input.thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
user_attachments.append(
|
||||
UserMessageAttachment(
|
||||
bucket=bucket,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
if len(user_attachments) > MAX_ATTACHMENTS_PER_MESSAGE:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENTS_TOO_MANY",
|
||||
detail="Too many attachments",
|
||||
params={"max": MAX_ATTACHMENTS_PER_MESSAGE},
|
||||
),
|
||||
)
|
||||
except ApiProblemError:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
parsed = urlparse(url)
|
||||
safe_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
|
||||
logger.warning(
|
||||
"Failed to parse signed URL", url=safe_url, error=str(exc)
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_SIGNED_IMAGE_URL_INVALID",
|
||||
detail="Invalid signed image url",
|
||||
),
|
||||
)
|
||||
|
||||
metadata: AgentChatMessageMetadata | None = None
|
||||
if user_attachments:
|
||||
metadata = AgentChatMessageMetadata(
|
||||
run_id=run_input.run_id,
|
||||
user_message_attachments=user_attachments,
|
||||
)
|
||||
|
||||
return text, metadata
|
||||
|
||||
async def upload_attachment(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
filename: str | None,
|
||||
content_type: str | None,
|
||||
payload: bytes,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
|
||||
detail="Attachment storage unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
if not isinstance(content_type, str):
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_UNSUPPORTED_TYPE",
|
||||
detail="Unsupported attachment type",
|
||||
),
|
||||
)
|
||||
mime_type = content_type.lower()
|
||||
if mime_type not in {"image/png", "image/jpeg", "image/webp"}:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_UNSUPPORTED_TYPE",
|
||||
detail="Unsupported attachment type",
|
||||
),
|
||||
)
|
||||
if not payload:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_EMPTY",
|
||||
detail="Empty attachment",
|
||||
),
|
||||
)
|
||||
if len(payload) > MAX_ATTACHMENT_BYTES:
|
||||
raise ApiProblemError(
|
||||
status_code=413,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_TOO_LARGE",
|
||||
detail="Attachment too large",
|
||||
params={"maxBytes": MAX_ATTACHMENT_BYTES},
|
||||
),
|
||||
)
|
||||
|
||||
created = False
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except ApiProblemError as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
created = await self._create_session_if_missing(
|
||||
thread_id=thread_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
else:
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
|
||||
suffix = mime_to_suffix(mime_type)
|
||||
checksum = hashlib.sha1(payload).hexdigest()[:16]
|
||||
filename_seed = filename if isinstance(filename, str) and filename else "upload"
|
||||
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
|
||||
path = (
|
||||
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
f"{filename_hash}-{checksum}.{suffix}"
|
||||
)
|
||||
bucket_name = config.storage.attachment.bucket
|
||||
try:
|
||||
stored_path = await self._attachment_storage.upload_bytes(
|
||||
bucket=bucket_name,
|
||||
path=path,
|
||||
content=payload,
|
||||
content_type=mime_type,
|
||||
)
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=bucket_name,
|
||||
path=stored_path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
if created:
|
||||
await self._repository.rollback()
|
||||
logger.exception(
|
||||
"Attachment upload failed",
|
||||
extra={
|
||||
"bucket": bucket_name,
|
||||
"path": path,
|
||||
"mime_type": mime_type,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=502,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_UPLOAD_FAILED",
|
||||
detail="Failed to upload attachment",
|
||||
),
|
||||
)
|
||||
|
||||
if created:
|
||||
await self._repository.commit()
|
||||
|
||||
return {
|
||||
"bucket": bucket_name,
|
||||
"path": stored_path,
|
||||
"mimeType": mime_type,
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def _create_session_if_missing(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> bool:
|
||||
try:
|
||||
await self._repository.create_session_for_user(
|
||||
user_id=str(current_user.id),
|
||||
session_id=thread_id,
|
||||
)
|
||||
except IntegrityError:
|
||||
await self._repository.rollback()
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _enforce_run_preconditions(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> None:
|
||||
await self._points_service.ensure_run_points_available(user_id=current_user.id)
|
||||
|
||||
user_message_count = await self._repository.get_user_message_count(
|
||||
session_id=thread_id
|
||||
)
|
||||
if user_message_count >= MAX_RUNS_PER_SESSION:
|
||||
raise ApiProblemError(
|
||||
status_code=409,
|
||||
detail=problem_payload(
|
||||
code="AGENT_SESSION_RUN_LIMIT_EXCEEDED",
|
||||
detail="Session run limit exceeded",
|
||||
params={"maxRuns": MAX_RUNS_PER_SESSION},
|
||||
),
|
||||
)
|
||||
|
||||
async def create_attachment_signed_url(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
path: str,
|
||||
current_user: CurrentUser,
|
||||
) -> dict[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
|
||||
detail="Attachment storage unavailable",
|
||||
),
|
||||
)
|
||||
normalized_bucket = bucket.strip()
|
||||
if normalized_bucket != config.storage.attachment.bucket:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_BUCKET_INVALID",
|
||||
detail="Invalid attachment bucket",
|
||||
),
|
||||
)
|
||||
|
||||
normalized_path = path.strip()
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/"
|
||||
if not is_safe_attachment_path(
|
||||
normalized_path, expected_prefix=expected_prefix
|
||||
):
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_PATH_SCOPE_INVALID",
|
||||
detail="Invalid attachment path scope",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=normalized_bucket,
|
||||
path=normalized_path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Attachment signed URL generation failed",
|
||||
extra={
|
||||
"bucket": normalized_bucket,
|
||||
"path": normalized_path,
|
||||
"user_id": str(current_user.id),
|
||||
},
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=502,
|
||||
detail=problem_payload(
|
||||
code="AGENT_SIGNED_URL_GENERATION_FAILED",
|
||||
detail="Failed to generate signed URL",
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"bucket": normalized_bucket,
|
||||
"path": normalized_path,
|
||||
"url": signed_url,
|
||||
}
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
return await self._stream.read(
|
||||
session_id=thread_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
async def get_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
before: date | None,
|
||||
current_user: CurrentUser,
|
||||
) -> HistorySnapshotResponse:
|
||||
from schemas.domain.chat_message import AgentChatMessage
|
||||
from v1.agent.utils import convert_message_to_history
|
||||
from v1.agent.schemas import HistoryMessage
|
||||
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
ensure_session_owner(owner_id=owner, current_user=current_user)
|
||||
day_payload = await self._repository.get_history_day(
|
||||
session_id=thread_id,
|
||||
before=before,
|
||||
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
|
||||
)
|
||||
|
||||
messages: list[HistoryMessage] = []
|
||||
if day_payload:
|
||||
raw_messages_obj = day_payload.get("messages")
|
||||
raw_messages = (
|
||||
raw_messages_obj if isinstance(raw_messages_obj, list) else []
|
||||
)
|
||||
for msg_dict in raw_messages:
|
||||
msg = AgentChatMessage.model_validate(msg_dict)
|
||||
if msg.role == "tool":
|
||||
continue
|
||||
|
||||
signed_urls: dict[str, str] = {}
|
||||
attachments = extract_user_message_attachments(msg.metadata)
|
||||
if self._attachment_storage and attachments:
|
||||
expected_prefix = (
|
||||
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
)
|
||||
for attachment in attachments:
|
||||
if not is_safe_attachment_path(
|
||||
attachment.path,
|
||||
expected_prefix=expected_prefix,
|
||||
):
|
||||
continue
|
||||
signed_url = await self._attachment_storage.create_signed_url(
|
||||
bucket=attachment.bucket,
|
||||
path=attachment.path,
|
||||
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
|
||||
)
|
||||
key = f"{attachment.bucket}/{attachment.path}"
|
||||
signed_urls[key] = signed_url
|
||||
|
||||
def _get_signed_url(payload: dict[str, str]) -> str:
|
||||
key = f"{payload['bucket']}/{payload['path']}"
|
||||
return signed_urls[key]
|
||||
|
||||
converted = convert_message_to_history(msg, _get_signed_url)
|
||||
messages.append(HistoryMessage.model_validate(converted))
|
||||
|
||||
return HistorySnapshotResponse(
|
||||
scope="history_day",
|
||||
threadId=thread_id,
|
||||
day=str(day_payload.get("day"))
|
||||
if day_payload and day_payload.get("day")
|
||||
else None,
|
||||
hasMore=bool(day_payload.get("hasMore")) if day_payload else False,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
async def get_user_history_snapshot(
|
||||
self,
|
||||
*,
|
||||
current_user: CurrentUser,
|
||||
thread_id: str | None,
|
||||
before: date | None,
|
||||
) -> HistorySnapshotResponse:
|
||||
target_thread_id = thread_id
|
||||
if target_thread_id is None:
|
||||
target_thread_id = await self._repository.get_latest_session_id_for_user(
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
if target_thread_id is None:
|
||||
return HistorySnapshotResponse(
|
||||
scope="history_day",
|
||||
threadId=None,
|
||||
day=None,
|
||||
hasMore=False,
|
||||
messages=[],
|
||||
)
|
||||
return await self.get_history_snapshot(
|
||||
thread_id=target_thread_id,
|
||||
before=before,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
def _validate_binary_signed_url(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
thread_id: str,
|
||||
current_user: CurrentUser,
|
||||
) -> tuple[str, str]:
|
||||
if self._attachment_storage is None:
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
|
||||
detail="Attachment storage unavailable",
|
||||
),
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
expected_host = urlparse(config.supabase.url).netloc
|
||||
if parsed.netloc != expected_host:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="INVALID_BINARY_URL_HOST",
|
||||
detail="Invalid binary url host",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
bucket, path = self._attachment_storage.parse_signed_url(url)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="AGENT_SIGNED_IMAGE_URL_INVALID",
|
||||
detail="Invalid signed image url",
|
||||
),
|
||||
) from exc
|
||||
|
||||
if bucket != config.storage.attachment.bucket:
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="INVALID_BINARY_URL_BUCKET",
|
||||
detail="Invalid binary url bucket",
|
||||
),
|
||||
)
|
||||
|
||||
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
|
||||
if not is_safe_attachment_path(path, expected_prefix=expected_prefix):
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="INVALID_BINARY_URL_PATH_SCOPE",
|
||||
detail="Invalid binary url path scope",
|
||||
),
|
||||
)
|
||||
return bucket, path
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
System agents 配置加载工具
|
||||
|
||||
从 system_agents.yaml 加载配置并构建 RuntimeConfig
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.agent.runtime_config import (
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MessageContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
|
||||
|
||||
def _default_system_agents_path() -> Path:
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "system_agents.yaml"
|
||||
)
|
||||
|
||||
|
||||
def _load_system_agents_yaml(path: Path | None = None) -> dict[str, object]:
|
||||
target_path = path or _default_system_agents_path()
|
||||
with target_path.open("r", encoding="utf-8") as f:
|
||||
loaded = yaml.safe_load(f) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"Invalid system agents format: {target_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def _parse_context_messages_config(
|
||||
yaml_config: dict[str, object] | None,
|
||||
) -> MessageContextConfig:
|
||||
if not yaml_config:
|
||||
return MessageContextConfig()
|
||||
raw_mode = yaml_config.get("mode", "day")
|
||||
mode_str = raw_mode if isinstance(raw_mode, str) else "day"
|
||||
raw_count = yaml_config.get("count", 2)
|
||||
count = raw_count if isinstance(raw_count, int) else 2
|
||||
try:
|
||||
source = ContextSource.LATEST_CHAT
|
||||
except ValueError:
|
||||
source = ContextSource.LATEST_CHAT
|
||||
try:
|
||||
window_mode = ContextWindowMode(mode_str)
|
||||
except ValueError:
|
||||
window_mode = ContextWindowMode.DAY
|
||||
return MessageContextConfig(
|
||||
source=source,
|
||||
window_mode=window_mode,
|
||||
window_count=count,
|
||||
)
|
||||
|
||||
|
||||
def build_runtime_config_from_system_agents(
|
||||
yaml_path: Path | None = None,
|
||||
) -> RuntimeConfig:
|
||||
"""
|
||||
从 system_agents.yaml 构建 RuntimeConfig
|
||||
|
||||
仅使用 worker 配置:
|
||||
- worker.context_messages 配置上下文窗口
|
||||
- enabled_tools 固定为空(eryao 不启用自定义工具)
|
||||
"""
|
||||
raw = _load_system_agents_yaml(yaml_path)
|
||||
raw_agents = raw.get("agents", [])
|
||||
agents_list = raw_agents if isinstance(raw_agents, list) else []
|
||||
|
||||
worker_config: SystemAgentLLMConfig | None = None
|
||||
|
||||
for agent in agents_list:
|
||||
if not isinstance(agent, dict):
|
||||
continue
|
||||
agent_type = str(agent.get("agent_type", "")).strip().lower()
|
||||
if agent_type == "worker":
|
||||
config_dict = agent.get("config") or {}
|
||||
try:
|
||||
worker_config = SystemAgentLLMConfig.model_validate(config_dict)
|
||||
except ValidationError:
|
||||
worker_config = SystemAgentLLMConfig()
|
||||
|
||||
context_cfg = _parse_context_messages_config(
|
||||
worker_config.context_messages.model_dump() if worker_config else None
|
||||
)
|
||||
|
||||
return RuntimeConfig(
|
||||
enabled_tools=[],
|
||||
context=context_cfg,
|
||||
)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
历史消息转换工具函数
|
||||
|
||||
将数据库中的原始消息转换为 API 响应的数据结构
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
|
||||
from schemas.domain.chat_message import (
|
||||
AgentChatMessage,
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
|
||||
ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
|
||||
MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
|
||||
MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
|
||||
MAX_ATTACHMENTS_PER_MESSAGE = 3
|
||||
|
||||
|
||||
def convert_message_to_history(
|
||||
message: AgentChatMessage,
|
||||
get_signed_url_fn: Callable[[dict[str, str]], str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
将 AgentChatMessage 转换为 HistoryMessage 格式
|
||||
|
||||
转换规则:
|
||||
- role=user: 读取 metadata.user_message_attachments,转换为 attachments[]
|
||||
- role=assistant: 读取 metadata.agent_output.ui_hints,编译成 ui_schema
|
||||
"""
|
||||
role = message.role
|
||||
content = message.content
|
||||
metadata = message.metadata
|
||||
|
||||
attachments: list[dict[str, str]] = []
|
||||
ui_schema: dict[str, Any] | None = None
|
||||
|
||||
if role == "user":
|
||||
attachments = _convert_user_attachments(metadata, get_signed_url_fn)
|
||||
|
||||
elif role == "assistant":
|
||||
ui_schema = _compile_worker_ui_hints(metadata)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"id": str(message.id),
|
||||
"seq": message.seq,
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": message.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
if attachments:
|
||||
result["attachments"] = attachments
|
||||
|
||||
if ui_schema:
|
||||
result["ui_schema"] = ui_schema
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _convert_user_attachments(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
get_signed_url_fn: Callable[[dict[str, str]], str] | None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""转换用户附件为临时访问 URL 列表"""
|
||||
if not metadata or not get_signed_url_fn:
|
||||
return []
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
resolved = extract_user_message_attachments(metadata)
|
||||
elif isinstance(metadata, dict):
|
||||
resolved = extract_user_message_attachments(metadata)
|
||||
else:
|
||||
return []
|
||||
|
||||
signed_attachments: list[dict[str, str]] = []
|
||||
for attachment in resolved:
|
||||
try:
|
||||
signed_url = get_signed_url_fn(
|
||||
{"bucket": attachment.bucket, "path": attachment.path}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
signed_attachments.append(
|
||||
{
|
||||
"url": signed_url,
|
||||
"mimeType": attachment.mime_type,
|
||||
}
|
||||
)
|
||||
return signed_attachments
|
||||
|
||||
|
||||
def _compile_worker_ui_hints(
|
||||
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""编译 assistant 消息的 agent ui_hints"""
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
if isinstance(metadata, AgentChatMessageMetadata):
|
||||
agent_output = metadata.agent_output
|
||||
else:
|
||||
agent_output_data = metadata.get("agent_output")
|
||||
if not agent_output_data:
|
||||
return None
|
||||
if isinstance(agent_output_data, dict):
|
||||
raw_ui_schema = agent_output_data.get("ui_schema")
|
||||
if isinstance(raw_ui_schema, dict):
|
||||
return raw_ui_schema
|
||||
from schemas.agent.runtime_models import AgentOutput
|
||||
|
||||
try:
|
||||
agent_output = AgentOutput.model_validate(agent_output_data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not agent_output:
|
||||
return None
|
||||
|
||||
ui_hints = agent_output.ui_hints
|
||||
if not ui_hints:
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = compile_ui_hints(ui_hints)
|
||||
return compiled
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def mime_to_suffix(mime_type: str) -> str:
|
||||
mapping = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/webp": "webp",
|
||||
}
|
||||
return mapping.get(mime_type.lower(), "bin")
|
||||
|
||||
|
||||
def is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
return False
|
||||
if normalized.startswith("/"):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return normalized.startswith(expected_prefix)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from schemas.enums import MemoryType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryRecord:
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
memory_type: MemoryType
|
||||
content: dict
|
||||
|
||||
|
||||
class SQLAlchemyMemoriesRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> MemoryRecord | None:
|
||||
_ = self._session
|
||||
_ = owner_id
|
||||
return None
|
||||
|
||||
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> MemoryRecord | None:
|
||||
_ = self._session
|
||||
_ = owner_id
|
||||
return None
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
content: dict,
|
||||
) -> MemoryRecord:
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
async def update_content(
|
||||
self,
|
||||
memory: MemoryRecord,
|
||||
content: dict | None = None,
|
||||
) -> MemoryRecord:
|
||||
if content is not None:
|
||||
memory.content = content
|
||||
return memory
|
||||
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from schemas.enums import MemoryType
|
||||
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
|
||||
from v1.memories.repository import MemoryRecord, SQLAlchemyMemoriesRepository
|
||||
|
||||
|
||||
class MemoriesService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: SQLAlchemyMemoriesRepository,
|
||||
session: object,
|
||||
current_user: CurrentUser | None,
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
self._current_user = current_user
|
||||
|
||||
def _require_user_id(self):
|
||||
if self._current_user is None:
|
||||
raise ValueError("current user is required")
|
||||
return self._current_user.id
|
||||
|
||||
async def get_all_memories(
|
||||
self,
|
||||
) -> dict[str, UserMemoryContent | WorkProfileContent | None]:
|
||||
owner_id = self._require_user_id()
|
||||
user_memory = await self._repository.get_user_memory_for_owner(
|
||||
owner_id=owner_id
|
||||
)
|
||||
work_memory = await self._repository.get_work_memory_for_owner(
|
||||
owner_id=owner_id
|
||||
)
|
||||
return {
|
||||
"user_memory": UserMemoryContent.model_validate(user_memory.content)
|
||||
if user_memory is not None
|
||||
else None,
|
||||
"work_memory": WorkProfileContent.model_validate(work_memory.content)
|
||||
if work_memory is not None
|
||||
else None,
|
||||
}
|
||||
|
||||
async def get_memory_model(self, *, memory_type: MemoryType) -> MemoryRecord | None:
|
||||
owner_id = self._require_user_id()
|
||||
if memory_type == MemoryType.USER:
|
||||
return await self._repository.get_user_memory_for_owner(owner_id=owner_id)
|
||||
return await self._repository.get_work_memory_for_owner(owner_id=owner_id)
|
||||
|
||||
async def update_user_memory(self, *, content: UserMemoryContent) -> MemoryRecord:
|
||||
owner_id = self._require_user_id()
|
||||
existing = await self._repository.get_user_memory_for_owner(owner_id=owner_id)
|
||||
if existing is not None:
|
||||
return await self._repository.update_content(
|
||||
existing,
|
||||
content.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return await self._repository.create(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.USER,
|
||||
content=content.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
async def update_work_memory(self, *, content: WorkProfileContent) -> MemoryRecord:
|
||||
owner_id = self._require_user_id()
|
||||
existing = await self._repository.get_work_memory_for_owner(owner_id=owner_id)
|
||||
if existing is not None:
|
||||
return await self._repository.update_content(
|
||||
existing,
|
||||
content.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return await self._repository.create(
|
||||
owner_id=owner_id,
|
||||
memory_type=MemoryType.WORK,
|
||||
content=content.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.points_ledger import PointsLedger
|
||||
from models.user_points import UserPoints
|
||||
from schemas.shared.points import ApplyPointsChangeCommand
|
||||
|
||||
|
||||
class PointsRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_or_create_user_points_for_update(
|
||||
self, *, user_id: UUID
|
||||
) -> UserPoints:
|
||||
insert_stmt = (
|
||||
insert(UserPoints)
|
||||
.values(user_id=user_id)
|
||||
.on_conflict_do_nothing(index_elements=[UserPoints.user_id])
|
||||
)
|
||||
await self._session.execute(insert_stmt)
|
||||
|
||||
stmt = select(UserPoints).where(UserPoints.user_id == user_id).with_for_update()
|
||||
return (await self._session.execute(stmt)).scalar_one()
|
||||
|
||||
async def has_ledger_event(self, *, user_id: UUID, event_id: str) -> bool:
|
||||
stmt = select(PointsLedger.id).where(
|
||||
PointsLedger.user_id == user_id,
|
||||
PointsLedger.event_id == event_id,
|
||||
)
|
||||
row = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
async def append_ledger(
|
||||
self,
|
||||
*,
|
||||
command: ApplyPointsChangeCommand,
|
||||
balance_after: int,
|
||||
) -> None:
|
||||
entry = PointsLedger(
|
||||
user_id=command.user_id,
|
||||
direction=command.direction,
|
||||
amount=command.amount,
|
||||
balance_after=balance_after,
|
||||
change_type=command.change_type.value,
|
||||
biz_type=command.biz_type.value if command.biz_type is not None else None,
|
||||
biz_id=command.biz_id,
|
||||
event_id=command.event_id,
|
||||
operator_id=command.operator_id,
|
||||
metadata_json=command.metadata.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
self._session.add(entry)
|
||||
await self._session.flush()
|
||||
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
import hashlib
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from schemas.domain.points import ConsumeLedgerMetadata, PointsChargeSnapshot
|
||||
from schemas.enums import PointsBizType, PointsChangeType, PointsOperatorType
|
||||
from schemas.shared.points import ApplyPointsChangeCommand
|
||||
from v1.points.repository import PointsRepository
|
||||
|
||||
RUN_POINTS_COST = 20
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunChargeResult:
|
||||
charged: bool
|
||||
amount: int
|
||||
balance_after: int
|
||||
event_id: str
|
||||
|
||||
|
||||
class PointsService:
|
||||
def __init__(self, repository: PointsRepository) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def ensure_run_points_available(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
) -> int:
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
balance = int(account.balance)
|
||||
frozen_balance = int(account.frozen_balance)
|
||||
available = balance - frozen_balance
|
||||
if available < RUN_POINTS_COST:
|
||||
raise ApiProblemError(
|
||||
status_code=402,
|
||||
detail=problem_payload(
|
||||
code="POINTS_INSUFFICIENT_BALANCE",
|
||||
detail="Insufficient points for this run",
|
||||
params={
|
||||
"required": RUN_POINTS_COST,
|
||||
"available": max(available, 0),
|
||||
},
|
||||
),
|
||||
)
|
||||
return available
|
||||
|
||||
async def consume_successful_run_points(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
session_id: UUID,
|
||||
run_id: str,
|
||||
operator_id: UUID | None,
|
||||
) -> RunChargeResult:
|
||||
event_source = f"{session_id}:{run_id}".encode("utf-8")
|
||||
event_hash = hashlib.sha1(event_source).hexdigest()
|
||||
event_id = f"chat.run.success:{event_hash}"
|
||||
if await self._repository.has_ledger_event(user_id=user_id, event_id=event_id):
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
return RunChargeResult(
|
||||
charged=False,
|
||||
amount=0,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
account = await self._repository.get_or_create_user_points_for_update(
|
||||
user_id=user_id
|
||||
)
|
||||
balance = int(account.balance)
|
||||
frozen_balance = int(account.frozen_balance)
|
||||
available = balance - frozen_balance
|
||||
if available < RUN_POINTS_COST:
|
||||
raise ApiProblemError(
|
||||
status_code=402,
|
||||
detail=problem_payload(
|
||||
code="POINTS_INSUFFICIENT_BALANCE",
|
||||
detail="Insufficient points for this run",
|
||||
params={
|
||||
"required": RUN_POINTS_COST,
|
||||
"available": max(available, 0),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
account.balance = balance - RUN_POINTS_COST
|
||||
account.lifetime_spent = int(account.lifetime_spent) + RUN_POINTS_COST
|
||||
account.version = int(account.version) + 1
|
||||
|
||||
metadata = ConsumeLedgerMetadata(
|
||||
operator_type=PointsOperatorType.USER,
|
||||
run_id=run_id,
|
||||
charge=PointsChargeSnapshot(
|
||||
message_id=uuid4(),
|
||||
message_seq=1,
|
||||
model_code="agent_run",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cost=Decimal("0"),
|
||||
),
|
||||
ext={"source": "run_success"},
|
||||
)
|
||||
command = ApplyPointsChangeCommand(
|
||||
user_id=user_id,
|
||||
change_type=PointsChangeType.CONSUME,
|
||||
biz_type=PointsBizType.CHAT,
|
||||
biz_id=session_id,
|
||||
event_id=event_id,
|
||||
amount=RUN_POINTS_COST,
|
||||
direction=-1,
|
||||
operator_id=operator_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
await self._repository.append_ledger(
|
||||
command=command,
|
||||
balance_after=int(account.balance),
|
||||
)
|
||||
return RunChargeResult(
|
||||
charged=True,
|
||||
amount=RUN_POINTS_COST,
|
||||
balance_after=int(account.balance),
|
||||
event_id=event_id,
|
||||
)
|
||||
@@ -2,8 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.auth.router import router as auth_router
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(agent_router)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContactInfo:
|
||||
username: str | None
|
||||
phone: str | None
|
||||
|
||||
|
||||
async def resolve_contacts_by_user_ids(
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
profiles_by_id: dict[UUID, object],
|
||||
auth_gateway: object,
|
||||
) -> dict[UUID, ContactInfo]:
|
||||
_ = auth_gateway
|
||||
resolved: dict[UUID, ContactInfo] = {}
|
||||
for user_id in user_ids:
|
||||
profile = profiles_by_id.get(user_id)
|
||||
username = getattr(profile, "username", None) if profile is not None else None
|
||||
resolved[user_id] = ContactInfo(username=username, phone=None)
|
||||
return resolved
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from services.base.supabase import supabase_service
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> CurrentUser:
|
||||
if not authorization:
|
||||
raise ApiProblemError(
|
||||
status_code=401,
|
||||
detail=problem_payload(code="AUTH_UNAUTHORIZED", detail="Unauthorized"),
|
||||
)
|
||||
|
||||
scheme, _, token = authorization.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
raise ApiProblemError(
|
||||
status_code=401,
|
||||
detail=problem_payload(code="AUTH_UNAUTHORIZED", detail="Unauthorized"),
|
||||
)
|
||||
|
||||
try:
|
||||
client = supabase_service.get_client()
|
||||
response = await asyncio.to_thread(client.auth.get_user, token)
|
||||
user = getattr(response, "user", None)
|
||||
user_id = getattr(user, "id", None)
|
||||
if not isinstance(user_id, str) or not user_id:
|
||||
raise ValueError("missing user id")
|
||||
return CurrentUser(
|
||||
id=UUID(user_id),
|
||||
email=getattr(user, "email", None),
|
||||
role=getattr(user, "role", None),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ApiProblemError(
|
||||
status_code=401,
|
||||
detail=problem_payload(code="AUTH_UNAUTHORIZED", detail="Unauthorized"),
|
||||
) from exc
|
||||
|
||||
|
||||
def get_user_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> UserService:
|
||||
_ = session
|
||||
return UserService(current_user=user)
|
||||
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLAlchemyUserRepository:
|
||||
session: object
|
||||
|
||||
async def get_by_user_ids(self, user_ids: list[UUID]) -> dict[UUID, object]:
|
||||
_ = self.session
|
||||
_ = user_ids
|
||||
return {}
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from schemas.shared.user import UserContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserService:
|
||||
current_user: CurrentUser
|
||||
|
||||
async def get_me(self) -> UserContext:
|
||||
user_id = str(self.current_user.id)
|
||||
return UserContext(
|
||||
id=user_id,
|
||||
username=f"user_{user_id[:8]}",
|
||||
email=self.current_user.email,
|
||||
avatar_url=None,
|
||||
bio=None,
|
||||
settings=None,
|
||||
)
|
||||
Reference in New Issue
Block a user