refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置
This commit is contained in:
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.agent.event_bridge import map_internal_event
|
||||
|
||||
|
||||
class AguiAdapter:
|
||||
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
message = payload.get("message")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
raise ValueError("message is required")
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"session_id": payload.get("session_id"),
|
||||
}
|
||||
|
||||
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
return map_internal_event(event)
|
||||
@@ -1,67 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrewAITemplate:
|
||||
agents: dict[str, Any]
|
||||
tasks: dict[str, Any]
|
||||
workflow: dict[str, Any]
|
||||
tools_whitelist: set[str]
|
||||
|
||||
|
||||
def _default_static_root() -> Path:
|
||||
return Path(__file__).resolve().parents[3] / "config" / "static" / "crewai"
|
||||
|
||||
|
||||
def _read_yaml(file_path: Path) -> dict[str, Any]:
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"Required config file not found: {file_path}")
|
||||
with file_path.open("r", encoding="utf-8") as file:
|
||||
loaded = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"YAML file must be a mapping: {file_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def validate_workflow_stages(stages: list[str]) -> None:
|
||||
expected = ["intent", "execution", "organization"]
|
||||
if stages != expected:
|
||||
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
|
||||
|
||||
|
||||
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
|
||||
root = static_root or _default_static_root()
|
||||
tools = _read_yaml(root / "tools.yaml")
|
||||
raw_tools = tools.get("tools", [])
|
||||
if not isinstance(raw_tools, list):
|
||||
raise ValueError("tools.yaml field 'tools' must be a list")
|
||||
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
|
||||
raise ValueError("tools.yaml list items must be non-empty strings")
|
||||
whitelist = {item.strip() for item in raw_tools}
|
||||
return whitelist
|
||||
|
||||
|
||||
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
|
||||
root = static_root or _default_static_root()
|
||||
|
||||
agents = _read_yaml(root / "agents.yaml")
|
||||
tasks = _read_yaml(root / "tasks.yaml")
|
||||
workflow = _read_yaml(root / "workflow.yaml")
|
||||
|
||||
stages = workflow.get("stages")
|
||||
if not isinstance(stages, list):
|
||||
raise ValueError("workflow.yaml field 'stages' must be a list")
|
||||
validate_workflow_stages([str(stage) for stage in stages])
|
||||
|
||||
return CrewAITemplate(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
workflow=workflow,
|
||||
tools_whitelist=load_tools_whitelist(root),
|
||||
)
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
|
||||
missing = [field for field in required if field not in event]
|
||||
if missing:
|
||||
raise ValueError(f"Missing fields for {kind}: {missing}")
|
||||
|
||||
|
||||
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
kind = event.get("kind")
|
||||
|
||||
if kind == "run_started":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.started",
|
||||
"run_id": event["session_id"],
|
||||
}
|
||||
|
||||
if kind == "message_delta":
|
||||
_require_fields(event, kind=kind, required=["message_id", "delta"])
|
||||
return {
|
||||
"type": "message.delta",
|
||||
"message_id": event["message_id"],
|
||||
"delta": event["delta"],
|
||||
}
|
||||
|
||||
if kind == "tool_started":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.started",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
}
|
||||
|
||||
if kind == "tool_completed":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.completed",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
"result": event.get("result"),
|
||||
}
|
||||
|
||||
if kind == "run_completed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": event["session_id"],
|
||||
"output": event.get("output", ""),
|
||||
}
|
||||
|
||||
if kind == "run_failed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": event["session_id"],
|
||||
"error": event.get("error", ""),
|
||||
}
|
||||
|
||||
raise ValueError(f"Unsupported event kind: {kind}")
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_started(*, run_id: str) -> dict[str, Any]:
|
||||
return {"type": "run.started", "run_id": run_id}
|
||||
|
||||
|
||||
def stage_completed(
|
||||
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
event: dict[str, Any] = {
|
||||
"type": "stage.completed",
|
||||
"run_id": run_id,
|
||||
"stage": stage,
|
||||
}
|
||||
if usage is not None:
|
||||
event["usage"] = usage
|
||||
return event
|
||||
|
||||
|
||||
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": run_id,
|
||||
"output": output,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
|
||||
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": run_id,
|
||||
"error": error,
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_code: str
|
||||
factory_name: str
|
||||
litellm_model: str
|
||||
request_url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMResponse:
|
||||
content: str
|
||||
usage: dict[str, Any]
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
|
||||
self._config = config
|
||||
self._api_key = api_key or self._get_api_key(config.factory_name)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_key(factory_name: str) -> str:
|
||||
key_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"volcengine-ark": "ARK_API_KEY",
|
||||
"z-ai": "ZAI_API_KEY",
|
||||
}
|
||||
env_key = key_map.get(factory_name)
|
||||
if not env_key:
|
||||
raise ValueError(f"No API key mapping for factory: {factory_name}")
|
||||
key = os.environ.get(env_key)
|
||||
if not key:
|
||||
raise ValueError(f"Environment variable {env_key} is not set")
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def load_config(
|
||||
model_code: str,
|
||||
static_root: Path | None = None,
|
||||
) -> LLMConfig:
|
||||
root = static_root or (
|
||||
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
|
||||
)
|
||||
yaml_path = root / "llm_catalog.yaml"
|
||||
with yaml_path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
factories = {f["name"]: f for f in data.get("factories", [])}
|
||||
llms = data.get("llms", [])
|
||||
|
||||
for llm in llms:
|
||||
if llm.get("model_code") == model_code:
|
||||
factory_name = llm["factory_name"]
|
||||
factory = factories.get(factory_name)
|
||||
if not factory:
|
||||
raise ValueError(f"Factory not found: {factory_name}")
|
||||
return LLMConfig(
|
||||
model_code=model_code,
|
||||
factory_name=factory_name,
|
||||
litellm_model=llm.get("litellm_model", model_code),
|
||||
request_url=factory["request_url"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Model not found: {model_code}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = litellm.completion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
|
||||
def get_model_cost(usage: dict[str, Any]) -> Decimal:
|
||||
cost = usage.get("cost")
|
||||
if cost is None:
|
||||
return Decimal("0")
|
||||
return Decimal(str(cost))
|
||||
@@ -1,117 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from core.agent import events
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrchestratorResult:
|
||||
output: str
|
||||
usage: dict[str, Any]
|
||||
events: list[dict[str, Any]]
|
||||
context: dict[str, Any]
|
||||
failed: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class _UsageTracker:
|
||||
def __init__(self) -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
|
||||
def add_usage(self, usage: dict[str, Any]) -> None:
|
||||
input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0) or usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
total = usage.get("total_tokens")
|
||||
|
||||
self._input_tokens += input_tokens
|
||||
self._output_tokens += output_tokens
|
||||
self._total_tokens += total if total else (input_tokens + output_tokens)
|
||||
self._cost += get_model_cost(usage)
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": str(self._cost),
|
||||
}
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intent_stage: StageCallable,
|
||||
execution_stage: StageCallable,
|
||||
organization_stage: StageCallable,
|
||||
) -> None:
|
||||
self._intent_stage = intent_stage
|
||||
self._execution_stage = execution_stage
|
||||
self._organization_stage = organization_stage
|
||||
|
||||
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
|
||||
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
tracker = _UsageTracker()
|
||||
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
stage_pipeline: list[tuple[str, StageCallable]] = [
|
||||
("intent", self._intent_stage),
|
||||
("execution", self._execution_stage),
|
||||
("organization", self._organization_stage),
|
||||
]
|
||||
|
||||
stage_output = user_message
|
||||
try:
|
||||
for stage_name, stage_callable in stage_pipeline:
|
||||
stage_result = await stage_callable(
|
||||
message=stage_output, context=context
|
||||
)
|
||||
stage_output = str(stage_result.get("content", stage_output))
|
||||
usage = stage_result.get("usage", {})
|
||||
if isinstance(usage, dict):
|
||||
tracker.add_usage(usage)
|
||||
emitted_events.append(
|
||||
events.stage_completed(
|
||||
run_id=run_id,
|
||||
stage=stage_name,
|
||||
usage=tracker.snapshot(),
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage=tracker.snapshot(),
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=True,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
summary = tracker.snapshot()
|
||||
emitted_events.append(
|
||||
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
|
||||
)
|
||||
return OrchestratorResult(
|
||||
output=stage_output,
|
||||
usage=summary,
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=False,
|
||||
error=None,
|
||||
)
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
TranscribeCallable = Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
class FunASRTool:
|
||||
_transcribe_callable: TranscribeCallable
|
||||
_model: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcribe_callable: TranscribeCallable | None = None,
|
||||
model: str = "fun-asr-realtime-2025-11-07",
|
||||
) -> None:
|
||||
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
|
||||
self._model = model
|
||||
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
|
||||
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
|
||||
return {
|
||||
"model": self._model,
|
||||
**payload,
|
||||
}
|
||||
|
||||
def _dashscope_transcribe(
|
||||
self, *, audio_bytes: bytes, filename: str
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
importlib.import_module("dashscope")
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("DashScope SDK is not installed") from exc
|
||||
|
||||
raise RuntimeError(
|
||||
"DashScope transcribe runtime integration is not configured yet"
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- attachment_extract
|
||||
@@ -1,9 +0,0 @@
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
|
||||
timeouts:
|
||||
intent_seconds: 8
|
||||
execution_seconds: 30
|
||||
organization_seconds: 10
|
||||
Reference in New Issue
Block a user