feat(agent-chat): complete core workflow and strengthen auth rate limiting

This commit is contained in:
qzl
2026-02-25 16:51:12 +08:00
parent 53c72e48e6
commit cd40b2b4f4
62 changed files with 3441 additions and 3 deletions
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,20 @@
from __future__ import annotations
from typing import Any
from core.agent_chat.event_bridge import map_internal_event
class AguiAdapter:
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
message = payload.get("message")
if not isinstance(message, str) or not message.strip():
raise ValueError("message is required")
return {
"message": message,
"session_id": payload.get("session_id"),
}
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
return map_internal_event(event)
@@ -0,0 +1,69 @@
from __future__ import annotations
from decimal import Decimal
from typing import Any, Mapping
def _to_non_negative_int(value: Any, *, field: str) -> int:
if isinstance(value, bool):
raise ValueError(f"{field} must be an integer")
if isinstance(value, int):
converted = value
elif isinstance(value, str) and value.isdigit():
converted = int(value)
else:
raise ValueError(f"{field} must be an integer")
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
converted = Decimal(str(value))
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
class CostTracker:
def __init__(self, *, currency: str = "USD") -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
self._currency = currency
def add_usage(self, usage: Mapping[str, Any]) -> None:
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
total_tokens = usage.get("total_tokens")
cost = usage.get("cost", "0")
currency = usage.get("currency")
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
normalized_total = (
_to_non_negative_int(total_tokens, field="total_tokens")
if total_tokens is not None
else normalized_input + normalized_output
)
normalized_cost = _to_non_negative_decimal(cost, field="cost")
self._input_tokens += normalized_input
self._output_tokens += normalized_output
self._total_tokens += normalized_total
self._cost += normalized_cost
if currency is not None:
normalized_currency = str(currency)
if normalized_currency != self._currency:
raise ValueError("currency mismatch")
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": self._cost,
"currency": self._currency,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,82 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class CrewAITemplate:
agents: dict[str, Any]
tasks: dict[str, Any]
workflow: dict[str, Any]
prompts: dict[str, str]
tools_whitelist: set[str]
def _default_static_root() -> Path:
return Path(__file__).resolve().parents[3] / "config" / "static" / "agent_chat"
def _read_yaml(file_path: Path) -> dict[str, Any]:
if not file_path.is_file():
raise FileNotFoundError(f"Required config file not found: {file_path}")
with file_path.open("r", encoding="utf-8") as file:
loaded = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"YAML file must be a mapping: {file_path}")
return loaded
def _read_prompt(file_path: Path) -> str:
if not file_path.is_file():
raise FileNotFoundError(f"Required prompt file not found: {file_path}")
return file_path.read_text(encoding="utf-8").strip()
def validate_workflow_stages(stages: list[str]) -> None:
expected = ["intent", "execution", "organization"]
if stages != expected:
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
root = static_root or _default_static_root()
tools = _read_yaml(root / "tools.yaml")
raw_tools = tools.get("tools", [])
if not isinstance(raw_tools, list):
raise ValueError("tools.yaml field 'tools' must be a list")
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
raise ValueError("tools.yaml list items must be non-empty strings")
whitelist = {item.strip() for item in raw_tools}
return whitelist
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
root = static_root or _default_static_root()
crewai_root = root / "crewai"
agents = _read_yaml(crewai_root / "agents.yaml")
tasks = _read_yaml(crewai_root / "tasks.yaml")
workflow = _read_yaml(crewai_root / "workflow.yaml")
stages = workflow.get("stages")
if not isinstance(stages, list):
raise ValueError("workflow.yaml field 'stages' must be a list")
validate_workflow_stages([str(stage) for stage in stages])
prompts = {
"intent": _read_prompt(crewai_root / "prompts" / "intent.md"),
"execution": _read_prompt(crewai_root / "prompts" / "execution.md"),
"organization": _read_prompt(crewai_root / "prompts" / "organization.md"),
}
return CrewAITemplate(
agents=agents,
tasks=tasks,
workflow=workflow,
prompts=prompts,
tools_whitelist=load_tools_whitelist(root),
)
@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Any
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
missing = [field for field in required if field not in event]
if missing:
raise ValueError(f"Missing fields for {kind}: {missing}")
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
kind = event.get("kind")
if kind == "run_started":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.started",
"run_id": event["session_id"],
}
if kind == "message_delta":
_require_fields(event, kind=kind, required=["message_id", "delta"])
return {
"type": "message.delta",
"message_id": event["message_id"],
"delta": event["delta"],
}
if kind == "tool_started":
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
return {
"type": "tool.started",
"message_id": event["message_id"],
"tool_name": event["tool_name"],
}
if kind == "tool_completed":
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
return {
"type": "tool.completed",
"message_id": event["message_id"],
"tool_name": event["tool_name"],
"result": event.get("result"),
}
if kind == "run_completed":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.completed",
"run_id": event["session_id"],
"output": event.get("output", ""),
}
if kind == "run_failed":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.failed",
"run_id": event["session_id"],
"error": event.get("error", ""),
}
raise ValueError(f"Unsupported event kind: {kind}")
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Any
def run_started(*, run_id: str) -> dict[str, Any]:
return {"type": "run.started", "run_id": run_id}
def stage_completed(
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
) -> dict[str, Any]:
event: dict[str, Any] = {
"type": "stage.completed",
"run_id": run_id,
"stage": stage,
}
if usage is not None:
event["usage"] = usage
return event
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
return {
"type": "run.completed",
"run_id": run_id,
"output": output,
"usage": usage,
}
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
return {
"type": "run.failed",
"run_id": run_id,
"error": error,
}
+112
View File
@@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass
import hashlib
from pathlib import Path
from typing import Protocol
from core.agent_chat.storage_adapter import StorageAdapter
_ALLOWED_MIME_TYPES = {
"audio/mpeg",
"audio/wav",
"audio/x-wav",
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
"text/plain",
}
class _AsrTool(Protocol):
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
@dataclass(frozen=True)
class AttachmentInput:
filename: str
mime_type: str
content: bytes
origin: str = "user_upload"
@dataclass(frozen=True)
class ProcessedAttachmentContext:
attachments: list[dict[str, object]]
preview_texts: list[str]
class MultimodalProcessor:
def __init__(
self,
*,
storage: StorageAdapter,
asr_tool: _AsrTool,
max_file_size_mb: int = 20,
) -> None:
self._storage = storage
self._asr_tool = asr_tool
self._max_size_bytes = max_file_size_mb * 1024 * 1024
def process(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
attachments: list[AttachmentInput],
) -> ProcessedAttachmentContext:
metadata_list: list[dict[str, object]] = []
preview_texts: list[str] = []
for attachment in attachments:
self._validate_attachment(attachment)
checksum = hashlib.sha256(attachment.content).hexdigest()
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
object_path = self._storage.build_object_path(
user_id=user_id,
session_id=session_id,
message_seq=message_seq,
checksum_sha256=checksum,
extension=extension,
)
preview_text = self._build_preview_text(attachment)
if preview_text:
preview_texts.append(preview_text)
metadata = self._storage.build_attachment_metadata(
object_path=object_path,
mime_type=attachment.mime_type,
size=len(attachment.content),
checksum_sha256=checksum,
origin=attachment.origin,
preview_text=preview_text,
)
metadata_list.append(metadata)
return ProcessedAttachmentContext(
attachments=metadata_list,
preview_texts=preview_texts,
)
def _validate_attachment(self, attachment: AttachmentInput) -> None:
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
raise ValueError("Unsupported MIME type")
if len(attachment.content) > self._max_size_bytes:
raise ValueError("Attachment exceeds max file size")
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
if attachment.mime_type.startswith("audio/"):
result = self._asr_tool.transcribe(
audio_bytes=attachment.content,
filename=attachment.filename,
)
text = result.get("text")
if isinstance(text, str):
return text
return None
if attachment.mime_type == "text/plain":
return attachment.content.decode("utf-8", errors="ignore")[:200]
return None
@@ -0,0 +1,88 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from core.agent_chat.cost_tracker import CostTracker
from core.agent_chat import events
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
@dataclass(frozen=True)
class OrchestratorResult:
output: str
usage: dict[str, Any]
events: list[dict[str, Any]]
context: dict[str, Any]
failed: bool
error: str | None
class AgentChatOrchestrator:
def __init__(
self,
*,
intent_stage: StageCallable,
execution_stage: StageCallable,
organization_stage: StageCallable,
) -> None:
self._intent_stage = intent_stage
self._execution_stage = execution_stage
self._organization_stage = organization_stage
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
tracker = CostTracker()
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
context: dict[str, Any] = {}
stage_pipeline: list[tuple[str, StageCallable]] = [
("intent", self._intent_stage),
("execution", self._execution_stage),
("organization", self._organization_stage),
]
stage_output = user_message
try:
for stage_name, stage_callable in stage_pipeline:
stage_result = await stage_callable(
message=stage_output, context=context
)
stage_output = str(stage_result.get("content", stage_output))
usage = stage_result.get("usage", {})
if isinstance(usage, dict):
tracker.add_usage(usage)
emitted_events.append(
events.stage_completed(
run_id=run_id,
stage=stage_name,
usage=tracker.snapshot(),
)
)
except Exception as exc: # noqa: BLE001
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
return OrchestratorResult(
output="",
usage=tracker.snapshot(),
events=emitted_events,
context=context,
failed=True,
error=str(exc),
)
summary = tracker.snapshot()
emitted_events.append(
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
)
return OrchestratorResult(
output=stage_output,
usage=summary,
events=emitted_events,
context=context,
failed=False,
error=None,
)
@@ -0,0 +1,46 @@
from __future__ import annotations
class StorageAdapter:
_bucket: str
def __init__(self, bucket: str) -> None:
self._bucket = bucket
@property
def bucket(self) -> str:
return self._bucket
def build_object_path(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
checksum_sha256: str,
extension: str,
) -> str:
normalized_ext = extension.strip(".").lower()
return (
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
f"{checksum_sha256}.{normalized_ext}"
)
def build_attachment_metadata(
self,
*,
object_path: str,
mime_type: str,
size: int,
checksum_sha256: str,
origin: str,
preview_text: str | None = None,
) -> dict[str, object]:
return {
"object_path": object_path,
"mime_type": mime_type,
"size": size,
"checksum_sha256": checksum_sha256,
"origin": origin,
"preview_text": preview_text,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,40 @@
from __future__ import annotations
import importlib
from collections.abc import Callable
from typing import Any
TranscribeCallable = Callable[..., dict[str, Any]]
class FunASRTool:
_transcribe_callable: TranscribeCallable
_model: str
def __init__(
self,
transcribe_callable: TranscribeCallable | None = None,
model: str = "fun-asr-realtime-2025-11-07",
) -> None:
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
self._model = model
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
return {
"model": self._model,
**payload,
}
def _dashscope_transcribe(
self, *, audio_bytes: bytes, filename: str
) -> dict[str, Any]:
try:
importlib.import_module("dashscope")
except ImportError as exc:
raise RuntimeError("DashScope SDK is not installed") from exc
raise RuntimeError(
"DashScope transcribe runtime integration is not configured yet"
)
+9
View File
@@ -132,6 +132,14 @@ class SupabaseSettings(BaseModel):
return self.public_url
class StorageSettings(BaseModel):
provider: Literal["supabase"] = "supabase"
bucket: str = Field(default="agent-chat-attachments", min_length=3, max_length=63)
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
max_file_size_mb: int = Field(default=20, ge=1, le=200)
retention_days: int = Field(default=30, ge=1, le=3650)
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
@@ -163,6 +171,7 @@ class Settings(BaseSettings):
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
supabase: SupabaseSettings = SupabaseSettings()
storage: StorageSettings = StorageSettings()
celery: CelerySettings = CelerySettings()
database: DatabaseSettings = DatabaseSettings()
@@ -0,0 +1,9 @@
intent:
role: Intent Agent
goal: Classify user intent and decide execution strategy
execution:
role: Execution Agent
goal: Execute tasks with available tools
organization:
role: Organization Agent
goal: Organize output for user-friendly response
@@ -0,0 +1,2 @@
你是任务执行代理。
基于输入意图与上下文调用可用工具,并生成可验证执行结果。
@@ -0,0 +1,2 @@
你是意图识别代理。
你的任务是识别用户输入的意图类型,并返回结构化意图标签。
@@ -0,0 +1,2 @@
你是结果整理代理。
将执行结果组织为面向用户的清晰回复,保留关键信息与必要引用。
@@ -0,0 +1,6 @@
intent:
description: Identify user intent and required capabilities
execution:
description: Execute intent with tools and model calls
organization:
description: Format final response and references
@@ -0,0 +1,9 @@
stages:
- intent
- execution
- organization
timeouts:
intent_seconds: 8
execution_seconds: 30
organization_seconds: 10
@@ -0,0 +1,25 @@
factories:
- name: qwen
request_url: https://dashscope.aliyuncs.com/compatible-mode/v1
avatar: https://cdn.simpleicons.org/alibabacloud/FF6A00
- name: minimax
request_url: https://api.minimax.chat/v1
avatar: https://cdn.simpleicons.org/minimax/1A1A1A
- name: kimi
request_url: https://api.moonshot.cn/v1
avatar: https://cdn.simpleicons.org/moonrepo/3B82F6
- name: deepseek
request_url: https://api.deepseek.com/v1
avatar: https://cdn.simpleicons.org/deepseek/4D6BFE
- name: doubao
request_url: https://ark.cn-beijing.volces.com/api/v3
avatar: https://cdn.simpleicons.org/volkswagen/001E50
- name: zai
request_url: https://api.z.ai/v1
avatar: https://cdn.simpleicons.org/zotero/CC2936
llms:
- model_code: qwen3.5-flash
factory_id: qwen
- model_code: deepseek-v3.2
factory_id: deepseek
@@ -0,0 +1,3 @@
tools:
- asr_fun_asr
- attachment_extract
+126 -1
View File
@@ -1,9 +1,134 @@
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel, ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.llm import Llm
from models.llm_factory import LlmFactory
logger = get_logger("core.initialization.init_data")
class LlmFactorySeed(BaseModel):
name: str
request_url: str
avatar: str | None = None
class LlmSeed(BaseModel):
model_code: str
factory_id: str
class LlmCatalogSeed(BaseModel):
factories: list[LlmFactorySeed]
llms: list[LlmSeed]
def _default_catalog_path() -> Path:
return (
Path(__file__).resolve().parents[1]
/ "config"
/ "static"
/ "agent_chat"
/ "llm_catalog.yaml"
)
def load_llm_catalog(catalog_path: Path | None = None) -> dict[str, Any]:
path = catalog_path or _default_catalog_path()
with path.open("r", encoding="utf-8") as file:
loaded = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"Invalid LLM catalog format: {path}")
raw_factories = loaded.get("factories", [])
raw_llms = loaded.get("llms", [])
if not isinstance(raw_factories, list) or not isinstance(raw_llms, list):
raise ValueError(f"Invalid LLM catalog sections: {path}")
try:
parsed = LlmCatalogSeed.model_validate(
{
"factories": list(raw_factories),
"llms": list(raw_llms),
}
)
except ValidationError as exc:
raise ValueError(f"Invalid LLM catalog data: {path}") from exc
return parsed.model_dump()
async def _upsert_factory(
session: AsyncSession,
*,
name: str,
request_url: str,
avatar: str | None,
) -> uuid.UUID:
result = await session.execute(select(LlmFactory).where(LlmFactory.name == name))
factory = result.scalar_one_or_none()
if factory is None:
factory = LlmFactory(name=name, request_url=request_url, avatar=avatar)
session.add(factory)
await session.flush()
else:
factory.request_url = request_url
factory.avatar = avatar
return factory.id
async def _upsert_llm(
session: AsyncSession,
*,
model_code: str,
factory_id: uuid.UUID,
) -> None:
result = await session.execute(select(Llm).where(Llm.model_code == model_code))
llm = result.scalar_one_or_none()
if llm is None:
session.add(Llm(model_code=model_code, factory_id=factory_id))
return
llm.factory_id = factory_id
async def initialize_data() -> bool:
"""Initialize bootstrap data."""
logger.info("Initializing data (no-op)")
catalog = load_llm_catalog()
async with AsyncSessionLocal() as session:
async with session.begin():
factory_id_by_name: dict[str, uuid.UUID] = {}
for factory in catalog["factories"]:
factory_id = await _upsert_factory(
session,
name=factory["name"],
request_url=factory["request_url"],
avatar=factory.get("avatar"),
)
factory_id_by_name[factory["name"]] = factory_id
for llm in catalog["llms"]:
factory_name = llm["factory_id"]
resolved_factory_id = factory_id_by_name.get(factory_name)
if resolved_factory_id is None:
raise RuntimeError(
f"Factory '{factory_name}' not found for model '{llm['model_code']}'"
)
await _upsert_llm(
session,
model_code=llm["model_code"],
factory_id=resolved_factory_id,
)
logger.info("Initialized LLM factory/model seed data")
return True