feat(agent): migrate to native CrewAI tool loop and async resume enqueue
This commit is contained in:
@@ -25,22 +25,39 @@ from models.system_agents import SystemAgents
|
||||
async def test_run_then_resume_persists_messages_and_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def _fake_execute(self, *, user_input: str) -> dict[str, object]:
|
||||
del user_input
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"agui_events": [
|
||||
{"type": "TEXT_MESSAGE_START", "data": {"session_id": "__TBD__"}},
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"data": {"session_id": "__TBD__", "text": "Mocked answer"},
|
||||
call_count = {"n": 0}
|
||||
|
||||
def _fake_execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
) -> dict[str, object]:
|
||||
del self, user_input, system_prompt, tools
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return {
|
||||
"assistant_text": "请确认是否跳转。",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"pending_front_tool": {
|
||||
"name": "front.navigate_to_route",
|
||||
"args": {"target": "/calendar/dayweek", "replace": False},
|
||||
"target": "frontend",
|
||||
},
|
||||
{"type": "TEXT_MESSAGE_END", "data": {"session_id": "__TBD__"}},
|
||||
],
|
||||
"agui_events": [],
|
||||
}
|
||||
return {
|
||||
"assistant_text": "已继续执行并完成。",
|
||||
"prompt_tokens": 3,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 5,
|
||||
"cost": 0.001,
|
||||
"pending_front_tool": None,
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -85,12 +102,17 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
await seed_session.commit()
|
||||
|
||||
published: list[str] = []
|
||||
queued_commands: list[dict[str, object]] = []
|
||||
|
||||
async def _publish(event: dict[str, object]) -> None:
|
||||
event_type = event.get("type")
|
||||
if isinstance(event_type, str):
|
||||
published.append(event_type)
|
||||
|
||||
async def _enqueue(command: dict[str, object]) -> str:
|
||||
queued_commands.append(command)
|
||||
return "task-followup-1"
|
||||
|
||||
try:
|
||||
run_input_payload = {
|
||||
"threadId": str(session_uuid),
|
||||
@@ -101,7 +123,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"name": "front.navigate_to_route",
|
||||
"description": "navigate route",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
@@ -115,6 +137,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
"run_input": run_input_payload,
|
||||
},
|
||||
publish_event=_publish,
|
||||
enqueue_command=_enqueue,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
@@ -138,7 +161,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
"toolCallId": pending_tool_call_id,
|
||||
"content": json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolName": "front.navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
@@ -158,6 +181,16 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
},
|
||||
},
|
||||
publish_event=_publish,
|
||||
enqueue_command=_enqueue,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
|
||||
assert len(queued_commands) == 1
|
||||
await run_agent_task(
|
||||
queued_commands[0],
|
||||
publish_event=_publish,
|
||||
enqueue_command=_enqueue,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
@@ -168,8 +201,8 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
assert db_session is not None
|
||||
assert db_session.status == AgentChatSessionStatus.COMPLETED
|
||||
assert db_session.message_count == 4
|
||||
assert db_session.total_tokens == 18
|
||||
assert db_session.total_cost == Decimal("0.002500")
|
||||
assert db_session.total_tokens == 23
|
||||
assert db_session.total_cost == Decimal("0.003500")
|
||||
assert db_session.state_snapshot == {
|
||||
"status": "completed",
|
||||
"pending_tool_call_id": None,
|
||||
@@ -193,6 +226,7 @@ async def test_run_then_resume_persists_messages_and_session_state(
|
||||
assert messages[1].input_tokens == 11
|
||||
assert messages[1].output_tokens == 7
|
||||
assert messages[1].cost == Decimal("0.002500")
|
||||
assert messages[3].content == "已继续执行并完成。"
|
||||
|
||||
assert "RUN_STARTED" in published
|
||||
assert "RUN_FINISHED" in published
|
||||
|
||||
@@ -134,7 +134,7 @@ def test_patch_me_validation_error_returns_problem_details() -> None:
|
||||
assert response.status_code == 422
|
||||
assert response.headers["content-type"].startswith("application/problem+json")
|
||||
body = response.json()
|
||||
assert body["title"] == "Unprocessable Content"
|
||||
assert body["title"] in {"Unprocessable Content", "Unprocessable Entity"}
|
||||
assert body["status"] == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -17,9 +17,7 @@ class _FakeAgentService:
|
||||
def __init__(self) -> None:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(
|
||||
self, *, run_input: RunAgentInput, current_user: CurrentUser
|
||||
):
|
||||
async def enqueue_run(self, *, run_input: RunAgentInput, current_user: CurrentUser):
|
||||
del current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
@@ -287,3 +285,64 @@ def test_run_rejects_oversized_user_text_payload() -> None:
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_rejects_client_supplied_history_messages() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-history",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{"id": "a1", "role": "assistant", "content": "old"},
|
||||
{"id": "u1", "role": "user", "content": "new"},
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_resume_accepts_tool_message_without_user_message() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/agent/runs/00000000-0000-0000-0000-000000000001/resume",
|
||||
json={
|
||||
"threadId": "00000000-0000-0000-0000-000000000001",
|
||||
"runId": "run-resume-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"role": "tool",
|
||||
"toolCallId": "call-1",
|
||||
"content": '{"toolName":"navigate_to_route","toolArgs":{"target":"/calendar/dayweek"},"nonce":"n1","result":{"ok":true}}',
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
assert response.json()["taskId"] == "task-resume-1"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -1,551 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from types import MethodType, SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver, SettingsLike
|
||||
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
|
||||
from core.agent.infrastructure.litellm.usage_tracker import UsageCost
|
||||
|
||||
|
||||
def test_runtime_emits_text_tool_reasoning_events() -> None:
|
||||
def _build_runtime() -> CrewAIRuntime:
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
default_model_code="", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
runtime = CrewAIRuntime(
|
||||
return CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="gpt-4o-mini",
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_maps_agui_events() -> None:
|
||||
runtime = _build_runtime()
|
||||
events = runtime.map_events(
|
||||
[
|
||||
{"type": "textMessageContent", "data": {"text": "hello"}},
|
||||
{"type": "toolCallStart", "data": {"tool_name": "weather"}},
|
||||
{"type": "toolCallResult", "data": {"ok": True}},
|
||||
{"type": "reasoningMessageContent", "data": {"text": "thinking"}},
|
||||
{"type": "runFinished", "data": {"status": "completed"}},
|
||||
]
|
||||
)
|
||||
|
||||
assert [event["type"] for event in events] == [
|
||||
"TEXT_MESSAGE_CONTENT",
|
||||
"TOOL_CALL_START",
|
||||
"TOOL_CALL_RESULT",
|
||||
"REASONING_MESSAGE_CONTENT",
|
||||
"RUN_FINISHED",
|
||||
]
|
||||
|
||||
|
||||
def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
def test_runtime_direct_execution_short_circuit() -> None:
|
||||
runtime = _build_runtime()
|
||||
|
||||
def _fake_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["model"] = model
|
||||
captured["api_key"] = api_key
|
||||
captured["messages"] = messages
|
||||
captured["temperature"] = temperature
|
||||
captured["max_tokens"] = max_tokens
|
||||
captured["timeout"] = timeout
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"route":"DIRECT_EXECUTION","intent_summary":"greet",'
|
||||
'"assistant_text":"hello","safety_flags":[]}'
|
||||
),
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
def _fake_run_stage(self, **kwargs):
|
||||
stage = kwargs["stage"]
|
||||
if stage == "intent":
|
||||
return (
|
||||
'{"route":"DIRECT_EXECUTION","intent_summary":"greet","assistant_text":"hello","safety_flags":[]}',
|
||||
UsageCost(1, 2, 3, 0.01),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
raise AssertionError("unexpected stage")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=2,
|
||||
total_tokens=3,
|
||||
cost=0.001,
|
||||
),
|
||||
)
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
llm_config=SystemAgentLLMConfig(temperature=0.3, max_tokens=256),
|
||||
)
|
||||
|
||||
result = runtime.execute(user_input="hi")
|
||||
|
||||
assert captured["model"] == "dashscope/qwen3.5-flash"
|
||||
assert captured["api_key"] == "env-api-key"
|
||||
assert captured["temperature"] == 0.3
|
||||
assert captured["max_tokens"] == 256
|
||||
assert captured["timeout"] == 30.0
|
||||
runtime._run_stage_with_crewai = MethodType(_fake_run_stage, runtime) # type: ignore[method-assign]
|
||||
result = runtime.execute(user_input="hi", tools=[])
|
||||
assert result["assistant_text"] == "hello"
|
||||
assert result["pending_front_tool"] is None
|
||||
assert result["total_tokens"] == 3
|
||||
|
||||
|
||||
def test_runtime_execute_injects_system_prompt_and_intent_template(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
def test_runtime_needs_execution_and_collects_front_tool_call() -> None:
|
||||
runtime = _build_runtime()
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
def _fake_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
captured["messages"] = messages
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"route":"DIRECT_EXECUTION","intent_summary":"greet",'
|
||||
'"assistant_text":"ok","safety_flags":[]}'
|
||||
),
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost=0.001,
|
||||
),
|
||||
)
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
runtime.execute(user_input="hello", system_prompt="USER_PROFILE_BLOCK")
|
||||
|
||||
messages = captured["messages"]
|
||||
assert isinstance(messages, list)
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "USER_PROFILE_BLOCK" in str(messages[0]["content"])
|
||||
assert "Intent Agent" in str(messages[0]["content"])
|
||||
assert messages[1] == {"role": "user", "content": "hello"}
|
||||
|
||||
|
||||
def test_runtime_execute_short_circuits_on_direct_execution(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
calls: list[list[dict[str, object]]] = []
|
||||
|
||||
responses = [
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"route":"DIRECT_EXECUTION","intent_summary":"greet",'
|
||||
'"assistant_text":"hello direct","safety_flags":[]}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
]
|
||||
usage_values = [
|
||||
SimpleNamespace(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=3,
|
||||
total_tokens=5,
|
||||
cost=0.01,
|
||||
def _fake_run_stage(self, **kwargs):
|
||||
calls.append(
|
||||
{
|
||||
"stage": kwargs["stage"],
|
||||
"tools": kwargs["tools_payload"],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _fake_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
return responses.pop(0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: usage_values.pop(0),
|
||||
)
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
result = runtime.execute(user_input="hello", system_prompt="USER_PROFILE_BLOCK")
|
||||
|
||||
assert len(calls) == 1
|
||||
assert result["assistant_text"] == "hello direct"
|
||||
assert result["prompt_tokens"] == 2
|
||||
assert result["completion_tokens"] == 3
|
||||
assert result["total_tokens"] == 5
|
||||
assert result["cost"] == 0.01
|
||||
|
||||
|
||||
def test_runtime_execute_runs_execution_and_organization_stages(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
calls: list[list[dict[str, object]]] = []
|
||||
responses = [
|
||||
{
|
||||
"choices": [
|
||||
stage = kwargs["stage"]
|
||||
if stage == "intent":
|
||||
return (
|
||||
'{"route":"NEEDS_EXECUTION","intent_summary":"need tool","execution_brief":"do it","safety_flags":[]}',
|
||||
UsageCost(1, 1, 2, 0.01),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
if stage == "execution":
|
||||
return (
|
||||
'{"status":"SUCCESS","execution_summary":"done","execution_data":{},"report_brief":"ok"}',
|
||||
UsageCost(2, 2, 4, 0.02),
|
||||
[],
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"route":"NEEDS_EXECUTION","intent_summary":"need tools",'
|
||||
'"execution_brief":"fetch data","safety_flags":[]}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"status":"SUCCESS","execution_summary":"done",'
|
||||
'"execution_data":{"k":"v"},"report_brief":"brief"}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"assistant_text":"final answer",'
|
||||
'"response_metadata":{"source":"organization"}}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
]
|
||||
usage_values = [
|
||||
SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost=0.01,
|
||||
),
|
||||
SimpleNamespace(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=2,
|
||||
total_tokens=4,
|
||||
cost=0.02,
|
||||
),
|
||||
SimpleNamespace(
|
||||
prompt_tokens=3,
|
||||
completion_tokens=3,
|
||||
total_tokens=6,
|
||||
cost=0.03,
|
||||
),
|
||||
]
|
||||
"name": "front.navigate_to_route",
|
||||
"args": {"target": "/calendar/dayweek"},
|
||||
"target": "frontend",
|
||||
},
|
||||
)
|
||||
return (
|
||||
'{"assistant_text":"final answer","response_metadata":{"source":"organization"}}',
|
||||
UsageCost(3, 3, 6, 0.03),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
def _fake_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
return responses.pop(0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: usage_values.pop(0),
|
||||
)
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
runtime._run_stage_with_crewai = MethodType(_fake_run_stage, runtime) # type: ignore[method-assign]
|
||||
result = runtime.execute(
|
||||
user_input="go",
|
||||
tools=[
|
||||
{
|
||||
"name": "front.navigate_to_route",
|
||||
"description": "navigate",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = runtime.execute(user_input="hello", system_prompt="USER_PROFILE_BLOCK")
|
||||
|
||||
assert len(calls) == 3
|
||||
assert "Intent Agent" in str(calls[0][0]["content"])
|
||||
assert "Execution Agent" in str(calls[1][0]["content"])
|
||||
assert "Organization Agent" in str(calls[2][0]["content"])
|
||||
assert result["assistant_text"] == "final answer"
|
||||
assert result["prompt_tokens"] == 6
|
||||
assert result["completion_tokens"] == 6
|
||||
assert result["total_tokens"] == 12
|
||||
assert result["cost"] == 0.06
|
||||
assert [item["stage"] for item in calls] == ["intent", "execution"]
|
||||
for item in calls:
|
||||
tools = item["tools"]
|
||||
assert isinstance(tools, list)
|
||||
assert any(t.get("name") == "front.navigate_to_route" for t in tools)
|
||||
execution_tools = calls[1]["tools"]
|
||||
assert any(t.get("name") == "back.create_calendar_event" for t in execution_tools)
|
||||
assert result["assistant_text"] == "do it"
|
||||
assert result["pending_front_tool"] == {
|
||||
"name": "front.navigate_to_route",
|
||||
"args": {"target": "/calendar/dayweek"},
|
||||
"target": "frontend",
|
||||
}
|
||||
assert result["total_tokens"] == 6
|
||||
|
||||
|
||||
def test_runtime_execute_rejects_invalid_intent_json(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
def _fake_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, messages, temperature, max_tokens
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "not-json",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost=0.01,
|
||||
),
|
||||
)
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
try:
|
||||
runtime.execute(user_input="hello", system_prompt="USER_PROFILE_BLOCK")
|
||||
raise AssertionError("expected ValueError")
|
||||
except ValueError as exc:
|
||||
assert "invalid intent stage output" in str(exc)
|
||||
|
||||
|
||||
def test_runtime_execute_minimizes_prompt_and_execution_payload(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
calls: list[list[dict[str, object]]] = []
|
||||
responses = [
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"route":"NEEDS_EXECUTION","intent_summary":"need tools",'
|
||||
'"execution_brief":"fetch data","safety_flags":[]}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"status":"SUCCESS","execution_summary":"done",'
|
||||
'"execution_data":{"secret":"secret_value"},'
|
||||
'"report_brief":"brief"}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": (
|
||||
'{"assistant_text":"final answer",'
|
||||
'"response_metadata":{"source":"organization"}}'
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
]
|
||||
usage_values = [
|
||||
SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost=0.01,
|
||||
),
|
||||
SimpleNamespace(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=2,
|
||||
total_tokens=4,
|
||||
cost=0.02,
|
||||
),
|
||||
SimpleNamespace(
|
||||
prompt_tokens=3,
|
||||
completion_tokens=3,
|
||||
total_tokens=6,
|
||||
cost=0.03,
|
||||
),
|
||||
]
|
||||
|
||||
def _fake_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_key: str,
|
||||
messages: list[dict[str, object]],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
del model, api_key, temperature, max_tokens
|
||||
calls.append(messages)
|
||||
return responses.pop(0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: usage_values.pop(0),
|
||||
)
|
||||
settings = cast(
|
||||
SettingsLike,
|
||||
SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
),
|
||||
)
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(settings=settings),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
runtime.execute(user_input="hello", system_prompt="USER_PROFILE_BLOCK")
|
||||
|
||||
assert "USER_PROFILE_BLOCK" in str(calls[0][0]["content"])
|
||||
assert "USER_PROFILE_BLOCK" not in str(calls[1][0]["content"])
|
||||
assert "USER_PROFILE_BLOCK" not in str(calls[2][0]["content"])
|
||||
assert "secret_value" not in str(calls[2][1]["content"])
|
||||
def test_runtime_backend_registry_check() -> None:
|
||||
runtime = _build_runtime()
|
||||
assert runtime.is_registered_backend_tool("back.create_calendar_event") is True
|
||||
assert runtime.is_registered_backend_tool("back.unknown") is False
|
||||
|
||||
@@ -9,6 +9,7 @@ from ag_ui.core import RunAgentInput
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.domain.agui_input import validate_run_request_messages_contract
|
||||
from core.agent.domain.system_agent_config import SystemAgentLLMConfig
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
@@ -92,8 +93,12 @@ def _build_resume_input(
|
||||
if payload is None:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolArgs": {"target": "/calendar/dayweek", "replace": False, "__nonce": "nonce-1"},
|
||||
"toolName": "front.navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
"__nonce": "nonce-1",
|
||||
},
|
||||
"nonce": "nonce-1",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
@@ -178,7 +183,7 @@ async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_name": "front.navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
@@ -217,7 +222,7 @@ async def test_resume_service_validates_pending_tool_guard_and_persists_payload(
|
||||
|
||||
assert captured[0]["role"] == AgentChatMessageRole.TOOL
|
||||
stored_payload = json.loads(captured[0]["content"])
|
||||
assert stored_payload["toolName"] == "navigate_to_route"
|
||||
assert stored_payload["toolName"] == "front.navigate_to_route"
|
||||
assert stored_payload["result"]["ok"] is True
|
||||
assert stored_payload["result"]["applied"] is True
|
||||
assert "ui" not in stored_payload
|
||||
@@ -259,7 +264,7 @@ async def test_resume_service_rejects_mismatched_nonce(
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_name": "front.navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
@@ -296,7 +301,7 @@ async def test_resume_service_rejects_mismatched_nonce(
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolName": "front.navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
@@ -348,7 +353,7 @@ async def test_resume_service_rejects_tool_result_when_not_ok(
|
||||
total_cost=0,
|
||||
state_snapshot={
|
||||
"pending_tool_call_id": "call-1",
|
||||
"pending_tool_name": "navigate_to_route",
|
||||
"pending_tool_name": "front.navigate_to_route",
|
||||
"pending_tool_args_sha256": "c8e6573e6ce79d9a8052d5167e18602841fa4248f3e8a2efb448c4bfbd298a12",
|
||||
"pending_tool_nonce": "nonce-1",
|
||||
},
|
||||
@@ -385,7 +390,7 @@ async def test_resume_service_rejects_tool_result_when_not_ok(
|
||||
tool_call_id="call-1",
|
||||
content=json.dumps(
|
||||
{
|
||||
"toolName": "navigate_to_route",
|
||||
"toolName": "front.navigate_to_route",
|
||||
"toolArgs": {
|
||||
"target": "/calendar/dayweek",
|
||||
"replace": False,
|
||||
@@ -524,9 +529,36 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
async def execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del session, owner_id
|
||||
assert tool_name == "back.create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
):
|
||||
captured["user_input"] = user_input
|
||||
captured["system_prompt"] = system_prompt
|
||||
captured["tools"] = tools
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 2,
|
||||
@@ -556,6 +588,7 @@ async def test_run_service_passes_user_context_system_prompt_to_runtime(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
@@ -646,8 +679,37 @@ async def test_run_service_emits_frontend_tool_pending_events(
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
def is_registered_backend_tool(self, tool_name: str) -> bool:
|
||||
return tool_name == "back.create_calendar_event"
|
||||
|
||||
async def execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del session, owner_id
|
||||
assert tool_name == "back.create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
):
|
||||
del user_input, system_prompt, tools
|
||||
return {
|
||||
"assistant_text": "请确认是否跳转。",
|
||||
"prompt_tokens": 1,
|
||||
@@ -655,6 +717,11 @@ async def test_run_service_emits_frontend_tool_pending_events(
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
"pending_front_tool": {
|
||||
"name": "front.navigate_to_route",
|
||||
"args": {"target": "/calendar/dayweek", "replace": False},
|
||||
"target": "frontend",
|
||||
},
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
@@ -702,10 +769,10 @@ async def test_run_service_emits_frontend_tool_pending_events(
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text="帮我打开日历",
|
||||
text="请帮我处理这个请求",
|
||||
tools=[
|
||||
{
|
||||
"name": "navigate_to_route",
|
||||
"name": "front.navigate_to_route",
|
||||
"description": "navigate",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
@@ -714,14 +781,16 @@ async def test_run_service_emits_frontend_tool_pending_events(
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is not None
|
||||
tool_start = next(event for event in result["events"] if event["type"] == "TOOL_CALL_START")
|
||||
assert tool_start["toolCallName"] == "navigate_to_route"
|
||||
tool_start = next(
|
||||
event for event in result["events"] if event["type"] == "TOOL_CALL_START"
|
||||
)
|
||||
assert tool_start["toolCallName"] == "front.navigate_to_route"
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.RUNNING
|
||||
snapshot = runtime_state["state_snapshot"]
|
||||
assert isinstance(snapshot, dict)
|
||||
assert snapshot["pending_tool_name"] == "navigate_to_route"
|
||||
assert snapshot["pending_tool_name"] == "front.navigate_to_route"
|
||||
assert isinstance(snapshot["pending_tool_args_sha256"], str)
|
||||
assert isinstance(snapshot["pending_tool_nonce"], str)
|
||||
|
||||
@@ -779,8 +848,37 @@ async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
del user_input, system_prompt
|
||||
def is_registered_backend_tool(self, tool_name: str) -> bool:
|
||||
return tool_name == "back.create_calendar_event"
|
||||
|
||||
async def execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del session, owner_id
|
||||
assert tool_name == "back.create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
):
|
||||
del user_input, system_prompt, tools
|
||||
return {
|
||||
"assistant_text": "日历事件已创建。",
|
||||
"prompt_tokens": 1,
|
||||
@@ -810,26 +908,6 @@ async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_execute_backend_tool(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
owner_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
):
|
||||
del self, session, owner_id
|
||||
assert tool_name == "create_calendar_event"
|
||||
assert "title" in tool_args
|
||||
return {
|
||||
"result": {"eventId": "evt-1", "ok": True},
|
||||
"ui": {
|
||||
"type": "calendar_card.v1",
|
||||
"version": "v1",
|
||||
"data": {"id": "evt-1", "title": "会议"},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
@@ -850,19 +928,14 @@ async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._execute_backend_tool",
|
||||
_fake_execute_backend_tool,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
result = await service.run(
|
||||
run_input=_build_run_input(
|
||||
thread_id=str(session_id),
|
||||
text='#tool:create_calendar_event {"title":"会议","startAt":"2026-03-07T08:00:00Z"}',
|
||||
text="请安排一个明早会议",
|
||||
tools=[
|
||||
{
|
||||
"name": "create_calendar_event",
|
||||
"name": "back.create_calendar_event",
|
||||
"description": "create calendar",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
@@ -871,7 +944,7 @@ async def test_run_service_executes_backend_calendar_tool_and_emits_result(
|
||||
)
|
||||
|
||||
assert result["pending_tool_call_id"] is None
|
||||
assert any(event["type"] == "TOOL_CALL_RESULT" for event in result["events"])
|
||||
assert all(event["type"] != "TOOL_CALL_RESULT" for event in result["events"])
|
||||
runtime_state = captured["update_runtime_state"]
|
||||
assert isinstance(runtime_state, dict)
|
||||
assert runtime_state["status"] == AgentChatSessionStatus.COMPLETED
|
||||
@@ -929,7 +1002,9 @@ async def test_load_user_agent_context_defaults_when_profile_missing() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_user_agent_context_defaults_when_profile_settings_not_dict() -> None:
|
||||
async def test_load_user_agent_context_defaults_when_profile_settings_not_dict() -> (
|
||||
None
|
||||
):
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
profile = SimpleNamespace(
|
||||
@@ -952,7 +1027,9 @@ async def test_load_user_agent_context_defaults_when_profile_settings_not_dict()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_user_agent_context_falls_back_for_invalid_profile_settings() -> None:
|
||||
async def test_load_user_agent_context_falls_back_for_invalid_profile_settings() -> (
|
||||
None
|
||||
):
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
profile = SimpleNamespace(
|
||||
@@ -1093,9 +1170,16 @@ async def test_run_service_still_executes_when_profile_missing(
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(self, *, user_input: str, system_prompt: str | None = None):
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
):
|
||||
captured["user_input"] = user_input
|
||||
captured["system_prompt"] = system_prompt
|
||||
captured["tools"] = tools
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 2,
|
||||
@@ -1138,3 +1222,222 @@ async def test_run_service_still_executes_when_profile_missing(
|
||||
payload = json.loads(system_prompt.split("# USER_PROFILE (JSON)\n", maxsplit=1)[1])
|
||||
assert payload["username"] == ""
|
||||
assert payload["ai_language"] == "zh-CN"
|
||||
|
||||
|
||||
def test_validate_run_request_messages_contract_allows_single_user_multiblock() -> None:
|
||||
run_input = RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": str(uuid4()),
|
||||
"runId": "run-multiblock",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "请分析"},
|
||||
{"type": "text", "text": " 这张图"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
)
|
||||
|
||||
validate_run_request_messages_contract(run_input)
|
||||
|
||||
|
||||
def test_compose_runtime_user_input_includes_history_context() -> None:
|
||||
service = RunService()
|
||||
|
||||
composed = service._compose_runtime_user_input(
|
||||
user_input="帮我创建会议",
|
||||
history_context="user: 之前消息\nassistant: 之前回复",
|
||||
)
|
||||
|
||||
assert "Server history context (today and previous day):" in composed
|
||||
assert "user: 之前消息" in composed
|
||||
assert "Current user input:" in composed
|
||||
assert composed.endswith("帮我创建会议")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_context_cache_hit_and_mismatch(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
|
||||
class _FakeRedisClient:
|
||||
def __init__(self) -> None:
|
||||
self.payload = json.dumps(
|
||||
{
|
||||
"message_count": 3,
|
||||
"context": "user: hi\nassistant: hello",
|
||||
},
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> str:
|
||||
del key
|
||||
return self.payload
|
||||
|
||||
async def _fake_get_or_init_redis_client():
|
||||
return _FakeRedisClient()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.get_or_init_redis_client",
|
||||
_fake_get_or_init_redis_client,
|
||||
)
|
||||
|
||||
service = RunService()
|
||||
hit = await service._read_history_context_cache(
|
||||
session_id=session_id,
|
||||
expected_message_count=3,
|
||||
)
|
||||
miss = await service._read_history_context_cache(
|
||||
session_id=session_id,
|
||||
expected_message_count=4,
|
||||
)
|
||||
|
||||
assert hit == "user: hi\nassistant: hello"
|
||||
assert miss is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_passes_server_history_context_into_runtime(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_id = uuid4()
|
||||
user_id = uuid4()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDbSession:
|
||||
async def commit(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __call__(self) -> "_FakeSessionFactory":
|
||||
return self
|
||||
|
||||
async def __aenter__(self) -> _FakeDbSession:
|
||||
return _FakeDbSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
del exc_type, exc, tb
|
||||
return False
|
||||
|
||||
class _FakeSessionRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def lock_session_for_update(self, *, session_id: object):
|
||||
return SimpleNamespace(
|
||||
id=session_id,
|
||||
user_id=user_id,
|
||||
status=AgentChatSessionStatus.PENDING,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=0,
|
||||
state_snapshot=None,
|
||||
)
|
||||
|
||||
async def next_message_seq(self, *, session_id: object):
|
||||
del session_id
|
||||
return 1
|
||||
|
||||
async def update_runtime_state(self, **kwargs) -> None:
|
||||
captured["update_runtime_state"] = kwargs
|
||||
|
||||
class _FakeMessageRepository:
|
||||
def __init__(self, session: object) -> None:
|
||||
del session
|
||||
|
||||
async def append_message(self, **kwargs) -> None:
|
||||
captured.setdefault("messages", []).append(kwargs)
|
||||
|
||||
class _FakeRuntime:
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
user_input: str,
|
||||
system_prompt: str | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
):
|
||||
captured["user_input"] = user_input
|
||||
del system_prompt, tools
|
||||
return {
|
||||
"assistant_text": "ok",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost": "0.001",
|
||||
"agui_events": [],
|
||||
}
|
||||
|
||||
async def _fake_load_agent_model_selection(self, _session):
|
||||
del self
|
||||
return ("qwen3.5-flash", "dashscope", SystemAgentLLMConfig())
|
||||
|
||||
async def _fake_load_user_agent_context(self, session, session_id, user_id):
|
||||
del self, session, session_id
|
||||
return SimpleNamespace(
|
||||
user_id=user_id,
|
||||
username="demo-user",
|
||||
bio=None,
|
||||
settings=SimpleNamespace(
|
||||
preferences=SimpleNamespace(
|
||||
interface_language="zh-CN",
|
||||
ai_language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
country="CN",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _fake_load_recent_history_context(
|
||||
self,
|
||||
session,
|
||||
session_id,
|
||||
expected_message_count,
|
||||
):
|
||||
del self, session, session_id, expected_message_count
|
||||
return "user: 昨天内容\nassistant: 昨天回复"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.SessionRepository",
|
||||
_FakeSessionRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.MessageRepository",
|
||||
_FakeMessageRepository,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.create_runtime",
|
||||
lambda **_kwargs: _FakeRuntime(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_agent_model_selection",
|
||||
_fake_load_agent_model_selection,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_user_agent_context",
|
||||
_fake_load_user_agent_context,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.application.run_service.RunService._load_recent_history_context",
|
||||
_fake_load_recent_history_context,
|
||||
)
|
||||
|
||||
service = RunService(session_factory=_FakeSessionFactory()) # type: ignore[arg-type]
|
||||
await service.run(
|
||||
run_input=_build_run_input(thread_id=str(session_id), text="今天问题")
|
||||
)
|
||||
|
||||
sent_input = captured["user_input"]
|
||||
assert isinstance(sent_input, str)
|
||||
assert "Server history context (today and previous day):" in sent_input
|
||||
assert "user: 昨天内容" in sent_input
|
||||
assert sent_input.endswith("今天问题")
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.agent.infrastructure.persistence.user_context_cache import UserContext
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, dict[str, str]] = {}
|
||||
self.set_store: dict[str, set[str]] = {}
|
||||
self.expire_calls: list[tuple[str, int]] = []
|
||||
self.delete_calls: list[str] = []
|
||||
self.hincrby_calls: list[tuple[str, str, int]] = []
|
||||
@@ -34,10 +35,22 @@ class _FakeRedis:
|
||||
self.expire_calls.append((key, seconds))
|
||||
return 1
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
self.delete_calls.append(key)
|
||||
self.store.pop(key, None)
|
||||
return 1
|
||||
async def delete(self, *keys: str) -> int:
|
||||
for key in keys:
|
||||
self.delete_calls.append(key)
|
||||
self.store.pop(key, None)
|
||||
self.set_store.pop(key, None)
|
||||
return len(keys)
|
||||
|
||||
async def sadd(self, key: str, *values: str) -> int:
|
||||
bucket = self.set_store.setdefault(key, set())
|
||||
before = len(bucket)
|
||||
for value in values:
|
||||
bucket.add(value)
|
||||
return len(bucket) - before
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
return set(self.set_store.get(key, set()))
|
||||
|
||||
|
||||
class _BrokenRedis:
|
||||
@@ -57,7 +70,15 @@ class _BrokenRedis:
|
||||
del key, seconds
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
async def delete(self, *keys: str) -> int:
|
||||
del keys
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def sadd(self, key: str, *values: str) -> int:
|
||||
del key, values
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
async def smembers(self, key: str) -> set[str]:
|
||||
del key
|
||||
raise RuntimeError("redis down")
|
||||
|
||||
@@ -89,12 +110,39 @@ async def test_user_context_cache_set_and_get_hit() -> None:
|
||||
assert loaded is not None
|
||||
assert loaded.user_id == context.user_id
|
||||
assert loaded.username == "demo-user"
|
||||
assert redis.expire_calls == [(f"agent:user-context:{session_id}", 600)]
|
||||
assert redis.expire_calls == [
|
||||
(f"agent:user-context:{session_id}", 600),
|
||||
(f"agent:user-context:sessions:{context.user_id}", 600),
|
||||
]
|
||||
assert redis.hincrby_calls == [
|
||||
(f"agent:user-context:{session_id}", "turns_used", 1)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_cache_invalidate_user_deletes_all_sessions() -> None:
|
||||
redis = _FakeRedis()
|
||||
cache = UserContextCache(
|
||||
client=redis,
|
||||
key_prefix="agent:user-context",
|
||||
ttl_seconds=600,
|
||||
max_turns=3,
|
||||
)
|
||||
context = _build_context()
|
||||
s1 = uuid4()
|
||||
s2 = uuid4()
|
||||
|
||||
await cache.set(session_id=s1, context=context)
|
||||
await cache.set(session_id=s2, context=context)
|
||||
|
||||
deleted = await cache.invalidate_user(user_id=context.user_id)
|
||||
|
||||
assert deleted == 2
|
||||
assert f"agent:user-context:{s1}" in redis.delete_calls
|
||||
assert f"agent:user-context:{s2}" in redis.delete_calls
|
||||
assert f"agent:user-context:sessions:{context.user_id}" in redis.delete_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_cache_invalidates_when_exceeds_max_turns() -> None:
|
||||
redis = _FakeRedis()
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.users.schemas import UserUpdateRequest
|
||||
from v1.users.service import UserService
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeProfile:
|
||||
id: object
|
||||
username: str
|
||||
avatar_url: str | None
|
||||
bio: str | None
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self, profile: _FakeProfile | None) -> None:
|
||||
self._profile = profile
|
||||
self.update_calls: list[tuple[object, dict[str, str | None]]] = []
|
||||
|
||||
async def update_by_user_id(
|
||||
self, user_id: object, update_data: dict[str, str | None]
|
||||
):
|
||||
self.update_calls.append((user_id, update_data))
|
||||
if self._profile is None:
|
||||
return None
|
||||
return _FakeProfile(
|
||||
id=self._profile.id,
|
||||
username=update_data.get("username") or self._profile.username,
|
||||
avatar_url=update_data.get("avatar_url")
|
||||
if "avatar_url" in update_data
|
||||
else self._profile.avatar_url,
|
||||
bio=update_data.get("bio") if "bio" in update_data else self._profile.bio,
|
||||
)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.commit_called = 0
|
||||
self.rollback_called = 0
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_called += 1
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rollback_called += 1
|
||||
|
||||
|
||||
class _FakeUserContextCache:
|
||||
def __init__(self, *, should_fail: bool = False) -> None:
|
||||
self.should_fail = should_fail
|
||||
self.invalidated_user_ids: list[object] = []
|
||||
|
||||
async def invalidate_user(self, *, user_id: object) -> int:
|
||||
self.invalidated_user_ids.append(user_id)
|
||||
if self.should_fail:
|
||||
raise RuntimeError("cache down")
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_invalidates_user_context_cache() -> None:
|
||||
user_id = uuid4()
|
||||
repo = _FakeRepository(
|
||||
_FakeProfile(id=user_id, username="old", avatar_url=None, bio=None)
|
||||
)
|
||||
session = _FakeSession()
|
||||
cache = _FakeUserContextCache()
|
||||
service = UserService(
|
||||
repository=repo,
|
||||
session=session, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=user_id),
|
||||
user_context_cache=cache, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
result = await service.update_me(UserUpdateRequest(username="new-name"))
|
||||
|
||||
assert result.username == "new-name"
|
||||
assert session.commit_called == 1
|
||||
assert cache.invalidated_user_ids == [user_id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_me_succeeds_when_cache_invalidation_fails() -> None:
|
||||
user_id = uuid4()
|
||||
repo = _FakeRepository(
|
||||
_FakeProfile(id=user_id, username="old", avatar_url=None, bio=None)
|
||||
)
|
||||
session = _FakeSession()
|
||||
cache = _FakeUserContextCache(should_fail=True)
|
||||
service = UserService(
|
||||
repository=repo,
|
||||
session=session, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=user_id),
|
||||
user_context_cache=cache, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
result = await service.update_me(UserUpdateRequest(username="new-name"))
|
||||
|
||||
assert result.username == "new-name"
|
||||
assert session.commit_called == 1
|
||||
assert cache.invalidated_user_ids == [user_id]
|
||||
Reference in New Issue
Block a user