Files
social-app/backend/src/v1/agent/router.py
T
zl-q 80ad5141a6 refactor(agent): restructure visibility masks, task queues, and memory service
Visibility mask refactoring:
- Replace dead UI_REALTIME bit with CONTEXT_ASSEMBLY (bit 1)
- Remove visibility_consumer_bit from SystemAgentLLMConfig and system_agents.yaml
- Simplify _resolve_user_message_visibility_mask: chat->UI_HISTORY|CONTEXT_ASSEMBLY, automation->0
- Simplify _resolve_stage_visibility_mask: memory->UI_HISTORY, router/worker->UI_HISTORY|CONTEXT_ASSEMBLY
- Remove stage_visibility_bit_map from store.py

Task queue renaming:
- Replace default_broker/bulk_broker/critical_broker with worker_agent_broker/worker_automation_broker
- Queue names: 'default'/'bulk'/'critical' -> 'agent'/'automation'
- Rename run_command_task -> run_command_task_agent/run_command_task_automation
- AgentService derives queue from runtime_mode: chat->agent, automation->automation

Architecture cleanup:
- Move context_service.py from runtime/ to agentscope/services/
- Add MemoryService in v1/memory/ following repository/service pattern
- Move consumer_registry.py and pipeline_spec.py from schemas/agent to agentscope/schemas/
- Delete dead code: registry_builder.py, VisibilityBitRef
- Delete superseded plan docs
2026-03-22 20:35:55 +08:00

357 lines
12 KiB
Python

