feat(agent): complete closed-loop runtime and pricing fallback
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.agui.bridge import to_agui_events
|
||||
from core.agent.infrastructure.agui.stream import to_sse_event
|
||||
|
||||
|
||||
def test_bridge_normalizes_event_type_to_upper_snake() -> None:
|
||||
events = [{"type": "runStarted", "data": {"ok": True}}]
|
||||
|
||||
out = to_agui_events(events)
|
||||
|
||||
assert out[0]["type"] == "RUN_STARTED"
|
||||
|
||||
|
||||
def test_bridge_supports_core_agui_event_taxonomy() -> None:
|
||||
events = [
|
||||
{"type": "runStarted", "data": {}},
|
||||
{"type": "runFinished", "data": {}},
|
||||
{"type": "stepStarted", "data": {}},
|
||||
{"type": "stepFinished", "data": {}},
|
||||
{"type": "textMessageStart", "data": {}},
|
||||
{"type": "textMessageContent", "data": {}},
|
||||
{"type": "textMessageEnd", "data": {}},
|
||||
{"type": "toolCallStart", "data": {}},
|
||||
{"type": "toolCallArgs", "data": {}},
|
||||
{"type": "toolCallEnd", "data": {}},
|
||||
{"type": "toolCallResult", "data": {}},
|
||||
{"type": "stateSnapshot", "data": {}},
|
||||
{"type": "stateDelta", "data": {}},
|
||||
{"type": "reasoningMessageStart", "data": {}},
|
||||
{"type": "reasoningMessageContent", "data": {}},
|
||||
{"type": "reasoningMessageEnd", "data": {}},
|
||||
]
|
||||
|
||||
out = to_agui_events(events)
|
||||
|
||||
assert [event["type"] for event in out] == [
|
||||
"RUN_STARTED",
|
||||
"RUN_FINISHED",
|
||||
"STEP_STARTED",
|
||||
"STEP_FINISHED",
|
||||
"TEXT_MESSAGE_START",
|
||||
"TEXT_MESSAGE_CONTENT",
|
||||
"TEXT_MESSAGE_END",
|
||||
"TOOL_CALL_START",
|
||||
"TOOL_CALL_ARGS",
|
||||
"TOOL_CALL_END",
|
||||
"TOOL_CALL_RESULT",
|
||||
"STATE_SNAPSHOT",
|
||||
"STATE_DELTA",
|
||||
"REASONING_MESSAGE_START",
|
||||
"REASONING_MESSAGE_CONTENT",
|
||||
"REASONING_MESSAGE_END",
|
||||
]
|
||||
|
||||
|
||||
def test_bridge_preserves_common_agui_fields() -> None:
|
||||
events = [
|
||||
{
|
||||
"type": "toolCallResult",
|
||||
"id": "evt-1",
|
||||
"run_id": "run-1",
|
||||
"timestamp": "2026-03-05T12:00:00Z",
|
||||
"parent_message_id": "msg-1",
|
||||
"data": {"ok": True},
|
||||
}
|
||||
]
|
||||
|
||||
out = to_agui_events(events)
|
||||
|
||||
assert out[0]["type"] == "TOOL_CALL_RESULT"
|
||||
assert out[0]["id"] == "evt-1"
|
||||
assert out[0]["run_id"] == "run-1"
|
||||
assert out[0]["timestamp"] == "2026-03-05T12:00:00Z"
|
||||
assert out[0]["parent_message_id"] == "msg-1"
|
||||
|
||||
|
||||
def test_bridge_rejects_empty_event_type() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_agui_events([{"type": "", "data": {}}])
|
||||
|
||||
|
||||
def test_bridge_rejects_non_object_data() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_agui_events([{"type": "runStarted", "data": "not-object"}])
|
||||
|
||||
|
||||
def test_bridge_redacts_sensitive_fields_in_data() -> None:
|
||||
out = to_agui_events(
|
||||
[
|
||||
{
|
||||
"type": "toolCallArgs",
|
||||
"data": {
|
||||
"api_key": "k-1",
|
||||
"payload": {"authorization": "Bearer x"},
|
||||
"safe": "ok",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert out[0]["data"]["api_key"] == "***REDACTED***"
|
||||
assert out[0]["data"]["payload"]["authorization"] == "***REDACTED***"
|
||||
assert out[0]["data"]["safe"] == "ok"
|
||||
|
||||
|
||||
def test_bridge_redacts_sensitive_key_variants() -> None:
|
||||
out = to_agui_events(
|
||||
[
|
||||
{
|
||||
"type": "toolCallArgs",
|
||||
"data": {
|
||||
"x-api-key": "k-2",
|
||||
"auth_token": "t-1",
|
||||
"openaiApiKey": "k-3",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert out[0]["data"]["x-api-key"] == "***REDACTED***"
|
||||
assert out[0]["data"]["auth_token"] == "***REDACTED***"
|
||||
assert out[0]["data"]["openaiApiKey"] == "***REDACTED***"
|
||||
|
||||
|
||||
def test_bridge_rejects_unknown_event_type() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_agui_events([{"type": "NOT_A_REAL_EVENT", "data": {}}])
|
||||
|
||||
|
||||
def test_sse_format_includes_id_event_data() -> None:
|
||||
payload = to_sse_event(
|
||||
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
|
||||
)
|
||||
|
||||
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_runtime_raises_if_model_or_api_key_missing() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={}),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
resolver.resolve(model_code="", provider_name="dashscope")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
|
||||
|
||||
|
||||
def test_runtime_reads_provider_api_key_from_settings() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="gpt-4o-mini",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-like-api-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
resolved = resolver.resolve(model_code="", provider_name="dashscope")
|
||||
|
||||
assert resolved.model_code == "gpt-4o-mini"
|
||||
assert resolved.provider_api_key == "env-like-api-key"
|
||||
|
||||
|
||||
def test_runtime_reads_provider_api_key_from_env(monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "env-key")
|
||||
resolver = AgentConfigResolver(settings=Settings())
|
||||
|
||||
resolved = resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
|
||||
|
||||
assert resolved.provider_api_key == "env-key"
|
||||
|
||||
|
||||
def test_runtime_supports_provider_alias_to_env_key() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="deepseek-v3.2",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"ark": "ark-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
resolved = resolver.resolve(model_code="", provider_name="volcengine-ark")
|
||||
|
||||
assert resolved.provider_api_key == "ark-key"
|
||||
|
||||
|
||||
def test_runtime_rejects_unsupported_provider() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="qwen3.5-flash", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "dash-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
resolver.resolve(model_code="", provider_name="unknown-provider")
|
||||
|
||||
|
||||
def test_runtime_config_repr_does_not_expose_api_key() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="qwen3.5-flash", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "very-secret-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
resolved = resolver.resolve(model_code="", provider_name="dashscope")
|
||||
|
||||
assert "very-secret-key" not in repr(resolved)
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver
|
||||
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
|
||||
|
||||
|
||||
def test_runtime_emits_text_tool_reasoning_events() -> None:
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
)
|
||||
),
|
||||
model_code="gpt-4o-mini",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
events = runtime.map_events(
|
||||
[
|
||||
{"type": "textMessageContent", "data": {"text": "hello"}},
|
||||
{"type": "toolCallStart", "data": {"tool_name": "weather"}},
|
||||
{"type": "toolCallResult", "data": {"ok": True}},
|
||||
{"type": "reasoningMessageContent", "data": {"text": "thinking"}},
|
||||
{"type": "runFinished", "data": {"status": "completed"}},
|
||||
]
|
||||
)
|
||||
|
||||
assert [event["type"] for event in events] == [
|
||||
"TEXT_MESSAGE_CONTENT",
|
||||
"TOOL_CALL_START",
|
||||
"TOOL_CALL_RESULT",
|
||||
"REASONING_MESSAGE_CONTENT",
|
||||
"RUN_FINISHED",
|
||||
]
|
||||
|
||||
|
||||
def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_completion(
|
||||
*, model: str, api_key: str, messages: list[dict[str, object]]
|
||||
):
|
||||
captured["model"] = model
|
||||
captured["api_key"] = api_key
|
||||
captured["messages"] = messages
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "hello",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=2,
|
||||
total_tokens=3,
|
||||
cost=0.001,
|
||||
),
|
||||
)
|
||||
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
)
|
||||
),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
result = runtime.execute(user_input="hi")
|
||||
|
||||
assert captured["model"] == "dashscope/qwen3.5-flash"
|
||||
assert captured["api_key"] == "env-api-key"
|
||||
assert result["assistant_text"] == "hello"
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
|
||||
|
||||
|
||||
def test_usage_tracker_extracts_tokens_and_cost(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
||||
lambda completion_response: 0.123,
|
||||
)
|
||||
response = {
|
||||
"usage": {"prompt_tokens": 11, "completion_tokens": 7, "total_tokens": 18},
|
||||
}
|
||||
|
||||
usage = extract_usage_and_cost(response)
|
||||
|
||||
assert usage.prompt_tokens == 11
|
||||
assert usage.completion_tokens == 7
|
||||
assert usage.total_tokens == 18
|
||||
assert usage.cost == 0.123
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("prompt_tokens", "completion_tokens", "expected_cost"),
|
||||
[
|
||||
(128000, 1000, 0.0276),
|
||||
(200000, 1000, 0.168),
|
||||
(300000, 1000, 0.372),
|
||||
],
|
||||
)
|
||||
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
expected_cost: float,
|
||||
) -> None:
|
||||
def _raise_unmapped(*, completion_response): # type: ignore[no-untyped-def]
|
||||
del completion_response
|
||||
raise Exception("This model isn't mapped yet")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
||||
_raise_unmapped,
|
||||
)
|
||||
response = {
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
usage = extract_usage_and_cost(response)
|
||||
|
||||
assert usage.cost == pytest.approx(expected_cost)
|
||||
assert usage.cost_source == "custom_pricing"
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
|
||||
|
||||
class _FakeResumeService:
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "tool_call_id": tool_call_id}
|
||||
|
||||
|
||||
def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
events: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
|
||||
|
||||
|
||||
def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
class _BrokenRunService(_FakeRunService):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
del session_id, user_input
|
||||
raise RuntimeError("boom")
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_BrokenRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert events == ["RUN_STARTED", "RUN_ERROR"]
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[tuple[str, dict[str, str]]] = []
|
||||
|
||||
def xadd(self, stream: str, fields: dict[str, str]) -> str:
|
||||
self.calls.append((stream, fields))
|
||||
return "1-0"
|
||||
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key, start_id = next(iter(streams.items()))
|
||||
if start_id == "$":
|
||||
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
|
||||
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
|
||||
|
||||
|
||||
def test_append_event_writes_json_payload() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
|
||||
|
||||
stream_id = store.append_event_sync(
|
||||
session_id=session_id, event={"type": "RUN_STARTED"}
|
||||
)
|
||||
|
||||
assert stream_id == "1-0"
|
||||
assert len(client.calls) == 1
|
||||
stream, fields = client.calls[0]
|
||||
assert stream == f"agent:events:{session_id}"
|
||||
assert fields["event"] == '{"type":"RUN_STARTED"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_respects_last_event_id() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
|
||||
|
||||
from_start = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
from_last = await store.read_events(session_id=session_id, last_event_id="11-0")
|
||||
|
||||
assert from_start[0]["id"] == "11-0"
|
||||
assert from_last[0]["id"] == "12-0"
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_rejects_invalid_session_id() -> None:
|
||||
run_service = RunService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await run_service.run(session_id="session-1", user_input="hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_requires_pending_tool_call() -> None:
|
||||
resume_service = ResumeService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
|
||||
def test_state_snapshot_serialization_round_trip() -> None:
|
||||
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
|
||||
|
||||
payload = snapshot.model_dump()
|
||||
|
||||
assert payload["status"] == "running"
|
||||
assert payload["pending_tool_call_id"] == "call-1"
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.domain.tool_correlation import build_tool_result_metadata
|
||||
|
||||
|
||||
def test_tool_correlation_builds_tool_result_metadata() -> None:
|
||||
metadata = build_tool_result_metadata(
|
||||
run_id="run-1",
|
||||
turn_id="turn-1",
|
||||
tool_call_id="call-1",
|
||||
tool_name="weather",
|
||||
storage_bucket="private",
|
||||
storage_path="tool-results/run-1/call-1.json",
|
||||
payload_sha256="sha256",
|
||||
payload_bytes=128,
|
||||
payload_format="json",
|
||||
)
|
||||
|
||||
assert metadata["type"] == "tool_result"
|
||||
assert metadata["tool_call_id"] == "call-1"
|
||||
Reference in New Issue
Block a user