refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.agui_adapter import AguiAdapter
|
||||
|
||||
|
||||
def test_to_command_maps_payload_fields() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
command = adapter.to_command(
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
assert command["message"] == "hello"
|
||||
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_to_protocol_event_maps_internal_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
mapped = adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": "run-1",
|
||||
"output": "done",
|
||||
}
|
||||
)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
|
||||
|
||||
|
||||
def test_to_protocol_event_raises_for_invalid_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
adapter.to_protocol_event({"kind": "unknown"})
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_transcribe_uses_injected_dashscope_callable() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert filename == "voice.wav"
|
||||
assert audio_bytes == b"audio"
|
||||
return {"text": "你好", "request_id": "req-1"}
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
|
||||
assert result["text"] == "你好"
|
||||
assert result["request_id"] == "req-1"
|
||||
assert result["model"] == "fun-asr-realtime-2025-11-07"
|
||||
|
||||
|
||||
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
raise RuntimeError("upstream timeout")
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
@@ -1,48 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
|
||||
def test_get_model_cost_returns_decimal() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0.002500")
|
||||
|
||||
|
||||
def test_get_model_cost_with_no_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0")
|
||||
|
||||
|
||||
def test_get_model_cost_with_zero_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": "0",
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0")
|
||||
|
||||
|
||||
def test_get_model_cost_with_numeric_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": 0.0025,
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0.0025")
|
||||
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.event_bridge import map_internal_event
|
||||
|
||||
|
||||
def test_map_run_started_event() -> None:
|
||||
event = {"kind": "run_started", "session_id": "s1"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.started", "run_id": "s1"}
|
||||
|
||||
|
||||
def test_map_message_delta_event() -> None:
|
||||
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
|
||||
def test_map_tool_events() -> None:
|
||||
started = {
|
||||
"kind": "tool_started",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
}
|
||||
completed = {
|
||||
"kind": "tool_completed",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
"result": "ok",
|
||||
}
|
||||
|
||||
mapped_started = map_internal_event(started)
|
||||
mapped_completed = map_internal_event(completed)
|
||||
|
||||
assert mapped_started["type"] == "tool.started"
|
||||
assert mapped_started["tool_name"] == "asr_fun_asr"
|
||||
assert mapped_completed["type"] == "tool.completed"
|
||||
assert mapped_completed["result"] == "ok"
|
||||
|
||||
|
||||
def test_map_run_completed_event() -> None:
|
||||
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
|
||||
|
||||
|
||||
def test_map_unknown_event_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "unknown"})
|
||||
|
||||
|
||||
def test_map_event_missing_required_field_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "message_delta", "message_id": "m1"})
|
||||
@@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.orchestrator import AgentChatOrchestrator
|
||||
|
||||
|
||||
async def _intent_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("intent")
|
||||
return {
|
||||
"content": f"intent:{message}",
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
|
||||
}
|
||||
|
||||
|
||||
async def _execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
return {
|
||||
"content": f"execution:{message}",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
|
||||
}
|
||||
|
||||
|
||||
async def _organization_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("organization")
|
||||
return {
|
||||
"content": "final answer",
|
||||
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
|
||||
}
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution", "organization"]
|
||||
assert result.output == "final answer"
|
||||
assert result.usage["total_tokens"] == 13
|
||||
assert result.events[0]["type"] == "run.started"
|
||||
assert result.events[-1]["type"] == "run.completed"
|
||||
|
||||
|
||||
async def _failing_execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_failing_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution"]
|
||||
assert result.events[-1]["type"] == "run.failed"
|
||||
assert result.events[-1]["run_id"] == "run-2"
|
||||
assert "boom" in (result.events[-1].get("error") or "")
|
||||
assert result.failed is True
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
def test_orchestrator_emits_stage_event_payload_shape() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
|
||||
|
||||
for event in result.events:
|
||||
assert "type" in event
|
||||
assert event.get("run_id") == "run-3"
|
||||
|
||||
stage_events = [
|
||||
event for event in result.events if event["type"] == "stage.completed"
|
||||
]
|
||||
assert [event["stage"] for event in stage_events] == [
|
||||
"intent",
|
||||
"execution",
|
||||
"organization",
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from v1.agent.service import build_session_title
|
||||
|
||||
|
||||
def test_build_session_title_truncates_first_message() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title(
|
||||
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
|
||||
)
|
||||
|
||||
assert len(title) == 24
|
||||
|
||||
|
||||
def test_build_session_title_falls_back_when_message_empty() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title("\n ", now=now)
|
||||
|
||||
assert title == "新对话 2026-02-25 10:30"
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.agui_adapter import AguiAdapter
|
||||
|
||||
|
||||
def test_to_command_maps_payload_fields() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
command = adapter.to_command(
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
assert command["message"] == "hello"
|
||||
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_to_protocol_event_maps_internal_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
mapped = adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": "run-1",
|
||||
"output": "done",
|
||||
}
|
||||
)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
|
||||
|
||||
|
||||
def test_to_protocol_event_raises_for_invalid_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
adapter.to_protocol_event({"kind": "unknown"})
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_transcribe_uses_injected_dashscope_callable() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert filename == "voice.wav"
|
||||
assert audio_bytes == b"audio"
|
||||
return {"text": "你好", "request_id": "req-1"}
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
|
||||
assert result["text"] == "你好"
|
||||
assert result["request_id"] == "req-1"
|
||||
assert result["model"] == "fun-asr-realtime-2025-11-07"
|
||||
|
||||
|
||||
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
raise RuntimeError("upstream timeout")
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cost_tracker import CostTracker
|
||||
|
||||
|
||||
def test_normalize_usage_and_cost_aggregation() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
)
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 3,
|
||||
"cost": "0.003000",
|
||||
"currency": "USD",
|
||||
}
|
||||
)
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 12
|
||||
assert snapshot["output_tokens"] == 8
|
||||
assert snapshot["total_tokens"] == 20
|
||||
assert snapshot["cost"] == Decimal("0.005500")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_negative_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": -1})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"cost": "-0.010000"})
|
||||
|
||||
|
||||
def test_snapshot_is_zero_before_any_usage() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 0
|
||||
assert snapshot["output_tokens"] == 0
|
||||
assert snapshot["total_tokens"] == 0
|
||||
assert snapshot["cost"] == Decimal("0")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_currency_mismatch() -> None:
|
||||
tracker = CostTracker(currency="USD")
|
||||
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"cost": "0.001000",
|
||||
"currency": "CNY",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_add_usage_rejects_non_integral_token_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": 1.5})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"output_tokens": True})
|
||||
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.event_bridge import map_internal_event
|
||||
|
||||
|
||||
def test_map_run_started_event() -> None:
|
||||
event = {"kind": "run_started", "session_id": "s1"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.started", "run_id": "s1"}
|
||||
|
||||
|
||||
def test_map_message_delta_event() -> None:
|
||||
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
|
||||
def test_map_tool_events() -> None:
|
||||
started = {
|
||||
"kind": "tool_started",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
}
|
||||
completed = {
|
||||
"kind": "tool_completed",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
"result": "ok",
|
||||
}
|
||||
|
||||
mapped_started = map_internal_event(started)
|
||||
mapped_completed = map_internal_event(completed)
|
||||
|
||||
assert mapped_started["type"] == "tool.started"
|
||||
assert mapped_started["tool_name"] == "asr_fun_asr"
|
||||
assert mapped_completed["type"] == "tool.completed"
|
||||
assert mapped_completed["result"] == "ok"
|
||||
|
||||
|
||||
def test_map_run_completed_event() -> None:
|
||||
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
|
||||
|
||||
|
||||
def test_map_unknown_event_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "unknown"})
|
||||
|
||||
|
||||
def test_map_event_missing_required_field_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "message_delta", "message_id": "m1"})
|
||||
@@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.orchestrator import AgentChatOrchestrator
|
||||
|
||||
|
||||
async def _intent_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("intent")
|
||||
return {
|
||||
"content": f"intent:{message}",
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
|
||||
}
|
||||
|
||||
|
||||
async def _execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
return {
|
||||
"content": f"execution:{message}",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
|
||||
}
|
||||
|
||||
|
||||
async def _organization_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("organization")
|
||||
return {
|
||||
"content": "final answer",
|
||||
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
|
||||
}
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution", "organization"]
|
||||
assert result.output == "final answer"
|
||||
assert result.usage["total_tokens"] == 13
|
||||
assert result.events[0]["type"] == "run.started"
|
||||
assert result.events[-1]["type"] == "run.completed"
|
||||
|
||||
|
||||
async def _failing_execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_failing_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution"]
|
||||
assert result.events[-1]["type"] == "run.failed"
|
||||
assert result.events[-1]["run_id"] == "run-2"
|
||||
assert "boom" in (result.events[-1].get("error") or "")
|
||||
assert result.failed is True
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
def test_orchestrator_emits_stage_event_payload_shape() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
|
||||
|
||||
for event in result.events:
|
||||
assert "type" in event
|
||||
assert event.get("run_id") == "run-3"
|
||||
|
||||
stage_events = [
|
||||
event for event in result.events if event["type"] == "stage.completed"
|
||||
]
|
||||
assert [event["stage"] for event in stage_events] == [
|
||||
"intent",
|
||||
"execution",
|
||||
"organization",
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from v1.agent.service import build_session_title
|
||||
|
||||
|
||||
def test_build_session_title_truncates_first_message() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title(
|
||||
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
|
||||
)
|
||||
|
||||
assert len(title) == 24
|
||||
|
||||
|
||||
def test_build_session_title_falls_back_when_message_empty() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title("\n ", now=now)
|
||||
|
||||
assert title == "新对话 2026-02-25 10:30"
|
||||
@@ -1,132 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.crewai.template_loader import (
|
||||
load_crewai_template,
|
||||
load_tools_whitelist,
|
||||
validate_workflow_stages,
|
||||
)
|
||||
|
||||
|
||||
def _write(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _prepare_static_root(root: Path) -> Path:
|
||||
_write(
|
||||
root / "agents.yaml",
|
||||
"""
|
||||
intent:
|
||||
role: Intent Agent
|
||||
execution:
|
||||
role: Execution Agent
|
||||
organization:
|
||||
role: Organization Agent
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "tasks.yaml",
|
||||
"""
|
||||
intent:
|
||||
description: classify
|
||||
execution:
|
||||
description: run task
|
||||
organization:
|
||||
description: summarize
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- doc_extract
|
||||
""".strip(),
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
|
||||
template = load_crewai_template(static_root)
|
||||
|
||||
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
|
||||
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
|
||||
assert template.workflow["stages"] == ["intent", "execution", "organization"]
|
||||
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
(static_root / "tasks.yaml").unlink()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
_write(
|
||||
static_root / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- execution
|
||||
- intent
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
|
||||
whitelist = load_tools_whitelist(static_root)
|
||||
|
||||
assert whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
|
||||
validate_workflow_stages(["intent", "execution", "organization"])
|
||||
|
||||
|
||||
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution"])
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution", "organization", "extra"])
|
||||
|
||||
|
||||
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
_write(
|
||||
static_root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- 123
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_tools_whitelist(static_root)
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from core.config.initial import init_data
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
|
||||
catalog_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "src"
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "llm_catalog.yaml"
|
||||
)
|
||||
|
||||
catalog = init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
assert len(catalog["factories"]) == 6
|
||||
assert len(catalog["llms"]) == 2
|
||||
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
|
||||
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_name"}
|
||||
|
||||
|
||||
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
|
||||
catalog_path = tmp_path / "llm_catalog.yaml"
|
||||
catalog_path.write_text(
|
||||
"""
|
||||
factories:
|
||||
- name: qwen
|
||||
llms:
|
||||
- model_code: qwen3.5-flash
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
|
||||
first = await init_data.initialize_data()
|
||||
second = await init_data.initialize_data()
|
||||
|
||||
assert first is True
|
||||
assert second is True
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 6
|
||||
assert llm_count == 2
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
monkeypatch.setattr(
|
||||
init_data,
|
||||
"load_llm_catalog",
|
||||
lambda *_: {
|
||||
"factories": [
|
||||
{
|
||||
"name": "qwen",
|
||||
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"avatar": "https://cdn.example.com/qwen.png",
|
||||
}
|
||||
],
|
||||
"llms": [
|
||||
{
|
||||
"model_code": "qwen3.5-flash",
|
||||
"factory_id": "missing_factory",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await init_data.initialize_data()
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 0
|
||||
assert llm_count == 0
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def test_user_agent_catalog_file_exists_and_has_required_fields() -> None:
|
||||
catalog_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "src"
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "user_agent_catalog.yaml"
|
||||
)
|
||||
|
||||
assert catalog_path.exists(), f"Catalog file not found: {catalog_path}"
|
||||
|
||||
catalog = init_data.load_user_agent_catalog(catalog_path)
|
||||
|
||||
assert "agents" in catalog
|
||||
assert isinstance(catalog["agents"], list)
|
||||
assert len(catalog["agents"]) == 3
|
||||
|
||||
for agent in catalog["agents"]:
|
||||
assert "agent_type" in agent
|
||||
assert "llm_model_code" in agent
|
||||
assert "status" in agent
|
||||
assert "config" in agent
|
||||
assert isinstance(agent["config"], dict)
|
||||
|
||||
|
||||
def test_load_user_agent_catalog_raises_on_invalid_structure(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
catalog_path = tmp_path / "user_agent_catalog.yaml"
|
||||
catalog_path.write_text(
|
||||
"""
|
||||
agents:
|
||||
- agent_type: TEST
|
||||
llm_model_code: test-model
|
||||
status: ACTIVE
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user agent catalog"):
|
||||
init_data.load_user_agent_catalog(catalog_path)
|
||||
Reference in New Issue
Block a user