feat(agent-chat): complete core workflow and strengthen auth rate limiting
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.agent_chat.event_bridge import map_internal_event
|
||||
|
||||
|
||||
class AguiAdapter:
|
||||
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
message = payload.get("message")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
raise ValueError("message is required")
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"session_id": payload.get("session_id"),
|
||||
}
|
||||
|
||||
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
return map_internal_event(event)
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def _to_non_negative_int(value: Any, *, field: str) -> int:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if isinstance(value, int):
|
||||
converted = value
|
||||
elif isinstance(value, str) and value.isdigit():
|
||||
converted = int(value)
|
||||
else:
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
|
||||
converted = Decimal(str(value))
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
class CostTracker:
|
||||
def __init__(self, *, currency: str = "USD") -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
self._currency = currency
|
||||
|
||||
def add_usage(self, usage: Mapping[str, Any]) -> None:
|
||||
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
|
||||
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
|
||||
total_tokens = usage.get("total_tokens")
|
||||
cost = usage.get("cost", "0")
|
||||
currency = usage.get("currency")
|
||||
|
||||
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
|
||||
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
|
||||
normalized_total = (
|
||||
_to_non_negative_int(total_tokens, field="total_tokens")
|
||||
if total_tokens is not None
|
||||
else normalized_input + normalized_output
|
||||
)
|
||||
normalized_cost = _to_non_negative_decimal(cost, field="cost")
|
||||
|
||||
self._input_tokens += normalized_input
|
||||
self._output_tokens += normalized_output
|
||||
self._total_tokens += normalized_total
|
||||
self._cost += normalized_cost
|
||||
|
||||
if currency is not None:
|
||||
normalized_currency = str(currency)
|
||||
if normalized_currency != self._currency:
|
||||
raise ValueError("currency mismatch")
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": self._cost,
|
||||
"currency": self._currency,
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrewAITemplate:
|
||||
agents: dict[str, Any]
|
||||
tasks: dict[str, Any]
|
||||
workflow: dict[str, Any]
|
||||
prompts: dict[str, str]
|
||||
tools_whitelist: set[str]
|
||||
|
||||
|
||||
def _default_static_root() -> Path:
|
||||
return Path(__file__).resolve().parents[3] / "config" / "static" / "agent_chat"
|
||||
|
||||
|
||||
def _read_yaml(file_path: Path) -> dict[str, Any]:
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"Required config file not found: {file_path}")
|
||||
with file_path.open("r", encoding="utf-8") as file:
|
||||
loaded = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"YAML file must be a mapping: {file_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def _read_prompt(file_path: Path) -> str:
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"Required prompt file not found: {file_path}")
|
||||
return file_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
|
||||
def validate_workflow_stages(stages: list[str]) -> None:
|
||||
expected = ["intent", "execution", "organization"]
|
||||
if stages != expected:
|
||||
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
|
||||
|
||||
|
||||
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
|
||||
root = static_root or _default_static_root()
|
||||
tools = _read_yaml(root / "tools.yaml")
|
||||
raw_tools = tools.get("tools", [])
|
||||
if not isinstance(raw_tools, list):
|
||||
raise ValueError("tools.yaml field 'tools' must be a list")
|
||||
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
|
||||
raise ValueError("tools.yaml list items must be non-empty strings")
|
||||
whitelist = {item.strip() for item in raw_tools}
|
||||
return whitelist
|
||||
|
||||
|
||||
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
|
||||
root = static_root or _default_static_root()
|
||||
crewai_root = root / "crewai"
|
||||
|
||||
agents = _read_yaml(crewai_root / "agents.yaml")
|
||||
tasks = _read_yaml(crewai_root / "tasks.yaml")
|
||||
workflow = _read_yaml(crewai_root / "workflow.yaml")
|
||||
|
||||
stages = workflow.get("stages")
|
||||
if not isinstance(stages, list):
|
||||
raise ValueError("workflow.yaml field 'stages' must be a list")
|
||||
validate_workflow_stages([str(stage) for stage in stages])
|
||||
|
||||
prompts = {
|
||||
"intent": _read_prompt(crewai_root / "prompts" / "intent.md"),
|
||||
"execution": _read_prompt(crewai_root / "prompts" / "execution.md"),
|
||||
"organization": _read_prompt(crewai_root / "prompts" / "organization.md"),
|
||||
}
|
||||
|
||||
return CrewAITemplate(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
workflow=workflow,
|
||||
prompts=prompts,
|
||||
tools_whitelist=load_tools_whitelist(root),
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
|
||||
missing = [field for field in required if field not in event]
|
||||
if missing:
|
||||
raise ValueError(f"Missing fields for {kind}: {missing}")
|
||||
|
||||
|
||||
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
kind = event.get("kind")
|
||||
|
||||
if kind == "run_started":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.started",
|
||||
"run_id": event["session_id"],
|
||||
}
|
||||
|
||||
if kind == "message_delta":
|
||||
_require_fields(event, kind=kind, required=["message_id", "delta"])
|
||||
return {
|
||||
"type": "message.delta",
|
||||
"message_id": event["message_id"],
|
||||
"delta": event["delta"],
|
||||
}
|
||||
|
||||
if kind == "tool_started":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.started",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
}
|
||||
|
||||
if kind == "tool_completed":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.completed",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
"result": event.get("result"),
|
||||
}
|
||||
|
||||
if kind == "run_completed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": event["session_id"],
|
||||
"output": event.get("output", ""),
|
||||
}
|
||||
|
||||
if kind == "run_failed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": event["session_id"],
|
||||
"error": event.get("error", ""),
|
||||
}
|
||||
|
||||
raise ValueError(f"Unsupported event kind: {kind}")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_started(*, run_id: str) -> dict[str, Any]:
|
||||
return {"type": "run.started", "run_id": run_id}
|
||||
|
||||
|
||||
def stage_completed(
|
||||
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
event: dict[str, Any] = {
|
||||
"type": "stage.completed",
|
||||
"run_id": run_id,
|
||||
"stage": stage,
|
||||
}
|
||||
if usage is not None:
|
||||
event["usage"] = usage
|
||||
return event
|
||||
|
||||
|
||||
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": run_id,
|
||||
"output": output,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
|
||||
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": run_id,
|
||||
"error": error,
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from core.agent_chat.storage_adapter import StorageAdapter
|
||||
|
||||
_ALLOWED_MIME_TYPES = {
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
}
|
||||
|
||||
|
||||
class _AsrTool(Protocol):
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttachmentInput:
|
||||
filename: str
|
||||
mime_type: str
|
||||
content: bytes
|
||||
origin: str = "user_upload"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProcessedAttachmentContext:
|
||||
attachments: list[dict[str, object]]
|
||||
preview_texts: list[str]
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
storage: StorageAdapter,
|
||||
asr_tool: _AsrTool,
|
||||
max_file_size_mb: int = 20,
|
||||
) -> None:
|
||||
self._storage = storage
|
||||
self._asr_tool = asr_tool
|
||||
self._max_size_bytes = max_file_size_mb * 1024 * 1024
|
||||
|
||||
def process(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
attachments: list[AttachmentInput],
|
||||
) -> ProcessedAttachmentContext:
|
||||
metadata_list: list[dict[str, object]] = []
|
||||
preview_texts: list[str] = []
|
||||
|
||||
for attachment in attachments:
|
||||
self._validate_attachment(attachment)
|
||||
checksum = hashlib.sha256(attachment.content).hexdigest()
|
||||
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
|
||||
object_path = self._storage.build_object_path(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_seq=message_seq,
|
||||
checksum_sha256=checksum,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
preview_text = self._build_preview_text(attachment)
|
||||
if preview_text:
|
||||
preview_texts.append(preview_text)
|
||||
|
||||
metadata = self._storage.build_attachment_metadata(
|
||||
object_path=object_path,
|
||||
mime_type=attachment.mime_type,
|
||||
size=len(attachment.content),
|
||||
checksum_sha256=checksum,
|
||||
origin=attachment.origin,
|
||||
preview_text=preview_text,
|
||||
)
|
||||
metadata_list.append(metadata)
|
||||
|
||||
return ProcessedAttachmentContext(
|
||||
attachments=metadata_list,
|
||||
preview_texts=preview_texts,
|
||||
)
|
||||
|
||||
def _validate_attachment(self, attachment: AttachmentInput) -> None:
|
||||
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
|
||||
raise ValueError("Unsupported MIME type")
|
||||
if len(attachment.content) > self._max_size_bytes:
|
||||
raise ValueError("Attachment exceeds max file size")
|
||||
|
||||
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
|
||||
if attachment.mime_type.startswith("audio/"):
|
||||
result = self._asr_tool.transcribe(
|
||||
audio_bytes=attachment.content,
|
||||
filename=attachment.filename,
|
||||
)
|
||||
text = result.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return None
|
||||
if attachment.mime_type == "text/plain":
|
||||
return attachment.content.decode("utf-8", errors="ignore")[:200]
|
||||
return None
|
||||
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from core.agent_chat.cost_tracker import CostTracker
|
||||
from core.agent_chat import events
|
||||
|
||||
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrchestratorResult:
|
||||
output: str
|
||||
usage: dict[str, Any]
|
||||
events: list[dict[str, Any]]
|
||||
context: dict[str, Any]
|
||||
failed: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intent_stage: StageCallable,
|
||||
execution_stage: StageCallable,
|
||||
organization_stage: StageCallable,
|
||||
) -> None:
|
||||
self._intent_stage = intent_stage
|
||||
self._execution_stage = execution_stage
|
||||
self._organization_stage = organization_stage
|
||||
|
||||
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
|
||||
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
tracker = CostTracker()
|
||||
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
stage_pipeline: list[tuple[str, StageCallable]] = [
|
||||
("intent", self._intent_stage),
|
||||
("execution", self._execution_stage),
|
||||
("organization", self._organization_stage),
|
||||
]
|
||||
|
||||
stage_output = user_message
|
||||
try:
|
||||
for stage_name, stage_callable in stage_pipeline:
|
||||
stage_result = await stage_callable(
|
||||
message=stage_output, context=context
|
||||
)
|
||||
stage_output = str(stage_result.get("content", stage_output))
|
||||
usage = stage_result.get("usage", {})
|
||||
if isinstance(usage, dict):
|
||||
tracker.add_usage(usage)
|
||||
emitted_events.append(
|
||||
events.stage_completed(
|
||||
run_id=run_id,
|
||||
stage=stage_name,
|
||||
usage=tracker.snapshot(),
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage=tracker.snapshot(),
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=True,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
summary = tracker.snapshot()
|
||||
emitted_events.append(
|
||||
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
|
||||
)
|
||||
return OrchestratorResult(
|
||||
output=stage_output,
|
||||
usage=summary,
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=False,
|
||||
error=None,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class StorageAdapter:
|
||||
_bucket: str
|
||||
|
||||
def __init__(self, bucket: str) -> None:
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._bucket
|
||||
|
||||
def build_object_path(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
checksum_sha256: str,
|
||||
extension: str,
|
||||
) -> str:
|
||||
normalized_ext = extension.strip(".").lower()
|
||||
return (
|
||||
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
|
||||
f"{checksum_sha256}.{normalized_ext}"
|
||||
)
|
||||
|
||||
def build_attachment_metadata(
|
||||
self,
|
||||
*,
|
||||
object_path: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
checksum_sha256: str,
|
||||
origin: str,
|
||||
preview_text: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"object_path": object_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"checksum_sha256": checksum_sha256,
|
||||
"origin": origin,
|
||||
"preview_text": preview_text,
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
TranscribeCallable = Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
class FunASRTool:
|
||||
_transcribe_callable: TranscribeCallable
|
||||
_model: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcribe_callable: TranscribeCallable | None = None,
|
||||
model: str = "fun-asr-realtime-2025-11-07",
|
||||
) -> None:
|
||||
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
|
||||
self._model = model
|
||||
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
|
||||
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
|
||||
return {
|
||||
"model": self._model,
|
||||
**payload,
|
||||
}
|
||||
|
||||
def _dashscope_transcribe(
|
||||
self, *, audio_bytes: bytes, filename: str
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
importlib.import_module("dashscope")
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("DashScope SDK is not installed") from exc
|
||||
|
||||
raise RuntimeError(
|
||||
"DashScope transcribe runtime integration is not configured yet"
|
||||
)
|
||||
@@ -132,6 +132,14 @@ class SupabaseSettings(BaseModel):
|
||||
return self.public_url
|
||||
|
||||
|
||||
class StorageSettings(BaseModel):
|
||||
provider: Literal["supabase"] = "supabase"
|
||||
bucket: str = Field(default="agent-chat-attachments", min_length=3, max_length=63)
|
||||
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
|
||||
max_file_size_mb: int = Field(default=20, ge=1, le=200)
|
||||
retention_days: int = Field(default=30, ge=1, le=3650)
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
@@ -163,6 +171,7 @@ class Settings(BaseSettings):
|
||||
cors: CorsSettings = CorsSettings()
|
||||
redis: RedisSettings = RedisSettings()
|
||||
supabase: SupabaseSettings = SupabaseSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
celery: CelerySettings = CelerySettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
intent:
|
||||
role: Intent Agent
|
||||
goal: Classify user intent and decide execution strategy
|
||||
execution:
|
||||
role: Execution Agent
|
||||
goal: Execute tasks with available tools
|
||||
organization:
|
||||
role: Organization Agent
|
||||
goal: Organize output for user-friendly response
|
||||
@@ -0,0 +1,2 @@
|
||||
你是任务执行代理。
|
||||
基于输入意图与上下文调用可用工具,并生成可验证执行结果。
|
||||
@@ -0,0 +1,2 @@
|
||||
你是意图识别代理。
|
||||
你的任务是识别用户输入的意图类型,并返回结构化意图标签。
|
||||
@@ -0,0 +1,2 @@
|
||||
你是结果整理代理。
|
||||
将执行结果组织为面向用户的清晰回复,保留关键信息与必要引用。
|
||||
@@ -0,0 +1,6 @@
|
||||
intent:
|
||||
description: Identify user intent and required capabilities
|
||||
execution:
|
||||
description: Execute intent with tools and model calls
|
||||
organization:
|
||||
description: Format final response and references
|
||||
@@ -0,0 +1,9 @@
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
|
||||
timeouts:
|
||||
intent_seconds: 8
|
||||
execution_seconds: 30
|
||||
organization_seconds: 10
|
||||
@@ -0,0 +1,25 @@
|
||||
factories:
|
||||
- name: qwen
|
||||
request_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
avatar: https://cdn.simpleicons.org/alibabacloud/FF6A00
|
||||
- name: minimax
|
||||
request_url: https://api.minimax.chat/v1
|
||||
avatar: https://cdn.simpleicons.org/minimax/1A1A1A
|
||||
- name: kimi
|
||||
request_url: https://api.moonshot.cn/v1
|
||||
avatar: https://cdn.simpleicons.org/moonrepo/3B82F6
|
||||
- name: deepseek
|
||||
request_url: https://api.deepseek.com/v1
|
||||
avatar: https://cdn.simpleicons.org/deepseek/4D6BFE
|
||||
- name: doubao
|
||||
request_url: https://ark.cn-beijing.volces.com/api/v3
|
||||
avatar: https://cdn.simpleicons.org/volkswagen/001E50
|
||||
- name: zai
|
||||
request_url: https://api.z.ai/v1
|
||||
avatar: https://cdn.simpleicons.org/zotero/CC2936
|
||||
|
||||
llms:
|
||||
- model_code: qwen3.5-flash
|
||||
factory_id: qwen
|
||||
- model_code: deepseek-v3.2
|
||||
factory_id: deepseek
|
||||
@@ -0,0 +1,3 @@
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- attachment_extract
|
||||
@@ -1,9 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
logger = get_logger("core.initialization.init_data")
|
||||
|
||||
|
||||
class LlmFactorySeed(BaseModel):
|
||||
name: str
|
||||
request_url: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class LlmSeed(BaseModel):
|
||||
model_code: str
|
||||
factory_id: str
|
||||
|
||||
|
||||
class LlmCatalogSeed(BaseModel):
|
||||
factories: list[LlmFactorySeed]
|
||||
llms: list[LlmSeed]
|
||||
|
||||
|
||||
def _default_catalog_path() -> Path:
|
||||
return (
|
||||
Path(__file__).resolve().parents[1]
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "agent_chat"
|
||||
/ "llm_catalog.yaml"
|
||||
)
|
||||
|
||||
|
||||
def load_llm_catalog(catalog_path: Path | None = None) -> dict[str, Any]:
|
||||
path = catalog_path or _default_catalog_path()
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"Invalid LLM catalog format: {path}")
|
||||
raw_factories = loaded.get("factories", [])
|
||||
raw_llms = loaded.get("llms", [])
|
||||
if not isinstance(raw_factories, list) or not isinstance(raw_llms, list):
|
||||
raise ValueError(f"Invalid LLM catalog sections: {path}")
|
||||
try:
|
||||
parsed = LlmCatalogSeed.model_validate(
|
||||
{
|
||||
"factories": list(raw_factories),
|
||||
"llms": list(raw_llms),
|
||||
}
|
||||
)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid LLM catalog data: {path}") from exc
|
||||
|
||||
return parsed.model_dump()
|
||||
|
||||
|
||||
async def _upsert_factory(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
name: str,
|
||||
request_url: str,
|
||||
avatar: str | None,
|
||||
) -> uuid.UUID:
|
||||
result = await session.execute(select(LlmFactory).where(LlmFactory.name == name))
|
||||
factory = result.scalar_one_or_none()
|
||||
|
||||
if factory is None:
|
||||
factory = LlmFactory(name=name, request_url=request_url, avatar=avatar)
|
||||
session.add(factory)
|
||||
await session.flush()
|
||||
else:
|
||||
factory.request_url = request_url
|
||||
factory.avatar = avatar
|
||||
|
||||
return factory.id
|
||||
|
||||
|
||||
async def _upsert_llm(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
model_code: str,
|
||||
factory_id: uuid.UUID,
|
||||
) -> None:
|
||||
result = await session.execute(select(Llm).where(Llm.model_code == model_code))
|
||||
llm = result.scalar_one_or_none()
|
||||
if llm is None:
|
||||
session.add(Llm(model_code=model_code, factory_id=factory_id))
|
||||
return
|
||||
llm.factory_id = factory_id
|
||||
|
||||
|
||||
async def initialize_data() -> bool:
|
||||
"""Initialize bootstrap data."""
|
||||
logger.info("Initializing data (no-op)")
|
||||
catalog = load_llm_catalog()
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
async with session.begin():
|
||||
factory_id_by_name: dict[str, uuid.UUID] = {}
|
||||
for factory in catalog["factories"]:
|
||||
factory_id = await _upsert_factory(
|
||||
session,
|
||||
name=factory["name"],
|
||||
request_url=factory["request_url"],
|
||||
avatar=factory.get("avatar"),
|
||||
)
|
||||
factory_id_by_name[factory["name"]] = factory_id
|
||||
|
||||
for llm in catalog["llms"]:
|
||||
factory_name = llm["factory_id"]
|
||||
resolved_factory_id = factory_id_by_name.get(factory_name)
|
||||
if resolved_factory_id is None:
|
||||
raise RuntimeError(
|
||||
f"Factory '{factory_name}' not found for model '{llm['model_code']}'"
|
||||
)
|
||||
await _upsert_llm(
|
||||
session,
|
||||
model_code=llm["model_code"],
|
||||
factory_id=resolved_factory_id,
|
||||
)
|
||||
|
||||
logger.info("Initialized LLM factory/model seed data")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user