test(backend): 更新后端测试文件
This commit is contained in:
@@ -1,53 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import RunAgentInput
|
||||
from agentscope.message import Msg
|
||||
|
||||
from core.agentscope.runtime.runner import (
|
||||
AgentScopeRunner,
|
||||
StageExecutionResult,
|
||||
SystemAgentRuntimeConfig,
|
||||
)
|
||||
from core.agentscope.utils import safe_json_loads_with_repair
|
||||
from schemas.agent.runtime_models import (
|
||||
RouterAgentOutput,
|
||||
UiMode,
|
||||
WorkerAgentOutputRich,
|
||||
)
|
||||
from schemas.agent.system_agent import AgentType, SystemAgentLLMConfig
|
||||
from schemas.user.context import UserContext, parse_profile_settings
|
||||
|
||||
|
||||
class _FakePipeline:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[dict[str, object]] = []
|
||||
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
self.events.append({"session_id": session_id, "event": event})
|
||||
return "1-0"
|
||||
|
||||
|
||||
class _FakeSessionCtx:
|
||||
def __init__(self, session: object) -> None:
|
||||
self._session = session
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
def _user_context() -> UserContext:
|
||||
return UserContext(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
settings=parse_profile_settings(None),
|
||||
)
|
||||
from core.agentscope.runtime.runner import AgentScopeRunner
|
||||
from schemas.agent.system_agent import AgentType
|
||||
|
||||
|
||||
def _run_input() -> RunAgentInput:
|
||||
@@ -64,215 +21,19 @@ def _run_input() -> RunAgentInput:
|
||||
)
|
||||
|
||||
|
||||
def _router_output(*, ui_mode: UiMode) -> RouterAgentOutput:
|
||||
return RouterAgentOutput.model_validate(
|
||||
{
|
||||
"normalized_task_input": {
|
||||
"user_text": "hello",
|
||||
"multimodal_summary": [],
|
||||
},
|
||||
"key_entities": [],
|
||||
"constraints": [],
|
||||
"task_typing": {"primary": "knowledge", "secondary": []},
|
||||
"execution_mode": "onestep",
|
||||
"result_typing": {"primary": "direct_answer", "secondary": []},
|
||||
"ui": {
|
||||
"ui_mode": ui_mode.value,
|
||||
"ui_decision_reason": "need structure"
|
||||
if ui_mode == UiMode.RICH
|
||||
else "plain text",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_resolve_stage_agent_type_defaults_to_worker() -> None:
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("") == AgentType.WORKER
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("worker") == AgentType.WORKER
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("unknown") == AgentType.WORKER
|
||||
|
||||
|
||||
def test_build_worker_input_messages_includes_field_guide() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
|
||||
messages = runner._build_worker_input_messages(
|
||||
router_output=_router_output(ui_mode=UiMode.NONE)
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
content = str(messages[0].content)
|
||||
assert "[Worker Contract]" in content
|
||||
assert "Use normalized_task_input as objective text." in content
|
||||
assert "multimodal_summary/key_entities/constraints" in content
|
||||
assert "key_entities" in content
|
||||
assert "constraints" in content
|
||||
assert "Infer deterministic missing required tool args" in content
|
||||
|
||||
|
||||
def test_safe_json_loads_with_repair_parses_valid_json() -> None:
|
||||
parsed = safe_json_loads_with_repair('{"operation":"create","title":"test"}')
|
||||
|
||||
assert parsed["operation"] == "create"
|
||||
assert parsed["title"] == "test"
|
||||
def test_resolve_stage_agent_type_supports_memory() -> None:
|
||||
assert AgentScopeRunner._resolve_stage_agent_type("memory") == AgentType.MEMORY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_uses_router_ui_mode_to_select_worker_output_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def test_resolve_runtime_client_time_from_forwarded_props() -> None:
|
||||
runner = AgentScopeRunner()
|
||||
pipeline = _FakePipeline()
|
||||
worker_model_holder: dict[str, type[object]] = {}
|
||||
|
||||
class _CommitSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.runner.AsyncSessionLocal",
|
||||
lambda: _FakeSessionCtx(_CommitSession()),
|
||||
)
|
||||
|
||||
async def _load_system_agent_config(**kwargs):
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=kwargs["agent_type"],
|
||||
model_code="qwen3.5-flash"
|
||||
if kwargs["agent_type"] == AgentType.ROUTER
|
||||
else "deepseek-chat",
|
||||
api_base_url="https://example.com/v1",
|
||||
api_key="sk-test",
|
||||
llm_config=SystemAgentLLMConfig(
|
||||
temperature=0.1, max_tokens=256, timeout_seconds=30
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runner, "_load_system_agent_config", _load_system_agent_config)
|
||||
|
||||
async def _run_router_stage(**kwargs):
|
||||
return StageExecutionResult(
|
||||
message=Msg(name="router", content="", role="assistant"),
|
||||
payload=_router_output(ui_mode=UiMode.RICH).model_dump(mode="json"),
|
||||
response_metadata={
|
||||
"model": "qwen3.5-flash",
|
||||
"inputTokens": 12,
|
||||
"outputTokens": 6,
|
||||
"cost": 0.001,
|
||||
"latencyMs": 50,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runner, "_run_router_stage", _run_router_stage)
|
||||
|
||||
async def _persist_router_message(**kwargs) -> None:
|
||||
assert kwargs["model_code"] == "qwen3.5-flash"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.runner.persist_router_message",
|
||||
_persist_router_message,
|
||||
)
|
||||
|
||||
async def _run_worker_stage(**kwargs):
|
||||
worker_model_holder["model"] = kwargs["worker_output_model"]
|
||||
return StageExecutionResult(
|
||||
message=Msg(name="worker", content="done", role="assistant"),
|
||||
payload=WorkerAgentOutputRich.model_validate(
|
||||
{
|
||||
"status": "success",
|
||||
"answer": "done",
|
||||
"key_points": [],
|
||||
"result_type": "direct_answer",
|
||||
"suggested_actions": [],
|
||||
"error": None,
|
||||
"ui_hints": None,
|
||||
}
|
||||
).model_dump(mode="json", exclude_none=True),
|
||||
response_metadata={
|
||||
"model": "deepseek-chat",
|
||||
"inputTokens": 8,
|
||||
"outputTokens": 4,
|
||||
"cost": 0.002,
|
||||
"latencyMs": 40,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runner, "_run_worker_stage", _run_worker_stage)
|
||||
|
||||
result = await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=pipeline,
|
||||
run_input=_run_input(),
|
||||
)
|
||||
|
||||
assert worker_model_holder["model"].__name__ == "WorkerAgentOutputRich"
|
||||
event_types = []
|
||||
for item in pipeline.events:
|
||||
event = item.get("event")
|
||||
if isinstance(event, dict):
|
||||
event_types.append(event.get("type"))
|
||||
assert event_types == [
|
||||
"STEP_STARTED",
|
||||
"STEP_FINISHED",
|
||||
"STEP_STARTED",
|
||||
"STEP_FINISHED",
|
||||
]
|
||||
assert result["router"]["ui"]["ui_mode"] == "rich"
|
||||
assert result["worker"]["answer"] == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_passes_runtime_client_time_to_router_and_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
runner = AgentScopeRunner()
|
||||
pipeline = _FakePipeline()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _CommitSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.runner.AsyncSessionLocal",
|
||||
lambda: _FakeSessionCtx(_CommitSession()),
|
||||
)
|
||||
|
||||
async def _load_system_agent_config(**kwargs):
|
||||
return SystemAgentRuntimeConfig(
|
||||
agent_type=kwargs["agent_type"],
|
||||
model_code="model-a",
|
||||
api_base_url="https://example.com/v1",
|
||||
api_key="sk-test",
|
||||
llm_config=SystemAgentLLMConfig(
|
||||
temperature=0.1, max_tokens=256, timeout_seconds=30
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runner, "_load_system_agent_config", _load_system_agent_config)
|
||||
|
||||
async def _run_router_stage(**kwargs):
|
||||
captured["router_timezone"] = kwargs["runtime_client_time"].device_timezone
|
||||
return StageExecutionResult(
|
||||
message=Msg(name="router", content="", role="assistant"),
|
||||
payload=_router_output(ui_mode=UiMode.NONE).model_dump(mode="json"),
|
||||
response_metadata={},
|
||||
)
|
||||
|
||||
async def _run_worker_stage(**kwargs):
|
||||
captured["worker_timezone"] = kwargs["runtime_client_time"].device_timezone
|
||||
return StageExecutionResult(
|
||||
message=Msg(name="worker", content="ok", role="assistant"),
|
||||
payload={
|
||||
"status": "success",
|
||||
"answer": "ok",
|
||||
"key_points": [],
|
||||
"result_type": "direct_answer",
|
||||
"suggested_actions": [],
|
||||
"error": None,
|
||||
},
|
||||
response_metadata={},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runner, "_run_router_stage", _run_router_stage)
|
||||
monkeypatch.setattr(runner, "_run_worker_stage", _run_worker_stage)
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.runner.persist_router_message", AsyncMock()
|
||||
)
|
||||
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "00000000-0000-0000-0000-000000000010",
|
||||
@@ -291,38 +52,6 @@ async def test_execute_passes_runtime_client_time_to_router_and_worker(
|
||||
}
|
||||
)
|
||||
|
||||
await runner.execute(
|
||||
user_context=_user_context(),
|
||||
context_messages=[],
|
||||
pipeline=pipeline,
|
||||
run_input=run_input,
|
||||
)
|
||||
|
||||
assert captured["router_timezone"] == "America/Los_Angeles"
|
||||
assert captured["worker_timezone"] == "America/Los_Angeles"
|
||||
|
||||
|
||||
def test_resolve_provider_api_key_maps_volcengine_to_ark(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.runner.config.llm.provider_keys",
|
||||
{"ARK": "ark-key", "DASHSCOPE": "dash-key"},
|
||||
)
|
||||
|
||||
assert (
|
||||
AgentScopeRunner._resolve_provider_api_key(factory_name="volcengine")
|
||||
== "ark-key"
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_provider_api_key_raises_when_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agentscope.runtime.runner.config.llm.provider_keys",
|
||||
{"DASHSCOPE": "dash-key"},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="provider api key missing"):
|
||||
AgentScopeRunner._resolve_provider_api_key(factory_name="deepseek")
|
||||
resolved = runner._resolve_runtime_client_time(run_input=run_input)
|
||||
assert resolved is not None
|
||||
assert resolved.device_timezone == "America/Los_Angeles"
|
||||
|
||||
Reference in New Issue
Block a user