feat: 实现起卦、设置与积分系统

This commit is contained in:
qzl
2026-04-03 16:56:47 +08:00
parent 31594558eb
commit f245eec5f6
170 changed files with 20728 additions and 328 deletions
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+120
View File
@@ -0,0 +1,120 @@
from __future__ import annotations
import asyncio
from typing import Any
import dashscope
from dashscope.audio.asr import Recognition, RecognitionCallback
from core.config.settings import config
from core.logging import get_logger
logger = get_logger(__name__)
class AsrService:
def __init__(self) -> None:
self._api_key: str | None = None
def _get_api_key(self) -> str:
if self._api_key is None:
dashscope_key = config.llm.provider_keys.get("dashscope")
if not dashscope_key:
raise ValueError(
"DASHSCOPE_API_KEY not configured. Set SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE in environment."
)
self._api_key = dashscope_key
return self._api_key
async def transcribe_file(self, file_path: str, filename: str) -> str:
try:
dashscope.api_key = self._get_api_key()
loop = asyncio.get_event_loop()
class SyncCallback(RecognitionCallback):
error: str | None = None
def on_error(self, result: Any) -> None:
self.error = str(result)
callback = SyncCallback()
recognizer = Recognition(
model="fun-asr-realtime-2026-02-28",
callback=callback,
format="wav",
sample_rate=16000,
)
result: Any = await loop.run_in_executor(
None,
lambda: recognizer.call(file=file_path),
)
if callback.error:
raise RuntimeError(f"ASR error: {callback.error}")
status_code = self._extract_field(result, "status_code")
if status_code != 200:
message = self._extract_field(result, "message")
raise RuntimeError(f"ASR transcription failed: {message}")
sentence = self._extract_sentence_payload(result)
if sentence is None:
request_id = self._extract_field(result, "request_id")
logger.warning(
"ASR returned empty result", extra={"request_id": request_id}
)
return ""
if isinstance(sentence, dict):
transcription = sentence.get("text", "")
elif isinstance(sentence, list):
transcription = " ".join(
item.get("text", "") for item in sentence if isinstance(item, dict)
)
else:
transcription = str(sentence) if sentence else ""
logger.info(
"ASR transcription completed",
extra={"filename": filename, "transcript_length": len(transcription)},
)
return transcription
except asyncio.CancelledError:
raise
except RuntimeError:
raise
except Exception as exc:
logger.exception("ASR transcription error")
raise RuntimeError(f"ASR transcription failed: {exc}") from exc
def _extract_sentence_payload(self, result: Any) -> Any | None:
if isinstance(result, dict):
output = result.get("output")
if isinstance(output, dict):
return output.get("sentence")
if output is not None:
return getattr(output, "sentence", None)
return result.get("sentence")
get_sentence = getattr(result, "get_sentence", None)
if callable(get_sentence):
sentence = get_sentence()
if sentence is not None:
return sentence
output = getattr(result, "output", None)
if output is None:
return None
if isinstance(output, dict):
return output.get("sentence")
return getattr(output, "sentence", None)
def _extract_field(self, result: Any, field: str) -> Any | None:
if isinstance(result, dict):
return result.get(field)
return getattr(result, field, None)
asr_service = AsrService()
+153
View File
@@ -0,0 +1,153 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
import json
from typing import Any
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agentscope.events import RedisStreamBus
from core.agentscope.tools.tool_result_storage import (
create_tool_result_storage,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import get_or_init_redis_client
from services.base.supabase import supabase_service
from v1.agent.repository import AgentRepository
from v1.agent.service import AgentService
from v1.points.repository import PointsRepository
from v1.points.service import PointsService
DEDUP_WAIT_RETRIES = 20
DEDUP_WAIT_SECONDS = 0.05
DEDUP_LOCK_SECONDS = 300
DEDUP_INFLIGHT_MARKER = "__inflight__"
RUN_CANCEL_SIGNAL_TTL_SECONDS = 1800
def _event_stream_block_ms() -> int:
configured = int(config.agent_runtime.redis_stream_block_ms)
socket_timeout = float(config.redis.socket_timeout)
socket_timeout_ms = max(int(socket_timeout * 1000), 1)
safe_max = max(socket_timeout_ms - 100, 1)
return max(1, min(configured, safe_max))
class TaskiqQueueClient:
def __init__(self) -> None:
self._redis: Redis | None = None
async def _get_redis(self) -> Redis:
if self._redis is None:
self._redis = await get_or_init_redis_client()
return self._redis
@staticmethod
def _select_queue_task(command: dict[str, object]) -> Any:
from core.agentscope.runtime.tasks import (
run_command_task_agent,
run_command_task_general,
)
queue = str(command.get("queue", "agent")).strip().lower()
if queue == "general":
return run_command_task_general
return run_command_task_agent
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
redis_client = await self._get_redis()
redis_key = None
if dedup_key:
redis_key = f"agent:dedup:{dedup_key}"
locked = await redis_client.set(
redis_key,
DEDUP_INFLIGHT_MARKER,
nx=True,
ex=DEDUP_LOCK_SECONDS,
)
if not locked:
for _ in range(DEDUP_WAIT_RETRIES):
existing = await redis_client.get(redis_key)
if existing and existing != DEDUP_INFLIGHT_MARKER:
return existing
await asyncio.sleep(DEDUP_WAIT_SECONDS)
raise RuntimeError("duplicate request is still in progress")
payload = dict(command)
queue_task = self._select_queue_task(payload)
try:
result = await queue_task.kiq(payload)
task_id = str(result.task_id)
if redis_key is not None:
await redis_client.set(redis_key, task_id, ex=DEDUP_LOCK_SECONDS)
return task_id
except Exception:
if redis_key is not None:
await redis_client.delete(redis_key)
raise
async def request_cancel(
self,
*,
thread_id: str,
run_id: str,
requested_by: str,
) -> None:
redis_client = await self._get_redis()
cancel_key = f"agent:cancel:{thread_id}:{run_id}"
payload = json.dumps(
{
"requested_by": requested_by,
"requested_at": datetime.now(timezone.utc).isoformat(),
},
ensure_ascii=True,
separators=(",", ":"),
)
await redis_client.set(
cancel_key,
payload,
ex=RUN_CANCEL_SIGNAL_TTL_SECONDS,
)
class RedisEventStream:
def __init__(self) -> None:
self._bus: RedisStreamBus | None = None
async def _get_bus(self) -> RedisStreamBus:
if self._bus is None:
client = await get_or_init_redis_client()
self._bus = RedisStreamBus(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=_event_stream_block_ms(),
)
return self._bus
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
bus = await self._get_bus()
rows = await bus.read(session_id=session_id, last_event_id=last_event_id)
return [{**row, "cursor": row.get("id")} for row in rows]
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
tool_result_storage = create_tool_result_storage()
return AgentService(
repository=AgentRepository(session, tool_result_storage=tool_result_storage),
queue=TaskiqQueueClient(),
stream=RedisEventStream(),
points_service=PointsService(repository=PointsRepository(session)),
attachment_storage=supabase_service,
)
+405
View File
@@ -0,0 +1,405 @@
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from typing import Protocol
from uuid import UUID, uuid4
from sqlalchemy import Select, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from core.http.errors import ApiProblemError
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.system_agents import SystemAgents
from schemas.enums import AgentChatMessageRole
from schemas.domain.chat_message import (
AgentChatMessage as AgentChatMessageSchema,
AgentChatMessageMetadata,
)
class ToolResultPayloadStorage(Protocol):
async def read_json(
self, *, bucket: str, path: str
) -> dict[str, object] | None: ...
class AgentRepository:
def __init__(
self,
session: AsyncSession,
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session: AsyncSession = session
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = select(AgentChatSession.user_id).where(
AgentChatSession.id == session_uuid
)
owner_id = (await self._session.execute(stmt)).scalar_one_or_none()
if owner_id is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
return str(owner_id)
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
session_uuid = None
if session_id is not None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
session = AgentChatSession(
id=session_uuid,
user_id=user_uuid,
)
self._session.add(session)
await self._session.flush()
await self._session.refresh(session)
return str(session.id)
async def commit(self) -> None:
await self._session.commit()
async def rollback(self) -> None:
await self._session.rollback()
async def delete_session(self, *, session_id: str) -> None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
session = await self._session.get(AgentChatSession, session_uuid)
if session is not None:
await self._session.delete(session)
await self._session.flush()
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: AgentChatMessageMetadata | None,
visibility_mask: int,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_uuid)
.with_for_update()
)
session_row = (await self._session.execute(stmt)).scalar_one_or_none()
if session_row is None:
raise ApiProblemError(
status_code=404,
code="AGENT_SESSION_NOT_FOUND",
detail="Session not found",
)
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content)
if session_title is not None:
session_row.title = session_title
message = OrmAgentChatMessage(
id=uuid4(),
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content,
visibility_mask=max(int(visibility_mask), 0),
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
session_row.message_count = next_seq
session_row.last_activity_at = datetime.now(timezone.utc)
await self._session.flush()
async def get_user_message_count(self, *, session_id: str) -> int:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
stmt = (
select(func.count(AgentChatMessage.id))
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.role == AgentChatMessageRole.USER)
)
count = (await self._session.execute(stmt)).scalar_one()
return int(count)
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
before_start = (
datetime.combine(before, time.min, tzinfo=timezone.utc)
if before is not None
else None
)
target_created_at_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
target_created_at_stmt = self._apply_visibility_filter(
stmt=target_created_at_stmt,
visibility_mask=visibility_mask,
)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
)
target_created_at = (
await self._session.execute(target_created_at_stmt)
).scalar_one_or_none()
if target_created_at is None:
return None
target_day = target_created_at.astimezone(timezone.utc).date()
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at >= start)
.where(AgentChatMessage.created_at < end)
.order_by(AgentChatMessage.seq.asc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more_stmt = (
select(AgentChatMessage.id)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more_stmt = self._apply_visibility_filter(
stmt=has_more_stmt,
visibility_mask=visibility_mask,
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
return {
"day": target_day.isoformat(),
"hasMore": has_more,
"messages": snapshot_messages,
}
async def get_recent_messages_by_user_window(
self,
*,
session_id: str,
user_message_limit: int,
visibility_mask: int | None = None,
) -> list[dict[str, object]]:
try:
session_uuid = UUID(session_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_SESSION_ID_INVALID",
detail="Invalid session_id",
) from exc
safe_user_limit = max(int(user_message_limit), 1)
message_stmt = (
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.seq.desc())
)
message_stmt = self._apply_visibility_filter(
stmt=message_stmt,
visibility_mask=visibility_mask,
)
messages_desc = (await self._session.execute(message_stmt)).scalars().all()
if not messages_desc:
return []
selected_desc: list[AgentChatMessage] = []
user_count = 0
for message in messages_desc:
selected_desc.append(message)
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
if role == AgentChatMessageRole.USER.value:
user_count += 1
if user_count >= safe_user_limit:
break
selected = list(reversed(selected_desc))
snapshot_messages: list[dict[str, object]] = []
for message in selected:
snapshot_messages.append(await self._to_snapshot_message(message))
return snapshot_messages
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None:
try:
user_uuid = UUID(user_id)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
code="AGENT_USER_ID_INVALID",
detail="Invalid user_id",
) from exc
stmt = (
select(AgentChatSession.id)
.where(AgentChatSession.user_id == user_uuid)
.where(AgentChatSession.deleted_at.is_(None))
.order_by(AgentChatSession.last_activity_at.desc())
.limit(1)
)
latest_id = (await self._session.execute(stmt)).scalar_one_or_none()
if latest_id is None:
return None
return str(latest_id)
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None:
normalized_type = agent_type.strip().lower()
if not normalized_type:
return None
stmt = select(SystemAgents).where(SystemAgents.agent_type == normalized_type)
row = (await self._session.execute(stmt)).scalar_one_or_none()
if row is None:
return None
config_payload = row.config if isinstance(row.config, dict) else {}
return {
"agent_type": normalized_type,
"status": str(row.status),
"config": config_payload,
}
async def _to_snapshot_message(
self, message: AgentChatMessage
) -> dict[str, object]:
role = (
message.role.value
if isinstance(message.role, AgentChatMessageRole)
else str(message.role)
)
payload_model = AgentChatMessageSchema.model_validate(
{
"id": str(message.id),
"seq": int(message.seq),
"role": role,
"content": message.content,
"model_code": message.model_code,
"tool_name": message.tool_name,
"input_tokens": int(message.input_tokens or 0),
"output_tokens": int(message.output_tokens or 0),
"cost": str(message.cost if message.cost is not None else Decimal("0")),
"latency_ms": message.latency_ms,
"metadata": message.metadata_json,
"timestamp": message.created_at.astimezone(timezone.utc).isoformat(),
}
)
return payload_model.model_dump(mode="json", exclude_none=True)
def _apply_visibility_filter(
self,
*,
stmt: Select,
visibility_mask: int | None,
) -> Select:
if visibility_mask is None:
return stmt
required_mask = max(int(visibility_mask), 0)
if required_mask == 0:
return stmt
return stmt.where(
(AgentChatMessage.visibility_mask.op("&")(required_mask)) != 0
)
def _has_title(title: object) -> bool:
return isinstance(title, str) and bool(title.strip())
def _derive_session_title(content_text: str) -> str | None:
normalized = " ".join(content_text.split())
if not normalized:
return None
return normalized[:80]
+473
View File
@@ -0,0 +1,473 @@
from __future__ import annotations
import asyncio
import os
import re
import tempfile
from collections.abc import AsyncIterator
from datetime import date
from typing import Annotated
from ag_ui.core import RunAgentInput
from core.http.errors import ApiProblemError, problem_payload
from core.agentscope.events import to_sse_event
from core.agentscope.schemas.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
)
from core.auth.models import CurrentUser
from core.logging import get_logger
from redis.exceptions import TimeoutError as RedisTimeoutError
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
Query,
Request,
UploadFile,
status,
)
from fastapi.responses import StreamingResponse
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
AsrTranscribeResponse,
AttachmentReference,
AttachmentSignedUrlResponse,
AttachmentUploadResponse,
CancelRunResponse,
HistorySnapshotResponse,
TaskAcceptedResponse,
)
from v1.agent.asr import asr_service
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
logger = get_logger("v1.agent.router")
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,128}$")
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
"audio/x-wav",
"audio/wave",
}
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
async def _acquire_sse_slot(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
return True
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE slot acquire failed",
user_id=user_id,
reason=str(exc),
)
return True
async def _release_sse_slot(*, user_id: str) -> None:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if count <= 0:
await redis.delete(key)
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE slot release failed",
user_id=user_id,
reason=str(exc),
)
return None
def _is_terminal_run_event(event: dict[str, object]) -> bool:
raw_event_type = event.get("type")
return (
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
)
def _is_target_run_event(event: dict[str, object], *, target_run_id: str) -> bool:
run_id = event.get("runId")
return isinstance(run_id, str) and run_id == target_run_id
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True))
except ValueError as exc:
raise ApiProblemError(
status_code=422,
detail=problem_payload(code="AGENT_RUN_INPUT_INVALID", detail=str(exc)),
) from exc
try:
validate_run_request_messages_contract(request)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
detail=problem_payload(code="AGENT_RUN_MESSAGES_INVALID", detail=str(exc)),
) from exc
task = await service.enqueue_run(
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
taskId=task.task_id,
threadId=task.thread_id,
runId=task.run_id,
created=task.created,
)
@router.post(
"/runs/{thread_id}/cancel",
response_model=CancelRunResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def cancel_run(
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
run_id: str = Query(
alias="runId",
min_length=1,
max_length=128,
pattern=r"^[A-Za-z0-9_-]+$",
),
) -> CancelRunResponse:
canceled = await service.cancel_run(
thread_id=thread_id,
run_id=run_id,
current_user=current_user,
)
return CancelRunResponse(
threadId=canceled.thread_id,
runId=canceled.run_id,
accepted=canceled.accepted,
)
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
run_id: str | None = Query(default=None, alias="runId"),
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if run_id is None or _RUN_ID_RE.fullmatch(run_id) is None:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_INVALID_RUN_ID",
detail="Invalid runId",
),
)
if last_event_id is not None and (
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
):
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_INVALID_LAST_EVENT_ID",
detail="Invalid Last-Event-ID",
),
)
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
if not sse_slot_acquired:
raise ApiProblemError(
status_code=429,
detail=problem_payload(
code="AGENT_SSE_CONNECTION_LIMIT",
detail="Too many SSE connections",
),
)
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
terminal_event_reached = False
try:
while (
not terminal_event_reached
and not await request.is_disconnected()
and idle_polls < idle_limit
):
try:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
except (TimeoutError, RedisTimeoutError):
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE stream read failed",
thread_id=thread_id,
user_id=str(current_user.id),
reason=str(exc),
)
break
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
if not _is_target_run_event(event, target_run_id=run_id):
continue
yield to_sse_event(row_id, event)
if _is_terminal_run_event(event):
terminal_event_reached = True
break
finally:
await _release_sse_slot(user_id=str(current_user.id))
return StreamingResponse(
_event_iter(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/history", response_model=HistorySnapshotResponse)
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str | None = Query(default=None, alias="threadId"),
before: date | None = Query(default=None),
) -> HistorySnapshotResponse:
return await service.get_user_history_snapshot(
current_user=current_user,
thread_id=thread_id,
before=before,
)
@router.post(
"/attachments",
response_model=AttachmentUploadResponse,
status_code=status.HTTP_200_OK,
)
async def upload_attachment(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str = Form(alias="threadId"),
file: UploadFile = File(),
) -> AttachmentUploadResponse:
payload = await file.read()
if not payload:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENT_EMPTY",
detail="Empty attachment",
),
)
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
raise ApiProblemError(
status_code=413,
detail=problem_payload(
code="AGENT_ATTACHMENT_TOO_LARGE",
detail="Attachment too large",
params={"maxBytes": _MAX_ATTACHMENT_UPLOAD_BYTES},
),
)
attachment = await service.upload_attachment(
thread_id=thread_id,
filename=file.filename,
content_type=file.content_type,
payload=payload,
current_user=current_user,
)
return AttachmentUploadResponse(
attachment=AttachmentReference.model_validate(attachment),
)
@router.get(
"/attachments/signed-url",
response_model=AttachmentSignedUrlResponse,
status_code=status.HTTP_200_OK,
)
async def create_attachment_signed_url(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
bucket: str = Query(min_length=1, max_length=100),
path: str = Query(min_length=1, max_length=500),
) -> AttachmentSignedUrlResponse:
signed = await service.create_attachment_signed_url(
bucket=bucket,
path=path,
current_user=current_user,
)
return AttachmentSignedUrlResponse(**signed)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
status_code=status.HTTP_200_OK,
)
async def transcribe(
audio: UploadFile,
request: Request,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AsrTranscribeResponse:
temp_path: str | None = None
try:
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="AGENT_AUDIO_UNSUPPORTED_FORMAT",
detail="Unsupported audio format",
),
)
content_length = request.headers.get("content-length")
if content_length is not None:
try:
declared_length = int(content_length)
except ValueError:
declared_length = None
if (
declared_length is not None
and declared_length
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
):
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="AGENT_AUDIO_TOO_LARGE",
detail="Audio file too large",
params={"maxBytes": _MAX_TRANSCRIBE_AUDIO_BYTES},
),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
temp_path = tmp_file.name
total_bytes = 0
header = bytearray()
while True:
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
if not chunk:
break
total_bytes += len(chunk)
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="AGENT_AUDIO_TOO_LARGE",
detail="Audio file too large",
params={"maxBytes": _MAX_TRANSCRIBE_AUDIO_BYTES},
),
)
if len(header) < _WAV_HEADER_MIN_BYTES:
required = _WAV_HEADER_MIN_BYTES - len(header)
header.extend(chunk[:required])
tmp_file.write(chunk)
if total_bytes == 0:
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="AGENT_AUDIO_EMPTY",
detail="Empty audio file",
),
)
if not _looks_like_wav_header(bytes(header)):
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="AGENT_AUDIO_UNSUPPORTED_FORMAT",
detail="Unsupported audio format",
),
)
transcript = await asr_service.transcribe_file(
temp_path, audio.filename or "unknown"
)
return AsrTranscribeResponse(transcript=transcript)
except ApiProblemError:
raise
except RuntimeError:
raise ApiProblemError(
status_code=502,
detail=problem_payload(
code="AGENT_ASR_UNAVAILABLE",
detail="ASR service unavailable",
),
)
finally:
await audio.close()
if temp_path:
try:
os.unlink(temp_path)
except OSError:
pass
+206
View File
@@ -0,0 +1,206 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date
from typing import Any, Literal, Protocol
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from schemas.agent.ui_schema import UiSchemaRenderer
class AgentRepositoryLike(Protocol):
async def get_session_owner(self, *, session_id: str) -> str: ...
async def create_session_for_user(
self, *, user_id: str, session_id: str | None = None
) -> str: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def get_history_day(
self,
*,
session_id: str,
before: date | None,
visibility_mask: int | None = None,
) -> dict[str, object] | None: ...
async def get_latest_session_id_for_user(self, *, user_id: str) -> str | None: ...
async def persist_user_message(
self,
*,
session_id: str,
content: str,
metadata: Any,
visibility_mask: int,
) -> None: ...
async def get_user_message_count(self, *, session_id: str) -> int: ...
async def get_system_agent_config(
self, *, agent_type: str
) -> dict[str, object] | None: ...
class QueueClientLike(Protocol):
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str: ...
async def request_cancel(
self,
*,
thread_id: str,
run_id: str,
requested_by: str,
) -> None: ...
class EventStreamLike(Protocol):
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, object]]: ...
class PointsServiceLike(Protocol):
async def ensure_run_points_available(
self,
*,
user_id: UUID,
) -> int: ...
async def consume_successful_run_points(
self,
*,
user_id: UUID,
session_id: UUID,
run_id: str,
operator_id: UUID | None,
) -> Any: ...
class AttachmentStorageLike(Protocol):
async def upload_bytes(
self,
*,
bucket: str,
path: str,
content: bytes,
content_type: str,
) -> str: ...
async def download_bytes(self, *, bucket: str, path: str) -> bytes: ...
async def create_signed_url(
self,
*,
bucket: str,
path: str,
expires_in_seconds: int,
) -> str: ...
def parse_signed_url(self, url: str) -> tuple[str, str]: ...
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
thread_id: str
run_id: str
created: bool
@dataclass(frozen=True)
class CancelRequested:
thread_id: str
run_id: str
accepted: bool
class TaskAcceptedResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
task_id: str = Field(alias="taskId")
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
created: bool
class CancelRunResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
thread_id: str = Field(alias="threadId")
run_id: str = Field(alias="runId")
accepted: bool
class AsrTranscribeResponse(BaseModel):
transcript: str = Field(description="Transcribed text from audio")
class AttachmentReference(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
bucket: str
path: str
mime_type: str = Field(alias="mimeType")
url: str
class AttachmentUploadResponse(BaseModel):
attachment: AttachmentReference
class AttachmentSignedUrlResponse(BaseModel):
bucket: str
path: str
url: str
class HistoryMessageAttachment(BaseModel):
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
mime_type: str = Field(alias="mimeType")
url: str
class HistoryMessage(BaseModel):
"""History message schema for /history endpoint response."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
id: str = Field(description="Message UUID")
seq: int = Field(description="Message sequence number")
role: Literal["user", "assistant"] = Field(
description="Message role: user | assistant"
)
content: str = Field(description="Message text content")
attachments: list[HistoryMessageAttachment] = Field(
default_factory=list,
description="Temporary signed URLs for user-attached images",
)
ui_schema: UiSchemaRenderer | None = Field(
default=None,
description="Compiled UI schema from worker ui_hints for frontend rendering",
)
timestamp: str = Field(description="Message creation timestamp in ISO-8601 format")
class HistorySnapshotResponse(BaseModel):
"""Response schema for GET /api/v1/agent/history"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
scope: str = Field(default="history_day")
thread_id: str | None = Field(default=None, alias="threadId")
day: str | None = None
has_more: bool = Field(default=False, alias="hasMore")
messages: list[HistoryMessage] = Field(default_factory=list)
+706
View File
@@ -0,0 +1,706 @@
from __future__ import annotations
from datetime import date, datetime, timezone
import hashlib
from urllib.parse import urlparse
from ag_ui.core import RunAgentInput
from sqlalchemy.exc import IntegrityError
from core.http.errors import ApiProblemError, problem_payload
from core.auth.models import CurrentUser
from core.agentscope.caches.context_messages_cache import (
create_context_messages_cache,
)
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
from schemas.agent.forwarded_props import (
parse_forwarded_props_runtime_mode,
RuntimeMode,
)
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
from schemas.agent.runtime_config import RuntimeConfig
from schemas.domain.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachment,
extract_user_message_attachments,
)
from v1.agent.schemas import (
AgentRepositoryLike,
AttachmentStorageLike,
CancelRequested,
EventStreamLike,
HistorySnapshotResponse,
PointsServiceLike,
QueueClientLike,
TaskAccepted,
)
from v1.agent.utils import (
MAX_ATTACHMENT_BYTES,
MAX_ATTACHMENTS_PER_MESSAGE,
is_safe_attachment_path,
mime_to_suffix,
)
logger = get_logger(__name__)
MAX_RUNS_PER_SESSION = 4
def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
if owner_id != str(current_user.id):
raise ApiProblemError(
status_code=403,
detail=problem_payload(code="AGENT_FORBIDDEN", detail="Forbidden"),
)
class AgentService:
_repository: AgentRepositoryLike
_queue: QueueClientLike
_stream: EventStreamLike
_points_service: PointsServiceLike
_attachment_storage: AttachmentStorageLike | None
_SIGNED_URL_EXPIRES_IN_SECONDS = 3600
def __init__(
self,
*,
repository: AgentRepositoryLike,
queue: QueueClientLike,
stream: EventStreamLike,
points_service: PointsServiceLike,
attachment_storage: AttachmentStorageLike | None = None,
) -> None:
self._repository = repository
self._queue = queue
self._stream = stream
self._points_service = points_service
self._attachment_storage = attachment_storage
async def enqueue_run(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
runtime_config: RuntimeConfig | None = None,
) -> TaskAccepted:
created = False
thread_id = run_input.thread_id
run_id = run_input.run_id
forwarded_props = getattr(run_input, "forwarded_props", None)
try:
runtime_mode = parse_forwarded_props_runtime_mode(forwarded_props)
except ValueError as exc:
raise ApiProblemError(
status_code=422,
detail=problem_payload(code="AGENT_PAYLOAD_INVALID", detail=str(exc)),
) from exc
if runtime_config is None:
from v1.agent.system_agents_config import (
build_runtime_config_from_system_agents,
)
runtime_config = build_runtime_config_from_system_agents()
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except ApiProblemError as exc:
if exc.status_code != 404:
raise
created = await self._create_session_if_missing(
thread_id=thread_id,
current_user=current_user,
)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
try:
await self._enforce_run_preconditions(
thread_id=thread_id,
current_user=current_user,
)
except ApiProblemError:
if created:
await self._repository.rollback()
raise
user_message_text, user_message_metadata = await self._prepare_user_message(
run_input=run_input,
current_user=current_user,
)
visibility_mask = await self._resolve_user_message_visibility_mask(
runtime_mode=runtime_mode
)
await self._repository.persist_user_message(
session_id=thread_id,
content=user_message_text,
metadata=user_message_metadata,
visibility_mask=visibility_mask,
)
await self._repository.commit()
await self._append_context_cache_user_message(
thread_id=thread_id,
runtime_mode=runtime_mode,
visibility_mask=visibility_mask,
content=user_message_text,
metadata=user_message_metadata,
)
task_id = await self._queue.enqueue(
command={
"command": "run",
"owner_id": str(current_user.id),
"owner_email": current_user.email,
"run_input": run_input.model_dump(
mode="json", by_alias=True, exclude_none=True
),
"runtime_config": runtime_config.model_dump(
mode="json", by_alias=True, exclude_none=True
),
"queue": "agent",
},
dedup_key=None,
)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_id,
created=created,
)
async def cancel_run(
self,
*,
thread_id: str,
run_id: str,
current_user: CurrentUser,
) -> CancelRequested:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
await self._queue.request_cancel(
thread_id=thread_id,
run_id=run_id,
requested_by=str(current_user.id),
)
return CancelRequested(
thread_id=thread_id,
run_id=run_id,
accepted=True,
)
async def _append_context_cache_user_message(
self,
*,
thread_id: str,
runtime_mode: RuntimeMode,
visibility_mask: int,
content: str,
metadata: AgentChatMessageMetadata | None,
) -> None:
metadata_payload = (
metadata.model_dump(mode="json", exclude_none=True)
if isinstance(metadata, AgentChatMessageMetadata)
else None
)
message_payload: dict[str, object] = {
"role": "user",
"content": content,
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
}
if isinstance(metadata_payload, dict):
message_payload["metadata"] = metadata_payload
try:
context_cache = create_context_messages_cache()
await context_cache.append_message(
thread_id=thread_id,
runtime_mode=runtime_mode.value,
visibility_mask=visibility_mask,
message=message_payload,
)
except Exception as exc:
logger.warning(
"Failed to append user message to context cache",
thread_id=thread_id,
runtime_mode=runtime_mode.value,
error=str(exc),
)
async def _resolve_user_message_visibility_mask(
self, *, runtime_mode: RuntimeMode
) -> int:
if runtime_mode == RuntimeMode.CHAT:
return bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)) | bit_mask(
bit=int(SystemVisibilityBit.CONTEXT_ASSEMBLY)
)
return 0
async def _prepare_user_message(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> tuple[str, AgentChatMessageMetadata | None]:
text, content_blocks = extract_latest_user_payload(run_input)
user_attachments: list[UserMessageAttachment] = []
for block in content_blocks:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type != "binary":
continue
url = block.get("url")
mime_type = block.get("mimeType")
if not isinstance(url, str) or not url:
continue
if not isinstance(mime_type, str):
mime_type = "application/octet-stream"
if self._attachment_storage is None:
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
detail="Attachment storage unavailable",
),
)
try:
bucket, path = self._validate_binary_signed_url(
url=url,
thread_id=run_input.thread_id,
current_user=current_user,
)
user_attachments.append(
UserMessageAttachment(
bucket=bucket,
path=path,
mime_type=mime_type,
)
)
if len(user_attachments) > MAX_ATTACHMENTS_PER_MESSAGE:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENTS_TOO_MANY",
detail="Too many attachments",
params={"max": MAX_ATTACHMENTS_PER_MESSAGE},
),
)
except ApiProblemError:
raise
except Exception as exc: # noqa: BLE001
parsed = urlparse(url)
safe_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
logger.warning(
"Failed to parse signed URL", url=safe_url, error=str(exc)
)
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_SIGNED_IMAGE_URL_INVALID",
detail="Invalid signed image url",
),
)
metadata: AgentChatMessageMetadata | None = None
if user_attachments:
metadata = AgentChatMessageMetadata(
run_id=run_input.run_id,
user_message_attachments=user_attachments,
)
return text, metadata
async def upload_attachment(
self,
*,
thread_id: str,
filename: str | None,
content_type: str | None,
payload: bytes,
current_user: CurrentUser,
) -> dict[str, str]:
if self._attachment_storage is None:
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
detail="Attachment storage unavailable",
),
)
if not isinstance(content_type, str):
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENT_UNSUPPORTED_TYPE",
detail="Unsupported attachment type",
),
)
mime_type = content_type.lower()
if mime_type not in {"image/png", "image/jpeg", "image/webp"}:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENT_UNSUPPORTED_TYPE",
detail="Unsupported attachment type",
),
)
if not payload:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENT_EMPTY",
detail="Empty attachment",
),
)
if len(payload) > MAX_ATTACHMENT_BYTES:
raise ApiProblemError(
status_code=413,
detail=problem_payload(
code="AGENT_ATTACHMENT_TOO_LARGE",
detail="Attachment too large",
params={"maxBytes": MAX_ATTACHMENT_BYTES},
),
)
created = False
try:
owner = await self._repository.get_session_owner(session_id=thread_id)
except ApiProblemError as exc:
if exc.status_code != 404:
raise
created = await self._create_session_if_missing(
thread_id=thread_id,
current_user=current_user,
)
else:
ensure_session_owner(owner_id=owner, current_user=current_user)
suffix = mime_to_suffix(mime_type)
checksum = hashlib.sha1(payload).hexdigest()[:16]
filename_seed = filename if isinstance(filename, str) and filename else "upload"
filename_hash = hashlib.sha1(filename_seed.encode("utf-8")).hexdigest()[:8]
path = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
f"{filename_hash}-{checksum}.{suffix}"
)
bucket_name = config.storage.attachment.bucket
try:
stored_path = await self._attachment_storage.upload_bytes(
bucket=bucket_name,
path=path,
content=payload,
content_type=mime_type,
)
signed_url = await self._attachment_storage.create_signed_url(
bucket=bucket_name,
path=stored_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
if created:
await self._repository.rollback()
logger.exception(
"Attachment upload failed",
extra={
"bucket": bucket_name,
"path": path,
"mime_type": mime_type,
"thread_id": thread_id,
},
)
raise ApiProblemError(
status_code=502,
detail=problem_payload(
code="AGENT_ATTACHMENT_UPLOAD_FAILED",
detail="Failed to upload attachment",
),
)
if created:
await self._repository.commit()
return {
"bucket": bucket_name,
"path": stored_path,
"mimeType": mime_type,
"url": signed_url,
}
async def _create_session_if_missing(
self,
*,
thread_id: str,
current_user: CurrentUser,
) -> bool:
try:
await self._repository.create_session_for_user(
user_id=str(current_user.id),
session_id=thread_id,
)
except IntegrityError:
await self._repository.rollback()
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return False
return True
async def _enforce_run_preconditions(
self,
*,
thread_id: str,
current_user: CurrentUser,
) -> None:
await self._points_service.ensure_run_points_available(user_id=current_user.id)
user_message_count = await self._repository.get_user_message_count(
session_id=thread_id
)
if user_message_count >= MAX_RUNS_PER_SESSION:
raise ApiProblemError(
status_code=409,
detail=problem_payload(
code="AGENT_SESSION_RUN_LIMIT_EXCEEDED",
detail="Session run limit exceeded",
params={"maxRuns": MAX_RUNS_PER_SESSION},
),
)
async def create_attachment_signed_url(
self,
*,
bucket: str,
path: str,
current_user: CurrentUser,
) -> dict[str, str]:
if self._attachment_storage is None:
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
detail="Attachment storage unavailable",
),
)
normalized_bucket = bucket.strip()
if normalized_bucket != config.storage.attachment.bucket:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENT_BUCKET_INVALID",
detail="Invalid attachment bucket",
),
)
normalized_path = path.strip()
expected_prefix = f"agent-inputs/{current_user.id}/"
if not is_safe_attachment_path(
normalized_path, expected_prefix=expected_prefix
):
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_ATTACHMENT_PATH_SCOPE_INVALID",
detail="Invalid attachment path scope",
),
)
try:
signed_url = await self._attachment_storage.create_signed_url(
bucket=normalized_bucket,
path=normalized_path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
except Exception: # noqa: BLE001
logger.exception(
"Attachment signed URL generation failed",
extra={
"bucket": normalized_bucket,
"path": normalized_path,
"user_id": str(current_user.id),
},
)
raise ApiProblemError(
status_code=502,
detail=problem_payload(
code="AGENT_SIGNED_URL_GENERATION_FAILED",
detail="Failed to generate signed URL",
),
)
return {
"bucket": normalized_bucket,
"path": normalized_path,
"url": signed_url,
}
async def stream_events(
self,
*,
thread_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
return await self._stream.read(
session_id=thread_id,
last_event_id=last_event_id,
)
async def get_history_snapshot(
self,
*,
thread_id: str,
before: date | None,
current_user: CurrentUser,
) -> HistorySnapshotResponse:
from schemas.domain.chat_message import AgentChatMessage
from v1.agent.utils import convert_message_to_history
from v1.agent.schemas import HistoryMessage
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
day_payload = await self._repository.get_history_day(
session_id=thread_id,
before=before,
visibility_mask=bit_mask(bit=int(SystemVisibilityBit.UI_HISTORY)),
)
messages: list[HistoryMessage] = []
if day_payload:
raw_messages_obj = day_payload.get("messages")
raw_messages = (
raw_messages_obj if isinstance(raw_messages_obj, list) else []
)
for msg_dict in raw_messages:
msg = AgentChatMessage.model_validate(msg_dict)
if msg.role == "tool":
continue
signed_urls: dict[str, str] = {}
attachments = extract_user_message_attachments(msg.metadata)
if self._attachment_storage and attachments:
expected_prefix = (
f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
)
for attachment in attachments:
if not is_safe_attachment_path(
attachment.path,
expected_prefix=expected_prefix,
):
continue
signed_url = await self._attachment_storage.create_signed_url(
bucket=attachment.bucket,
path=attachment.path,
expires_in_seconds=self._SIGNED_URL_EXPIRES_IN_SECONDS,
)
key = f"{attachment.bucket}/{attachment.path}"
signed_urls[key] = signed_url
def _get_signed_url(payload: dict[str, str]) -> str:
key = f"{payload['bucket']}/{payload['path']}"
return signed_urls[key]
converted = convert_message_to_history(msg, _get_signed_url)
messages.append(HistoryMessage.model_validate(converted))
return HistorySnapshotResponse(
scope="history_day",
threadId=thread_id,
day=str(day_payload.get("day"))
if day_payload and day_payload.get("day")
else None,
hasMore=bool(day_payload.get("hasMore")) if day_payload else False,
messages=messages,
)
async def get_user_history_snapshot(
self,
*,
current_user: CurrentUser,
thread_id: str | None,
before: date | None,
) -> HistorySnapshotResponse:
target_thread_id = thread_id
if target_thread_id is None:
target_thread_id = await self._repository.get_latest_session_id_for_user(
user_id=str(current_user.id)
)
if target_thread_id is None:
return HistorySnapshotResponse(
scope="history_day",
threadId=None,
day=None,
hasMore=False,
messages=[],
)
return await self.get_history_snapshot(
thread_id=target_thread_id,
before=before,
current_user=current_user,
)
def _validate_binary_signed_url(
self,
*,
url: str,
thread_id: str,
current_user: CurrentUser,
) -> tuple[str, str]:
if self._attachment_storage is None:
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="AGENT_ATTACHMENT_STORAGE_UNAVAILABLE",
detail="Attachment storage unavailable",
),
)
parsed = urlparse(url)
expected_host = urlparse(config.supabase.url).netloc
if parsed.netloc != expected_host:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="INVALID_BINARY_URL_HOST",
detail="Invalid binary url host",
),
)
try:
bucket, path = self._attachment_storage.parse_signed_url(url)
except Exception as exc: # noqa: BLE001
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="AGENT_SIGNED_IMAGE_URL_INVALID",
detail="Invalid signed image url",
),
) from exc
if bucket != config.storage.attachment.bucket:
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="INVALID_BINARY_URL_BUCKET",
detail="Invalid binary url bucket",
),
)
expected_prefix = f"agent-inputs/{current_user.id}/{thread_id}/uploads/"
if not is_safe_attachment_path(path, expected_prefix=expected_prefix):
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="INVALID_BINARY_URL_PATH_SCOPE",
detail="Invalid binary url path scope",
),
)
return bucket, path
@@ -0,0 +1,99 @@
"""
System agents 配置加载工具
从 system_agents.yaml 加载配置并构建 RuntimeConfig
"""
from pathlib import Path
import yaml
from pydantic import ValidationError
from schemas.agent.system_agent import SystemAgentLLMConfig
from schemas.agent.runtime_config import (
ContextSource,
ContextWindowMode,
MessageContextConfig,
RuntimeConfig,
)
def _default_system_agents_path() -> Path:
return (
Path(__file__).resolve().parents[2]
/ "core"
/ "config"
/ "static"
/ "database"
/ "system_agents.yaml"
)
def _load_system_agents_yaml(path: Path | None = None) -> dict[str, object]:
target_path = path or _default_system_agents_path()
with target_path.open("r", encoding="utf-8") as f:
loaded = yaml.safe_load(f) or {}
if not isinstance(loaded, dict):
raise ValueError(f"Invalid system agents format: {target_path}")
return loaded
def _parse_context_messages_config(
yaml_config: dict[str, object] | None,
) -> MessageContextConfig:
if not yaml_config:
return MessageContextConfig()
raw_mode = yaml_config.get("mode", "day")
mode_str = raw_mode if isinstance(raw_mode, str) else "day"
raw_count = yaml_config.get("count", 2)
count = raw_count if isinstance(raw_count, int) else 2
try:
source = ContextSource.LATEST_CHAT
except ValueError:
source = ContextSource.LATEST_CHAT
try:
window_mode = ContextWindowMode(mode_str)
except ValueError:
window_mode = ContextWindowMode.DAY
return MessageContextConfig(
source=source,
window_mode=window_mode,
window_count=count,
)
def build_runtime_config_from_system_agents(
yaml_path: Path | None = None,
) -> RuntimeConfig:
"""
从 system_agents.yaml 构建 RuntimeConfig
仅使用 worker 配置:
- worker.context_messages 配置上下文窗口
- enabled_tools 固定为空(eryao 不启用自定义工具)
"""
raw = _load_system_agents_yaml(yaml_path)
raw_agents = raw.get("agents", [])
agents_list = raw_agents if isinstance(raw_agents, list) else []
worker_config: SystemAgentLLMConfig | None = None
for agent in agents_list:
if not isinstance(agent, dict):
continue
agent_type = str(agent.get("agent_type", "")).strip().lower()
if agent_type == "worker":
config_dict = agent.get("config") or {}
try:
worker_config = SystemAgentLLMConfig.model_validate(config_dict)
except ValidationError:
worker_config = SystemAgentLLMConfig()
context_cfg = _parse_context_messages_config(
worker_config.context_messages.model_dump() if worker_config else None
)
return RuntimeConfig(
enabled_tools=[],
context=context_cfg,
)
+151
View File
@@ -0,0 +1,151 @@
"""
历史消息转换工具函数
将数据库中的原始消息转换为 API 响应的数据结构
"""
from collections.abc import Callable
from typing import Any
from core.agentscope.runtime.ui_compiler import compile as compile_ui_hints
from schemas.domain.chat_message import (
AgentChatMessage,
AgentChatMessageMetadata,
extract_user_message_attachments,
)
ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
MAX_ATTACHMENT_BYTES = 5 * 1024 * 1024
MAX_TOTAL_ATTACHMENT_BYTES = 12 * 1024 * 1024
MAX_ATTACHMENTS_PER_MESSAGE = 3
def convert_message_to_history(
message: AgentChatMessage,
get_signed_url_fn: Callable[[dict[str, str]], str] | None = None,
) -> dict[str, Any]:
"""
将 AgentChatMessage 转换为 HistoryMessage 格式
转换规则:
- role=user: 读取 metadata.user_message_attachments,转换为 attachments[]
- role=assistant: 读取 metadata.agent_output.ui_hints,编译成 ui_schema
"""
role = message.role
content = message.content
metadata = message.metadata
attachments: list[dict[str, str]] = []
ui_schema: dict[str, Any] | None = None
if role == "user":
attachments = _convert_user_attachments(metadata, get_signed_url_fn)
elif role == "assistant":
ui_schema = _compile_worker_ui_hints(metadata)
result: dict[str, Any] = {
"id": str(message.id),
"seq": message.seq,
"role": role,
"content": content,
"timestamp": message.timestamp.isoformat(),
}
if attachments:
result["attachments"] = attachments
if ui_schema:
result["ui_schema"] = ui_schema
return result
def _convert_user_attachments(
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
get_signed_url_fn: Callable[[dict[str, str]], str] | None,
) -> list[dict[str, str]]:
"""转换用户附件为临时访问 URL 列表"""
if not metadata or not get_signed_url_fn:
return []
if isinstance(metadata, AgentChatMessageMetadata):
resolved = extract_user_message_attachments(metadata)
elif isinstance(metadata, dict):
resolved = extract_user_message_attachments(metadata)
else:
return []
signed_attachments: list[dict[str, str]] = []
for attachment in resolved:
try:
signed_url = get_signed_url_fn(
{"bucket": attachment.bucket, "path": attachment.path}
)
except Exception:
continue
signed_attachments.append(
{
"url": signed_url,
"mimeType": attachment.mime_type,
}
)
return signed_attachments
def _compile_worker_ui_hints(
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
) -> dict[str, Any] | None:
"""编译 assistant 消息的 agent ui_hints"""
if not metadata:
return None
if isinstance(metadata, AgentChatMessageMetadata):
agent_output = metadata.agent_output
else:
agent_output_data = metadata.get("agent_output")
if not agent_output_data:
return None
if isinstance(agent_output_data, dict):
raw_ui_schema = agent_output_data.get("ui_schema")
if isinstance(raw_ui_schema, dict):
return raw_ui_schema
from schemas.agent.runtime_models import AgentOutput
try:
agent_output = AgentOutput.model_validate(agent_output_data)
except Exception:
return None
if not agent_output:
return None
ui_hints = agent_output.ui_hints
if not ui_hints:
return None
try:
compiled = compile_ui_hints(ui_hints)
return compiled
except Exception:
return None
def mime_to_suffix(mime_type: str) -> str:
mapping = {
"image/png": "png",
"image/jpeg": "jpg",
"image/webp": "webp",
}
return mapping.get(mime_type.lower(), "bin")
def is_safe_attachment_path(path: str, *, expected_prefix: str) -> bool:
normalized = path.strip()
if not normalized:
return False
if normalized.startswith("/"):
return False
if ".." in normalized:
return False
return normalized.startswith(expected_prefix)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+52
View File
@@ -0,0 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID, uuid4
from schemas.enums import MemoryType
@dataclass
class MemoryRecord:
id: UUID
owner_id: UUID
memory_type: MemoryType
content: dict
class SQLAlchemyMemoriesRepository:
def __init__(self, session: object) -> None:
self._session = session
async def get_user_memory_for_owner(self, *, owner_id: UUID) -> MemoryRecord | None:
_ = self._session
_ = owner_id
return None
async def get_work_memory_for_owner(self, *, owner_id: UUID) -> MemoryRecord | None:
_ = self._session
_ = owner_id
return None
async def create(
self,
*,
owner_id: UUID,
memory_type: MemoryType,
content: dict,
) -> MemoryRecord:
return MemoryRecord(
id=uuid4(),
owner_id=owner_id,
memory_type=memory_type,
content=content,
)
async def update_content(
self,
memory: MemoryRecord,
content: dict | None = None,
) -> MemoryRecord:
if content is not None:
memory.content = content
return memory
+76
View File
@@ -0,0 +1,76 @@
from __future__ import annotations
from core.auth.models import CurrentUser
from schemas.enums import MemoryType
from schemas.domain.memory_content import UserMemoryContent, WorkProfileContent
from v1.memories.repository import MemoryRecord, SQLAlchemyMemoriesRepository
class MemoriesService:
def __init__(
self,
repository: SQLAlchemyMemoriesRepository,
session: object,
current_user: CurrentUser | None,
) -> None:
self._repository = repository
self._session = session
self._current_user = current_user
def _require_user_id(self):
if self._current_user is None:
raise ValueError("current user is required")
return self._current_user.id
async def get_all_memories(
self,
) -> dict[str, UserMemoryContent | WorkProfileContent | None]:
owner_id = self._require_user_id()
user_memory = await self._repository.get_user_memory_for_owner(
owner_id=owner_id
)
work_memory = await self._repository.get_work_memory_for_owner(
owner_id=owner_id
)
return {
"user_memory": UserMemoryContent.model_validate(user_memory.content)
if user_memory is not None
else None,
"work_memory": WorkProfileContent.model_validate(work_memory.content)
if work_memory is not None
else None,
}
async def get_memory_model(self, *, memory_type: MemoryType) -> MemoryRecord | None:
owner_id = self._require_user_id()
if memory_type == MemoryType.USER:
return await self._repository.get_user_memory_for_owner(owner_id=owner_id)
return await self._repository.get_work_memory_for_owner(owner_id=owner_id)
async def update_user_memory(self, *, content: UserMemoryContent) -> MemoryRecord:
owner_id = self._require_user_id()
existing = await self._repository.get_user_memory_for_owner(owner_id=owner_id)
if existing is not None:
return await self._repository.update_content(
existing,
content.model_dump(mode="json", exclude_none=True),
)
return await self._repository.create(
owner_id=owner_id,
memory_type=MemoryType.USER,
content=content.model_dump(mode="json", exclude_none=True),
)
async def update_work_memory(self, *, content: WorkProfileContent) -> MemoryRecord:
owner_id = self._require_user_id()
existing = await self._repository.get_work_memory_for_owner(owner_id=owner_id)
if existing is not None:
return await self._repository.update_content(
existing,
content.model_dump(mode="json", exclude_none=True),
)
return await self._repository.create(
owner_id=owner_id,
memory_type=MemoryType.WORK,
content=content.model_dump(mode="json", exclude_none=True),
)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+58
View File
@@ -0,0 +1,58 @@
from __future__ import annotations
from uuid import UUID
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.points_ledger import PointsLedger
from models.user_points import UserPoints
from schemas.shared.points import ApplyPointsChangeCommand
class PointsRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_or_create_user_points_for_update(
self, *, user_id: UUID
) -> UserPoints:
insert_stmt = (
insert(UserPoints)
.values(user_id=user_id)
.on_conflict_do_nothing(index_elements=[UserPoints.user_id])
)
await self._session.execute(insert_stmt)
stmt = select(UserPoints).where(UserPoints.user_id == user_id).with_for_update()
return (await self._session.execute(stmt)).scalar_one()
async def has_ledger_event(self, *, user_id: UUID, event_id: str) -> bool:
stmt = select(PointsLedger.id).where(
PointsLedger.user_id == user_id,
PointsLedger.event_id == event_id,
)
row = (await self._session.execute(stmt)).scalar_one_or_none()
return row is not None
async def append_ledger(
self,
*,
command: ApplyPointsChangeCommand,
balance_after: int,
) -> None:
entry = PointsLedger(
user_id=command.user_id,
direction=command.direction,
amount=command.amount,
balance_after=balance_after,
change_type=command.change_type.value,
biz_type=command.biz_type.value if command.biz_type is not None else None,
biz_id=command.biz_id,
event_id=command.event_id,
operator_id=command.operator_id,
metadata_json=command.metadata.model_dump(mode="json", exclude_none=True),
)
self._session.add(entry)
await self._session.flush()
+132
View File
@@ -0,0 +1,132 @@
from __future__ import annotations
from dataclasses import dataclass
from decimal import Decimal
import hashlib
from uuid import UUID, uuid4
from core.http.errors import ApiProblemError, problem_payload
from schemas.domain.points import ConsumeLedgerMetadata, PointsChargeSnapshot
from schemas.enums import PointsBizType, PointsChangeType, PointsOperatorType
from schemas.shared.points import ApplyPointsChangeCommand
from v1.points.repository import PointsRepository
RUN_POINTS_COST = 20
@dataclass(frozen=True)
class RunChargeResult:
charged: bool
amount: int
balance_after: int
event_id: str
class PointsService:
def __init__(self, repository: PointsRepository) -> None:
self._repository = repository
async def ensure_run_points_available(
self,
*,
user_id: UUID,
) -> int:
account = await self._repository.get_or_create_user_points_for_update(
user_id=user_id
)
balance = int(account.balance)
frozen_balance = int(account.frozen_balance)
available = balance - frozen_balance
if available < RUN_POINTS_COST:
raise ApiProblemError(
status_code=402,
detail=problem_payload(
code="POINTS_INSUFFICIENT_BALANCE",
detail="Insufficient points for this run",
params={
"required": RUN_POINTS_COST,
"available": max(available, 0),
},
),
)
return available
async def consume_successful_run_points(
self,
*,
user_id: UUID,
session_id: UUID,
run_id: str,
operator_id: UUID | None,
) -> RunChargeResult:
event_source = f"{session_id}:{run_id}".encode("utf-8")
event_hash = hashlib.sha1(event_source).hexdigest()
event_id = f"chat.run.success:{event_hash}"
if await self._repository.has_ledger_event(user_id=user_id, event_id=event_id):
account = await self._repository.get_or_create_user_points_for_update(
user_id=user_id
)
return RunChargeResult(
charged=False,
amount=0,
balance_after=int(account.balance),
event_id=event_id,
)
account = await self._repository.get_or_create_user_points_for_update(
user_id=user_id
)
balance = int(account.balance)
frozen_balance = int(account.frozen_balance)
available = balance - frozen_balance
if available < RUN_POINTS_COST:
raise ApiProblemError(
status_code=402,
detail=problem_payload(
code="POINTS_INSUFFICIENT_BALANCE",
detail="Insufficient points for this run",
params={
"required": RUN_POINTS_COST,
"available": max(available, 0),
},
),
)
account.balance = balance - RUN_POINTS_COST
account.lifetime_spent = int(account.lifetime_spent) + RUN_POINTS_COST
account.version = int(account.version) + 1
metadata = ConsumeLedgerMetadata(
operator_type=PointsOperatorType.USER,
run_id=run_id,
charge=PointsChargeSnapshot(
message_id=uuid4(),
message_seq=1,
model_code="agent_run",
input_tokens=0,
output_tokens=0,
cost=Decimal("0"),
),
ext={"source": "run_success"},
)
command = ApplyPointsChangeCommand(
user_id=user_id,
change_type=PointsChangeType.CONSUME,
biz_type=PointsBizType.CHAT,
biz_id=session_id,
event_id=event_id,
amount=RUN_POINTS_COST,
direction=-1,
operator_id=operator_id,
metadata=metadata,
)
await self._repository.append_ledger(
command=command,
balance_after=int(account.balance),
)
return RunChargeResult(
charged=True,
amount=RUN_POINTS_COST,
balance_after=int(account.balance),
event_id=event_id,
)
+2
View File
@@ -2,8 +2,10 @@ from __future__ import annotations
from fastapi import APIRouter
from v1.agent.router import router as agent_router
from v1.auth.router import router as auth_router
router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(agent_router)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+25
View File
@@ -0,0 +1,25 @@
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
@dataclass
class ContactInfo:
username: str | None
phone: str | None
async def resolve_contacts_by_user_ids(
*,
user_ids: list[UUID],
profiles_by_id: dict[UUID, object],
auth_gateway: object,
) -> dict[UUID, ContactInfo]:
_ = auth_gateway
resolved: dict[UUID, ContactInfo] = {}
for user_id in user_ids:
profile = profiles_by_id.get(user_id)
username = getattr(profile, "username", None) if profile is not None else None
resolved[user_id] = ContactInfo(username=username, phone=None)
return resolved
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
import asyncio
from typing import Annotated
from uuid import UUID
from fastapi import Depends, Header
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from core.http.errors import ApiProblemError, problem_payload
from services.base.supabase import supabase_service
from v1.users.service import UserService
async def get_current_user(
authorization: str | None = Header(default=None),
) -> CurrentUser:
if not authorization:
raise ApiProblemError(
status_code=401,
detail=problem_payload(code="AUTH_UNAUTHORIZED", detail="Unauthorized"),
)
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer" or not token:
raise ApiProblemError(
status_code=401,
detail=problem_payload(code="AUTH_UNAUTHORIZED", detail="Unauthorized"),
)
try:
client = supabase_service.get_client()
response = await asyncio.to_thread(client.auth.get_user, token)
user = getattr(response, "user", None)
user_id = getattr(user, "id", None)
if not isinstance(user_id, str) or not user_id:
raise ValueError("missing user id")
return CurrentUser(
id=UUID(user_id),
email=getattr(user, "email", None),
role=getattr(user, "role", None),
)
except Exception as exc:
raise ApiProblemError(
status_code=401,
detail=problem_payload(code="AUTH_UNAUTHORIZED", detail="Unauthorized"),
) from exc
def get_user_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> UserService:
_ = session
return UserService(current_user=user)
+14
View File
@@ -0,0 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
@dataclass
class SQLAlchemyUserRepository:
session: object
async def get_by_user_ids(self, user_ids: list[UUID]) -> dict[UUID, object]:
_ = self.session
_ = user_ids
return {}
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
from dataclasses import dataclass
from core.auth.models import CurrentUser
from schemas.shared.user import UserContext
@dataclass
class UserService:
current_user: CurrentUser
async def get_me(self) -> UserContext:
user_id = str(self.current_user.id)
return UserContext(
id=user_id,
username=f"user_{user_id[:8]}",
email=self.current_user.email,
avatar_url=None,
bio=None,
settings=None,
)