feat(agent): complete closed-loop runtime and pricing fallback

This commit is contained in:
qzl
2026-03-05 15:34:37 +08:00
parent b02a322bf3
commit b486e78ff3
67 changed files with 3832 additions and 7 deletions
@@ -0,0 +1,97 @@
from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from uuid import UUID
import httpx
import jwt
import pytest
from sqlalchemy import select
from core.config import config
from core.db.session import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.profile import Profile
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
async def _owner_id() -> UUID:
async with AsyncSessionLocal() as session:
owner_id = (
await session.execute(select(Profile.id).limit(1))
).scalar_one_or_none()
if owner_id is None:
raise RuntimeError("profile owner not found")
return owner_id
def _jwt_for(user_id: UUID) -> str:
secret = config.supabase.jwt_secret
if not secret:
raise RuntimeError("JWT secret not configured")
issuer = f"{config.supabase.public_url.rstrip('/')}/auth/v1"
payload = {
"sub": str(user_id),
"role": "authenticated",
"aud": "authenticated",
"iss": issuer,
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
return jwt.encode(payload, secret, algorithm="HS256")
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_closed_loop_live() -> None:
if os.getenv("AGENT_LIVE_E2E") != "1":
pytest.skip("set AGENT_LIVE_E2E=1 to run live closed-loop test")
owner_id = await _owner_id()
token = _jwt_for(owner_id)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=30.0) as client:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={"prompt": "请用一句话介绍你自己"},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
session_id = str(accepted["session_id"])
assert session_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=20.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_names.append(line.split(":", 1)[1].strip())
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(session_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(session_id)
)
)
assert len(list(rows.scalars().all())) >= 1
@@ -0,0 +1,213 @@
from __future__ import annotations
import uuid
from decimal import Decimal
import pytest
from sqlalchemy import delete, select
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
from core.agent.infrastructure.persistence.session_repository import SessionRepository
from core.agent.infrastructure.queue.tasks import run_agent_task
from core.db import AsyncSessionLocal, engine
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.profile import Profile
from models.system_agents import SystemAgents
@pytest.mark.asyncio
async def test_run_then_resume_persists_messages_and_session_state(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def _fake_execute(self, *, user_input: str) -> dict[str, object]:
del user_input
return {
"assistant_text": "Mocked answer",
"prompt_tokens": 11,
"completion_tokens": 7,
"total_tokens": 18,
"cost": 0.0025,
"agui_events": [
{"type": "TEXT_MESSAGE_START", "data": {"session_id": "__TBD__"}},
{
"type": "TEXT_MESSAGE_CONTENT",
"data": {"session_id": "__TBD__", "text": "Mocked answer"},
},
{"type": "TEXT_MESSAGE_END", "data": {"session_id": "__TBD__"}},
],
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
_fake_execute,
)
async with AsyncSessionLocal() as lookup_session:
existing_owner = await lookup_session.execute(
select(AgentChatSession.user_id).limit(1)
)
owner_id = existing_owner.scalar_one_or_none()
if owner_id is None:
pytest.skip("No existing session owner available in local database")
factory_id = uuid.uuid4()
session_uuid = uuid.uuid4()
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
async with AsyncSessionLocal() as seed_session:
llm_row = await seed_session.execute(select(Llm.id).limit(1))
llm_id = llm_row.scalar_one_or_none()
if llm_id is None:
seed_session.add(
LlmFactory(
id=factory_id,
name=f"dashscope-test-{uuid.uuid4().hex[:8]}",
request_url="https://dashscope.example",
)
)
llm_id = uuid.uuid4()
seed_session.add(
Llm(
id=llm_id,
factory_id=factory_id,
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
)
)
seed_session.add(
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
)
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.commit()
published: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
try:
run_result = run_agent_task(
{
"command": "run",
"session_id": str(session_uuid),
"user_input": "hello",
},
publish_event=_publish,
run_service=RunService(),
resume_service=ResumeService(),
)
pending_tool_call_id = str(run_result["pending_tool_call_id"])
run_agent_task(
{
"command": "resume",
"session_id": str(session_uuid),
"tool_call_id": pending_tool_call_id,
},
publish_event=_publish,
run_service=RunService(),
resume_service=ResumeService(),
)
await engine.dispose()
async with AsyncSessionLocal() as verify_session:
db_session = await verify_session.get(AgentChatSession, session_uuid)
assert db_session is not None
assert db_session.status == AgentChatSessionStatus.COMPLETED
assert db_session.message_count == 4
assert db_session.total_tokens == 18
assert db_session.total_cost == Decimal("0.002500")
assert db_session.state_snapshot == {
"status": "completed",
"pending_tool_call_id": None,
}
rows = await verify_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.order_by(AgentChatMessage.seq.asc())
)
messages = list(rows.scalars().all())
assert [item.role for item in messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
AgentChatMessageRole.TOOL,
AgentChatMessageRole.ASSISTANT,
]
assert messages[1].input_tokens == 11
assert messages[1].output_tokens == 7
assert messages[1].cost == Decimal("0.002500")
assert "RUN_STARTED" in published
assert "RUN_RESUMED" in published
assert "TEXT_MESSAGE_CONTENT" in published
finally:
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.execute(
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
)
await cleanup_session.commit()
@pytest.mark.asyncio
async def test_soft_delete_session_cascades_to_messages() -> None:
session_uuid = uuid.uuid4()
await engine.dispose()
async with AsyncSessionLocal() as lookup_session:
owner = await lookup_session.execute(select(Profile.id).limit(1))
owner_id = owner.scalar_one_or_none()
if owner_id is None:
pytest.skip("No profile owner available in local database")
async with AsyncSessionLocal() as seed_session:
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
await seed_session.flush()
seed_session.add(
AgentChatMessage(
session_id=session_uuid,
seq=1,
role=AgentChatMessageRole.USER,
content="hello",
)
)
await seed_session.commit()
try:
async with AsyncSessionLocal() as mutate_session:
repo = SessionRepository(mutate_session)
affected = await repo.soft_delete_session_with_messages(
session_id=session_uuid
)
await mutate_session.commit()
assert affected == 1
async with AsyncSessionLocal() as verify_session:
db_session = await verify_session.get(AgentChatSession, session_uuid)
assert db_session is not None
assert db_session.deleted_at is not None
rows = await verify_session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == session_uuid
)
)
messages = list(rows.scalars().all())
assert len(messages) == 1
assert messages[0].deleted_at is not None
finally:
async with AsyncSessionLocal() as cleanup_session:
await cleanup_session.execute(
delete(AgentChatMessage).where(
AgentChatMessage.session_id == session_uuid
)
)
await cleanup_session.execute(
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
)
await cleanup_session.commit()
@@ -0,0 +1,69 @@
from __future__ import annotations
from core.agent.application.session_state_persistence import persist_tool_result_payload
from core.agent.domain.tool_correlation import reconstruct_tool_call_result_event
from core.agent.infrastructure.queue.tasks import run_agent_task
class _FakeStorage:
def __init__(self) -> None:
self.writes: dict[str, dict[str, object]] = {}
async def upload_json(
self, *, bucket: str, path: str, payload: dict[str, object]
) -> str:
self.writes[f"{bucket}/{path}"] = payload
return "etag-1"
def test_closed_loop_run_flow_frontend_to_sse() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
published: list[str] = []
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
published.append(event_type)
result = run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
},
publish_event=_publish,
run_service=_FakeRunService(),
)
assert result["session_id"] == session_id
assert published[0] == "RUN_STARTED"
assert published[-1] == "RUN_FINISHED"
async def test_tool_result_full_payload_persist_and_reconstruct() -> None:
storage = _FakeStorage()
payload = {
"schema": "ui.v1",
"components": [{"type": "card", "title": "Weather"}],
}
metadata = await persist_tool_result_payload(
storage=storage,
run_id="run-1",
turn_id="turn-1",
tool_call_id="call-1",
tool_name="weather",
payload=payload,
bucket="private",
path="tool-results/run-1/call-1.json",
)
event = reconstruct_tool_call_result_event(metadata=metadata, payload=payload)
assert metadata["type"] == "tool_result"
assert metadata["storage_bucket"] == "private"
assert event["type"] == "TOOL_CALL_RESULT"
assert event["data"]["schema"] == "ui.v1"
@@ -0,0 +1,107 @@
from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
from fastapi.testclient import TestClient
from app import app
from core.auth.models import CurrentUser
from v1.agent.dependencies import get_agent_service
from v1.users.dependencies import get_current_user
class _FakeAgentService:
def __init__(self) -> None:
self._stream_called = False
async def enqueue_run(
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
):
del prompt, current_user
resolved_session = session_id or "auto-created-session"
return SimpleNamespace(
task_id="task-run-1",
session_id=resolved_session,
created=session_id is None,
)
async def enqueue_resume(
self,
*,
session_id: str,
tool_call_id: str,
current_user: CurrentUser,
):
del tool_call_id, current_user
return SimpleNamespace(
task_id="task-resume-1", session_id=session_id, created=False
)
async def stream_events(
self,
*,
session_id: str,
last_event_id: str | None,
current_user: CurrentUser,
) -> list[dict[str, object]]:
del session_id, current_user
if self._stream_called:
return []
self._stream_called = True
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
]
def test_run_requires_auth_and_returns_202_task_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
client = TestClient(app)
try:
unauthorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
)
assert unauthorized.status_code == 401
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
authorized = client.post(
"/api/v1/agent/runs",
json={"session_id": "session-1", "prompt": "hello"},
)
assert authorized.status_code == 202
assert authorized.json()["task_id"] == "task-run-1"
assert authorized.json()["created"] is False
first_chat = client.post(
"/api/v1/agent/runs",
json={"prompt": "hello"},
)
assert first_chat.status_code == 202
assert first_chat.json()["session_id"] == "auto-created-session"
assert first_chat.json()["created"] is True
finally:
app.dependency_overrides = {}
def test_stream_reads_from_last_event_id() -> None:
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
id=uuid4(), email="user@example.com"
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/agent/runs/session-1/events?idle_limit=1",
headers={"Last-Event-ID": "1-0"},
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert "id: 2-0" in response.text
assert "event: RUN_STARTED" in response.text
finally:
app.dependency_overrides = {}
@@ -0,0 +1,138 @@
from __future__ import annotations
import pytest
from core.agent.infrastructure.agui.bridge import to_agui_events
from core.agent.infrastructure.agui.stream import to_sse_event
def test_bridge_normalizes_event_type_to_upper_snake() -> None:
events = [{"type": "runStarted", "data": {"ok": True}}]
out = to_agui_events(events)
assert out[0]["type"] == "RUN_STARTED"
def test_bridge_supports_core_agui_event_taxonomy() -> None:
events = [
{"type": "runStarted", "data": {}},
{"type": "runFinished", "data": {}},
{"type": "stepStarted", "data": {}},
{"type": "stepFinished", "data": {}},
{"type": "textMessageStart", "data": {}},
{"type": "textMessageContent", "data": {}},
{"type": "textMessageEnd", "data": {}},
{"type": "toolCallStart", "data": {}},
{"type": "toolCallArgs", "data": {}},
{"type": "toolCallEnd", "data": {}},
{"type": "toolCallResult", "data": {}},
{"type": "stateSnapshot", "data": {}},
{"type": "stateDelta", "data": {}},
{"type": "reasoningMessageStart", "data": {}},
{"type": "reasoningMessageContent", "data": {}},
{"type": "reasoningMessageEnd", "data": {}},
]
out = to_agui_events(events)
assert [event["type"] for event in out] == [
"RUN_STARTED",
"RUN_FINISHED",
"STEP_STARTED",
"STEP_FINISHED",
"TEXT_MESSAGE_START",
"TEXT_MESSAGE_CONTENT",
"TEXT_MESSAGE_END",
"TOOL_CALL_START",
"TOOL_CALL_ARGS",
"TOOL_CALL_END",
"TOOL_CALL_RESULT",
"STATE_SNAPSHOT",
"STATE_DELTA",
"REASONING_MESSAGE_START",
"REASONING_MESSAGE_CONTENT",
"REASONING_MESSAGE_END",
]
def test_bridge_preserves_common_agui_fields() -> None:
events = [
{
"type": "toolCallResult",
"id": "evt-1",
"run_id": "run-1",
"timestamp": "2026-03-05T12:00:00Z",
"parent_message_id": "msg-1",
"data": {"ok": True},
}
]
out = to_agui_events(events)
assert out[0]["type"] == "TOOL_CALL_RESULT"
assert out[0]["id"] == "evt-1"
assert out[0]["run_id"] == "run-1"
assert out[0]["timestamp"] == "2026-03-05T12:00:00Z"
assert out[0]["parent_message_id"] == "msg-1"
def test_bridge_rejects_empty_event_type() -> None:
with pytest.raises(ValueError):
to_agui_events([{"type": "", "data": {}}])
def test_bridge_rejects_non_object_data() -> None:
with pytest.raises(ValueError):
to_agui_events([{"type": "runStarted", "data": "not-object"}])
def test_bridge_redacts_sensitive_fields_in_data() -> None:
out = to_agui_events(
[
{
"type": "toolCallArgs",
"data": {
"api_key": "k-1",
"payload": {"authorization": "Bearer x"},
"safe": "ok",
},
}
]
)
assert out[0]["data"]["api_key"] == "***REDACTED***"
assert out[0]["data"]["payload"]["authorization"] == "***REDACTED***"
assert out[0]["data"]["safe"] == "ok"
def test_bridge_redacts_sensitive_key_variants() -> None:
out = to_agui_events(
[
{
"type": "toolCallArgs",
"data": {
"x-api-key": "k-2",
"auth_token": "t-1",
"openaiApiKey": "k-3",
},
}
]
)
assert out[0]["data"]["x-api-key"] == "***REDACTED***"
assert out[0]["data"]["auth_token"] == "***REDACTED***"
assert out[0]["data"]["openaiApiKey"] == "***REDACTED***"
def test_bridge_rejects_unknown_event_type() -> None:
with pytest.raises(ValueError):
to_agui_events([{"type": "NOT_A_REAL_EVENT", "data": {}}])
def test_sse_format_includes_id_event_data() -> None:
payload = to_sse_event(
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
)
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
@@ -0,0 +1,96 @@
from __future__ import annotations
import pytest
from types import SimpleNamespace
from pytest import MonkeyPatch
from core.agent.infrastructure.config.resolver import AgentConfigResolver
from core.config.settings import Settings
def test_runtime_raises_if_model_or_api_key_missing() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="", streaming_enabled=True
),
llm=SimpleNamespace(provider_keys={}),
)
)
with pytest.raises(ValueError):
resolver.resolve(model_code="", provider_name="dashscope")
with pytest.raises(ValueError):
resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
def test_runtime_reads_provider_api_key_from_settings() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="gpt-4o-mini",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"dashscope": "env-like-api-key"}),
)
)
resolved = resolver.resolve(model_code="", provider_name="dashscope")
assert resolved.model_code == "gpt-4o-mini"
assert resolved.provider_api_key == "env-like-api-key"
def test_runtime_reads_provider_api_key_from_env(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "env-key")
resolver = AgentConfigResolver(settings=Settings())
resolved = resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
assert resolved.provider_api_key == "env-key"
def test_runtime_supports_provider_alias_to_env_key() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="deepseek-v3.2",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"ark": "ark-key"}),
)
)
resolved = resolver.resolve(model_code="", provider_name="volcengine-ark")
assert resolved.provider_api_key == "ark-key"
def test_runtime_rejects_unsupported_provider() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="qwen3.5-flash", streaming_enabled=True
),
llm=SimpleNamespace(provider_keys={"dashscope": "dash-key"}),
)
)
with pytest.raises(ValueError):
resolver.resolve(model_code="", provider_name="unknown-provider")
def test_runtime_config_repr_does_not_expose_api_key() -> None:
resolver = AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="qwen3.5-flash", streaming_enabled=True
),
llm=SimpleNamespace(provider_keys={"dashscope": "very-secret-key"}),
)
)
resolved = resolver.resolve(model_code="", provider_name="dashscope")
assert "very-secret-key" not in repr(resolved)
@@ -0,0 +1,97 @@
from __future__ import annotations
from types import SimpleNamespace
from core.agent.infrastructure.config.resolver import AgentConfigResolver
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
def test_runtime_emits_text_tool_reasoning_events() -> None:
runtime = CrewAIRuntime(
resolver=AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
)
),
model_code="gpt-4o-mini",
provider_name="dashscope",
)
events = runtime.map_events(
[
{"type": "textMessageContent", "data": {"text": "hello"}},
{"type": "toolCallStart", "data": {"tool_name": "weather"}},
{"type": "toolCallResult", "data": {"ok": True}},
{"type": "reasoningMessageContent", "data": {"text": "thinking"}},
{"type": "runFinished", "data": {"status": "completed"}},
]
)
assert [event["type"] for event in events] == [
"TEXT_MESSAGE_CONTENT",
"TOOL_CALL_START",
"TOOL_CALL_RESULT",
"REASONING_MESSAGE_CONTENT",
"RUN_FINISHED",
]
def test_runtime_execute_uses_provider_prefixed_litellm_model(
monkeypatch,
) -> None:
captured: dict[str, object] = {}
def _fake_completion(
*, model: str, api_key: str, messages: list[dict[str, object]]
):
captured["model"] = model
captured["api_key"] = api_key
captured["messages"] = messages
return {
"choices": [
{
"message": {
"content": "hello",
}
}
],
"usage": {},
}
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.run_completion",
_fake_completion,
)
monkeypatch.setattr(
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
lambda _response: SimpleNamespace(
prompt_tokens=1,
completion_tokens=2,
total_tokens=3,
cost=0.001,
),
)
runtime = CrewAIRuntime(
resolver=AgentConfigResolver(
settings=SimpleNamespace(
agent_runtime=SimpleNamespace(
default_model_code="",
streaming_enabled=True,
),
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
)
),
model_code="qwen3.5-flash",
provider_name="dashscope",
)
result = runtime.execute(user_input="hi")
assert captured["model"] == "dashscope/qwen3.5-flash"
assert captured["api_key"] == "env-api-key"
assert result["assistant_text"] == "hello"
@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
def test_usage_tracker_extracts_tokens_and_cost(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
lambda completion_response: 0.123,
)
response = {
"usage": {"prompt_tokens": 11, "completion_tokens": 7, "total_tokens": 18},
}
usage = extract_usage_and_cost(response)
assert usage.prompt_tokens == 11
assert usage.completion_tokens == 7
assert usage.total_tokens == 18
assert usage.cost == 0.123
@pytest.mark.parametrize(
("prompt_tokens", "completion_tokens", "expected_cost"),
[
(128000, 1000, 0.0276),
(200000, 1000, 0.168),
(300000, 1000, 0.372),
],
)
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
monkeypatch: pytest.MonkeyPatch,
prompt_tokens: int,
completion_tokens: int,
expected_cost: float,
) -> None:
def _raise_unmapped(*, completion_response): # type: ignore[no-untyped-def]
del completion_response
raise Exception("This model isn't mapped yet")
monkeypatch.setattr(
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
_raise_unmapped,
)
response = {
"model": "dashscope/qwen3.5-flash",
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
usage = extract_usage_and_cost(response)
assert usage.cost == pytest.approx(expected_cost)
assert usage.cost_source == "custom_pricing"
@@ -0,0 +1,67 @@
from __future__ import annotations
import pytest
from core.agent.infrastructure.queue.tasks import run_agent_task
class _FakeRunService:
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
return {"session_id": session_id, "user_input": user_input}
class _FakeResumeService:
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
return {"session_id": session_id, "tool_call_id": tool_call_id}
def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
events: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
result = run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
},
publish_event=_publish,
run_service=_FakeRunService(),
resume_service=_FakeResumeService(),
)
assert result["session_id"] == session_id
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
def test_run_agent_task_emits_error_event_on_exception() -> None:
session_id = "00000000-0000-0000-0000-000000000001"
class _BrokenRunService(_FakeRunService):
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
del session_id, user_input
raise RuntimeError("boom")
events: list[str] = []
def _publish(event_type: str, payload: dict[str, object]) -> None:
del payload
events.append(event_type)
with pytest.raises(RuntimeError):
run_agent_task(
{
"command": "run",
"session_id": session_id,
"user_input": "hello",
},
publish_event=_publish,
run_service=_BrokenRunService(),
resume_service=_FakeResumeService(),
)
assert events == ["RUN_STARTED", "RUN_ERROR"]
@@ -0,0 +1,57 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
class _FakeRedisClient:
def __init__(self) -> None:
self.calls: list[tuple[str, dict[str, str]]] = []
def xadd(self, stream: str, fields: dict[str, str]) -> str:
self.calls.append((stream, fields))
return "1-0"
async def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
key, start_id = next(iter(streams.items()))
if start_id == "$":
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
def test_append_event_writes_json_payload() -> None:
client = _FakeRedisClient()
session_id = uuid4()
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
stream_id = store.append_event_sync(
session_id=session_id, event={"type": "RUN_STARTED"}
)
assert stream_id == "1-0"
assert len(client.calls) == 1
stream, fields = client.calls[0]
assert stream == f"agent:events:{session_id}"
assert fields["event"] == '{"type":"RUN_STARTED"}'
@pytest.mark.asyncio
async def test_read_events_respects_last_event_id() -> None:
client = _FakeRedisClient()
session_id = uuid4()
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
from_start = await store.read_events(session_id=session_id, last_event_id=None)
from_last = await store.read_events(session_id=session_id, last_event_id="11-0")
assert from_start[0]["id"] == "11-0"
assert from_last[0]["id"] == "12-0"
@@ -0,0 +1,22 @@
from __future__ import annotations
import pytest
from core.agent.application.resume_service import ResumeService
from core.agent.application.run_service import RunService
@pytest.mark.asyncio
async def test_run_service_rejects_invalid_session_id() -> None:
run_service = RunService()
with pytest.raises(ValueError):
await run_service.run(session_id="session-1", user_input="hello")
@pytest.mark.asyncio
async def test_resume_service_requires_pending_tool_call() -> None:
resume_service = ResumeService()
with pytest.raises(ValueError):
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
@@ -0,0 +1,12 @@
from __future__ import annotations
from core.agent.domain.state_snapshot import AgentStateSnapshot
def test_state_snapshot_serialization_round_trip() -> None:
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
payload = snapshot.model_dump()
assert payload["status"] == "running"
assert payload["pending_tool_call_id"] == "call-1"
@@ -0,0 +1,20 @@
from __future__ import annotations
from core.agent.domain.tool_correlation import build_tool_result_metadata
def test_tool_correlation_builds_tool_result_metadata() -> None:
metadata = build_tool_result_metadata(
run_id="run-1",
turn_id="turn-1",
tool_call_id="call-1",
tool_name="weather",
storage_bucket="private",
storage_path="tool-results/run-1/call-1.json",
payload_sha256="sha256",
payload_bytes=128,
payload_format="json",
)
assert metadata["type"] == "tool_result"
assert metadata["tool_call_id"] == "call-1"
@@ -0,0 +1,29 @@
from __future__ import annotations
from pathlib import Path
def test_session_has_state_snapshot_and_status_contract() -> None:
model_path = (
Path(__file__).resolve().parents[3] / "src" / "models" / "agent_chat_session.py"
)
content = model_path.read_text(encoding="utf-8")
assert "state_snapshot" in content
assert "AgentChatSessionStatus" in content
def test_message_has_token_cost_and_metadata_contract() -> None:
model_path = (
Path(__file__).resolve().parents[3] / "src" / "models" / "agent_chat_message.py"
)
content = model_path.read_text(encoding="utf-8")
assert "input_tokens" in content
assert "output_tokens" in content
assert "cost" in content
assert '"metadata"' in content
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
migration_file = versions_dir / "20260305_agent_runtime_closed_loop_contract.py"
assert migration_file.exists()
@@ -0,0 +1,20 @@
from __future__ import annotations
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_social_prefixed_llm_provider_keys_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "dash-key")
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DEEPSEEK", "deep-key")
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__ARK", "ark-key")
settings = Settings()
keys = {key.lower(): value for key, value in settings.llm.provider_keys.items()}
assert keys["dashscope"] == "dash-key"
assert keys["deepseek"] == "deep-key"
assert keys["ark"] == "ark-key"
@@ -0,0 +1,16 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from fastapi import HTTPException
from core.auth.models import CurrentUser
from v1.agent.service import ensure_session_owner
def test_owner_guard_denies_non_owner() -> None:
user = CurrentUser(id=uuid4(), email="self@example.com")
with pytest.raises(HTTPException):
ensure_session_owner(owner_id="other-user", current_user=user)
+125
View File
@@ -0,0 +1,125 @@
from __future__ import annotations
from uuid import UUID
from core.auth.models import CurrentUser
from v1.agent.service import AgentService
class _FakeRepository:
def __init__(self) -> None:
self.committed = False
self.rolled_back = False
self.deleted_session_id: str | None = None
async def get_session_owner(self, *, session_id: str) -> str:
del session_id
return "00000000-0000-0000-0000-000000000001"
async def create_session_for_user(self, *, user_id: str) -> str:
del user_id
return "00000000-0000-0000-0000-000000000999"
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def delete_session(self, *, session_id: str) -> None:
self.deleted_session_id = session_id
class _FakeQueue:
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
return "task-1"
class _FailingQueue:
async def enqueue(
self, *, command: dict[str, object], dedup_key: str | None
) -> str:
del command, dedup_key
raise RuntimeError("enqueue failed")
class _FakeStream:
async def read(
self, *, session_id: str, last_event_id: str | None
) -> list[dict[str, object]]:
del session_id
return [
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
]
def _user() -> CurrentUser:
return CurrentUser(
id=UUID("00000000-0000-0000-0000-000000000001"),
email="user@example.com",
)
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
service = AgentService(
repository=_FakeRepository(),
queue=_FakeQueue(),
stream=_FakeStream(),
)
user = _user()
first = await service.enqueue_resume(
session_id="session-1",
tool_call_id="call-1",
current_user=user,
)
second = await service.enqueue_resume(
session_id="session-1",
tool_call_id="call-1",
current_user=user,
)
assert first.task_id == second.task_id
async def test_enqueue_run_without_session_creates_new_session() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FakeQueue(),
stream=_FakeStream(),
)
accepted = await service.enqueue_run(
session_id=None,
prompt="hello",
current_user=_user(),
)
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
assert accepted.created is True
assert repository.committed is True
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
repository = _FakeRepository()
service = AgentService(
repository=repository,
queue=_FailingQueue(),
stream=_FakeStream(),
)
try:
await service.enqueue_run(
session_id=None,
prompt="hello",
current_user=_user(),
)
raise AssertionError("expected RuntimeError")
except RuntimeError as exc:
assert str(exc) == "enqueue failed"
assert repository.deleted_session_id is None