from __future__ import annotations
import asyncio
import os
import re
import tempfile
from collections.abc import AsyncIterator
from datetime import date
from typing import Annotated
from ag_ui.core import RunAgentInput
from core.agentscope.events import to_sse_event
from core.agentscope.schemas.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
)
from core.auth.models import CurrentUser
from core.logging import get_logger
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
HTTPException,
Query,
Request,
UploadFile,
status,
)
from fastapi.responses import StreamingResponse
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
AsrTranscribeResponse,
AttachmentReference,
AttachmentSignedUrlResponse,
AttachmentUploadResponse,
HistorySnapshotResponse,
TaskAcceptedResponse,
)
from v1.agent.asr import asr_service
from v1.agent.service import AgentService
from v1.users.dependencies import get_current_user
router = APIRouter(prefix="/agent", tags=["agent"])
logger = get_logger("v1.agent.router")
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
_TERMINAL_RUN_EVENT_TYPES = {"RUN_FINISHED", "RUN_ERROR"}
_MAX_TRANSCRIBE_AUDIO_BYTES = 10 * 1024 * 1024
_TRANSCRIBE_READ_CHUNK_BYTES = 1024 * 1024
_MULTIPART_OVERHEAD_BYTES = 64 * 1024
_MAX_ATTACHMENT_UPLOAD_BYTES = 5 * 1024 * 1024
_WAV_HEADER_MIN_BYTES = 12
_ALLOWED_AUDIO_CONTENT_TYPES = {
"audio/wav",
"audio/x-wav",
"audio/wave",
}
def _looks_like_wav_header(header: bytes) -> bool:
if len(header) < _WAV_HEADER_MIN_BYTES:
return False
return header[0:4] == b"RIFF" and header[8:12] == b"WAVE"
async def _acquire_sse_slot(*, user_id: str) -> bool:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
return True
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE slot acquire failed",
user_id=user_id,
reason=str(exc),
)
return True
async def _release_sse_slot(*, user_id: str) -> None:
try:
redis = await get_or_init_redis_client()
key = f"agent:sse-active:{user_id}"
count = await redis.decr(key)
if count <= 0:
await redis.delete(key)
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE slot release failed",
user_id=user_id,
reason=str(exc),
)
return None
def _is_terminal_run_event(event: dict[str, object]) -> bool:
raw_event_type = event.get("type")
return (
isinstance(raw_event_type, str) and raw_event_type in _TERMINAL_RUN_EVENT_TYPES
)
@router.post(
"/runs", response_model=TaskAcceptedResponse, status_code=status.HTTP_202_ACCEPTED
)
async def enqueue_run(
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
try:
validate_run_request_messages_contract(request)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
task = await service.enqueue_run(
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
taskId=task.task_id,
threadId=task.thread_id,
runId=task.run_id,
created=task.created,
)
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
thread_id: str,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if last_event_id is not None and (
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
):
raise HTTPException(status_code=422, detail="Invalid Last-Event-ID")
sse_slot_acquired = await _acquire_sse_slot(user_id=str(current_user.id))
if not sse_slot_acquired:
raise HTTPException(status_code=429, detail="Too many SSE connections")
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
terminal_event_reached = False
try:
while (
not terminal_event_reached
and not await request.is_disconnected()
and idle_polls < idle_limit
):
try:
rows = await service.stream_events(
thread_id=thread_id,
last_event_id=cursor,
current_user=current_user,
)
except TimeoutError:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
except Exception as exc: # noqa: BLE001
logger.warning(
"SSE stream read failed",
thread_id=thread_id,
user_id=str(current_user.id),
reason=str(exc),
)
break
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
row_id = str(row.get("id", ""))
event = row.get("event")
if not row_id or not isinstance(event, dict):
continue
cursor = row_id
yield to_sse_event(row_id, event)
if _is_terminal_run_event(event):
terminal_event_reached = True
break
finally:
await _release_sse_slot(user_id=str(current_user.id))
return StreamingResponse(
_event_iter(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/history", response_model=HistorySnapshotResponse)
async def get_user_history_snapshot(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str | None = Query(default=None, alias="threadId"),
before: date | None = Query(default=None),
) -> HistorySnapshotResponse:
return await service.get_user_history_snapshot(
current_user=current_user,
thread_id=thread_id,
before=before,
)
@router.post(
"/attachments",
response_model=AttachmentUploadResponse,
status_code=status.HTTP_200_OK,
)
async def upload_attachment(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
thread_id: str = Form(alias="threadId"),
file: UploadFile = File(),
) -> AttachmentUploadResponse:
payload = await file.read()
if not payload:
raise HTTPException(status_code=422, detail="Empty attachment")
if len(payload) > _MAX_ATTACHMENT_UPLOAD_BYTES:
raise HTTPException(status_code=413, detail="Attachment too large")
attachment = await service.upload_attachment(
thread_id=thread_id,
filename=file.filename,
content_type=file.content_type,
payload=payload,
current_user=current_user,
)
return AttachmentUploadResponse(
attachment=AttachmentReference.model_validate(attachment),
)
@router.get(
"/attachments/signed-url",
response_model=AttachmentSignedUrlResponse,
status_code=status.HTTP_200_OK,
)
async def create_attachment_signed_url(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
bucket: str = Query(min_length=1, max_length=100),
path: str = Query(min_length=1, max_length=500),
) -> AttachmentSignedUrlResponse:
signed = await service.create_attachment_signed_url(
bucket=bucket,
path=path,
current_user=current_user,
)
return AttachmentSignedUrlResponse(**signed)
@router.post(
"/transcribe",
response_model=AsrTranscribeResponse,
status_code=status.HTTP_200_OK,
)
async def transcribe(
audio: UploadFile,
request: Request,
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AsrTranscribeResponse:
temp_path: str | None = None
try:
if audio.content_type not in _ALLOWED_AUDIO_CONTENT_TYPES:
raise HTTPException(status_code=400, detail="Unsupported audio format")
content_length = request.headers.get("content-length")
if content_length is not None:
try:
declared_length = int(content_length)
except ValueError:
declared_length = None
if (
declared_length is not None
and declared_length
> _MAX_TRANSCRIBE_AUDIO_BYTES + _MULTIPART_OVERHEAD_BYTES
):
raise HTTPException(status_code=400, detail="Audio file too large")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
temp_path = tmp_file.name
total_bytes = 0
header = bytearray()
while True:
chunk = await audio.read(_TRANSCRIBE_READ_CHUNK_BYTES)
if not chunk:
break
total_bytes += len(chunk)
if total_bytes > _MAX_TRANSCRIBE_AUDIO_BYTES:
raise HTTPException(status_code=400, detail="Audio file too large")
if len(header) < _WAV_HEADER_MIN_BYTES:
required = _WAV_HEADER_MIN_BYTES - len(header)
header.extend(chunk[:required])
tmp_file.write(chunk)
if total_bytes == 0:
raise HTTPException(status_code=400, detail="Empty audio file")
if not _looks_like_wav_header(bytes(header)):
raise HTTPException(status_code=400, detail="Unsupported audio format")
transcript = await asr_service.transcribe_file(
temp_path, audio.filename or "unknown"
)
return AsrTranscribeResponse(transcript=transcript)
except HTTPException:
raise
except RuntimeError:
raise HTTPException(status_code=502, detail="ASR service unavailable")
finally:
await audio.close()
if temp_path:
try:
os.unlink(temp_path)
except OSError:
pass