feat(agent): complete closed-loop runtime and pricing fallback
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.config import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from models.profile import Profile
|
||||
|
||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||
|
||||
|
||||
async def _owner_id() -> UUID:
|
||||
async with AsyncSessionLocal() as session:
|
||||
owner_id = (
|
||||
await session.execute(select(Profile.id).limit(1))
|
||||
).scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
raise RuntimeError("profile owner not found")
|
||||
return owner_id
|
||||
|
||||
|
||||
def _jwt_for(user_id: UUID) -> str:
|
||||
secret = config.supabase.jwt_secret
|
||||
if not secret:
|
||||
raise RuntimeError("JWT secret not configured")
|
||||
issuer = f"{config.supabase.public_url.rstrip('/')}/auth/v1"
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"role": "authenticated",
|
||||
"aud": "authenticated",
|
||||
"iss": issuer,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
async def test_agent_closed_loop_live() -> None:
|
||||
if os.getenv("AGENT_LIVE_E2E") != "1":
|
||||
pytest.skip("set AGENT_LIVE_E2E=1 to run live closed-loop test")
|
||||
|
||||
owner_id = await _owner_id()
|
||||
token = _jwt_for(owner_id)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={"prompt": "请用一句话介绍你自己"},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
accepted = run_resp.json()
|
||||
session_id = str(accepted["session_id"])
|
||||
assert session_id
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
|
||||
event_names: list[str] = []
|
||||
async with client.stream(
|
||||
"GET", events_url, headers=headers, timeout=20.0
|
||||
) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
assert sse_resp.headers.get("content-type", "").startswith(
|
||||
"text/event-stream"
|
||||
)
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("event:"):
|
||||
event_names.append(line.split(":", 1)[1].strip())
|
||||
|
||||
assert "RUN_STARTED" in event_names
|
||||
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
session_row = await session.get(AgentChatSession, UUID(session_id))
|
||||
assert session_row is not None
|
||||
assert session_row.message_count >= 1
|
||||
assert session_row.total_tokens >= 0
|
||||
assert session_row.total_cost >= 0
|
||||
|
||||
rows = await session.execute(
|
||||
select(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == UUID(session_id)
|
||||
)
|
||||
)
|
||||
assert len(list(rows.scalars().all())) >= 1
|
||||
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
from core.db import AsyncSessionLocal, engine
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.profile import Profile
|
||||
from models.system_agents import SystemAgents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_then_resume_persists_messages_and_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def _fake_execute(self, *, user_input: str) -> dict[str, object]:
|
||||
del user_input
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"agui_events": [
|
||||
{"type": "TEXT_MESSAGE_START", "data": {"session_id": "__TBD__"}},
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"data": {"session_id": "__TBD__", "text": "Mocked answer"},
|
||||
},
|
||||
{"type": "TEXT_MESSAGE_END", "data": {"session_id": "__TBD__"}},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
|
||||
_fake_execute,
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
existing_owner = await lookup_session.execute(
|
||||
select(AgentChatSession.user_id).limit(1)
|
||||
)
|
||||
owner_id = existing_owner.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No existing session owner available in local database")
|
||||
factory_id = uuid.uuid4()
|
||||
session_uuid = uuid.uuid4()
|
||||
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
llm_row = await seed_session.execute(select(Llm.id).limit(1))
|
||||
llm_id = llm_row.scalar_one_or_none()
|
||||
if llm_id is None:
|
||||
seed_session.add(
|
||||
LlmFactory(
|
||||
id=factory_id,
|
||||
name=f"dashscope-test-{uuid.uuid4().hex[:8]}",
|
||||
request_url="https://dashscope.example",
|
||||
)
|
||||
)
|
||||
llm_id = uuid.uuid4()
|
||||
seed_session.add(
|
||||
Llm(
|
||||
id=llm_id,
|
||||
factory_id=factory_id,
|
||||
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
|
||||
)
|
||||
)
|
||||
seed_session.add(
|
||||
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
|
||||
)
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
published: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
|
||||
try:
|
||||
run_result = run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": str(session_uuid),
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
pending_tool_call_id = str(run_result["pending_tool_call_id"])
|
||||
|
||||
run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": str(session_uuid),
|
||||
"tool_call_id": pending_tool_call_id,
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
db_session = await verify_session.get(AgentChatSession, session_uuid)
|
||||
assert db_session is not None
|
||||
assert db_session.status == AgentChatSessionStatus.COMPLETED
|
||||
assert db_session.message_count == 4
|
||||
assert db_session.total_tokens == 18
|
||||
assert db_session.total_cost == Decimal("0.002500")
|
||||
assert db_session.state_snapshot == {
|
||||
"status": "completed",
|
||||
"pending_tool_call_id": None,
|
||||
}
|
||||
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = list(rows.scalars().all())
|
||||
assert [item.role for item in messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
AgentChatMessageRole.TOOL,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert messages[1].input_tokens == 11
|
||||
assert messages[1].output_tokens == 7
|
||||
assert messages[1].cost == Decimal("0.002500")
|
||||
|
||||
assert "RUN_STARTED" in published
|
||||
assert "RUN_RESUMED" in published
|
||||
assert "TEXT_MESSAGE_CONTENT" in published
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_session_cascades_to_messages() -> None:
|
||||
session_uuid = uuid.uuid4()
|
||||
await engine.dispose()
|
||||
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
owner = await lookup_session.execute(select(Profile.id).limit(1))
|
||||
owner_id = owner.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No profile owner available in local database")
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.flush()
|
||||
seed_session.add(
|
||||
AgentChatMessage(
|
||||
session_id=session_uuid,
|
||||
seq=1,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
await seed_session.commit()
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as mutate_session:
|
||||
repo = SessionRepository(mutate_session)
|
||||
affected = await repo.soft_delete_session_with_messages(
|
||||
session_id=session_uuid
|
||||
)
|
||||
await mutate_session.commit()
|
||||
assert affected == 1
|
||||
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
db_session = await verify_session.get(AgentChatSession, session_uuid)
|
||||
assert db_session is not None
|
||||
assert db_session.deleted_at is not None
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == session_uuid
|
||||
)
|
||||
)
|
||||
messages = list(rows.scalars().all())
|
||||
assert len(messages) == 1
|
||||
assert messages[0].deleted_at is not None
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == session_uuid
|
||||
)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.application.session_state_persistence import persist_tool_result_payload
|
||||
from core.agent.domain.tool_correlation import reconstruct_tool_call_result_event
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
|
||||
|
||||
class _FakeStorage:
|
||||
def __init__(self) -> None:
|
||||
self.writes: dict[str, dict[str, object]] = {}
|
||||
|
||||
async def upload_json(
|
||||
self, *, bucket: str, path: str, payload: dict[str, object]
|
||||
) -> str:
|
||||
self.writes[f"{bucket}/{path}"] = payload
|
||||
return "etag-1"
|
||||
|
||||
|
||||
def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
published: list[str] = []
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert published[0] == "RUN_STARTED"
|
||||
assert published[-1] == "RUN_FINISHED"
|
||||
|
||||
|
||||
async def test_tool_result_full_payload_persist_and_reconstruct() -> None:
|
||||
storage = _FakeStorage()
|
||||
payload = {
|
||||
"schema": "ui.v1",
|
||||
"components": [{"type": "card", "title": "Weather"}],
|
||||
}
|
||||
|
||||
metadata = await persist_tool_result_payload(
|
||||
storage=storage,
|
||||
run_id="run-1",
|
||||
turn_id="turn-1",
|
||||
tool_call_id="call-1",
|
||||
tool_name="weather",
|
||||
payload=payload,
|
||||
bucket="private",
|
||||
path="tool-results/run-1/call-1.json",
|
||||
)
|
||||
|
||||
event = reconstruct_tool_call_result_event(metadata=metadata, payload=payload)
|
||||
|
||||
assert metadata["type"] == "tool_result"
|
||||
assert metadata["storage_bucket"] == "private"
|
||||
assert event["type"] == "TOOL_CALL_RESULT"
|
||||
assert event["data"]["schema"] == "ui.v1"
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class _FakeAgentService:
|
||||
def __init__(self) -> None:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(
|
||||
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
|
||||
):
|
||||
del prompt, current_user
|
||||
resolved_session = session_id or "auto-created-session"
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
session_id=resolved_session,
|
||||
created=session_id is None,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
current_user: CurrentUser,
|
||||
):
|
||||
del tool_call_id, current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-resume-1", session_id=session_id, created=False
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id, current_user
|
||||
if self._stream_called:
|
||||
return []
|
||||
self._stream_called = True
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
]
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
unauthorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
authorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
assert authorized.json()["task_id"] == "task-run-1"
|
||||
assert authorized.json()["created"] is False
|
||||
|
||||
first_chat = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"prompt": "hello"},
|
||||
)
|
||||
assert first_chat.status_code == 202
|
||||
assert first_chat.json()["session_id"] == "auto-created-session"
|
||||
assert first_chat.json()["created"] is True
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_reads_from_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/session-1/events?idle_limit=1",
|
||||
headers={"Last-Event-ID": "1-0"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert "id: 2-0" in response.text
|
||||
assert "event: RUN_STARTED" in response.text
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.agui.bridge import to_agui_events
|
||||
from core.agent.infrastructure.agui.stream import to_sse_event
|
||||
|
||||
|
||||
def test_bridge_normalizes_event_type_to_upper_snake() -> None:
|
||||
events = [{"type": "runStarted", "data": {"ok": True}}]
|
||||
|
||||
out = to_agui_events(events)
|
||||
|
||||
assert out[0]["type"] == "RUN_STARTED"
|
||||
|
||||
|
||||
def test_bridge_supports_core_agui_event_taxonomy() -> None:
|
||||
events = [
|
||||
{"type": "runStarted", "data": {}},
|
||||
{"type": "runFinished", "data": {}},
|
||||
{"type": "stepStarted", "data": {}},
|
||||
{"type": "stepFinished", "data": {}},
|
||||
{"type": "textMessageStart", "data": {}},
|
||||
{"type": "textMessageContent", "data": {}},
|
||||
{"type": "textMessageEnd", "data": {}},
|
||||
{"type": "toolCallStart", "data": {}},
|
||||
{"type": "toolCallArgs", "data": {}},
|
||||
{"type": "toolCallEnd", "data": {}},
|
||||
{"type": "toolCallResult", "data": {}},
|
||||
{"type": "stateSnapshot", "data": {}},
|
||||
{"type": "stateDelta", "data": {}},
|
||||
{"type": "reasoningMessageStart", "data": {}},
|
||||
{"type": "reasoningMessageContent", "data": {}},
|
||||
{"type": "reasoningMessageEnd", "data": {}},
|
||||
]
|
||||
|
||||
out = to_agui_events(events)
|
||||
|
||||
assert [event["type"] for event in out] == [
|
||||
"RUN_STARTED",
|
||||
"RUN_FINISHED",
|
||||
"STEP_STARTED",
|
||||
"STEP_FINISHED",
|
||||
"TEXT_MESSAGE_START",
|
||||
"TEXT_MESSAGE_CONTENT",
|
||||
"TEXT_MESSAGE_END",
|
||||
"TOOL_CALL_START",
|
||||
"TOOL_CALL_ARGS",
|
||||
"TOOL_CALL_END",
|
||||
"TOOL_CALL_RESULT",
|
||||
"STATE_SNAPSHOT",
|
||||
"STATE_DELTA",
|
||||
"REASONING_MESSAGE_START",
|
||||
"REASONING_MESSAGE_CONTENT",
|
||||
"REASONING_MESSAGE_END",
|
||||
]
|
||||
|
||||
|
||||
def test_bridge_preserves_common_agui_fields() -> None:
|
||||
events = [
|
||||
{
|
||||
"type": "toolCallResult",
|
||||
"id": "evt-1",
|
||||
"run_id": "run-1",
|
||||
"timestamp": "2026-03-05T12:00:00Z",
|
||||
"parent_message_id": "msg-1",
|
||||
"data": {"ok": True},
|
||||
}
|
||||
]
|
||||
|
||||
out = to_agui_events(events)
|
||||
|
||||
assert out[0]["type"] == "TOOL_CALL_RESULT"
|
||||
assert out[0]["id"] == "evt-1"
|
||||
assert out[0]["run_id"] == "run-1"
|
||||
assert out[0]["timestamp"] == "2026-03-05T12:00:00Z"
|
||||
assert out[0]["parent_message_id"] == "msg-1"
|
||||
|
||||
|
||||
def test_bridge_rejects_empty_event_type() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_agui_events([{"type": "", "data": {}}])
|
||||
|
||||
|
||||
def test_bridge_rejects_non_object_data() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_agui_events([{"type": "runStarted", "data": "not-object"}])
|
||||
|
||||
|
||||
def test_bridge_redacts_sensitive_fields_in_data() -> None:
|
||||
out = to_agui_events(
|
||||
[
|
||||
{
|
||||
"type": "toolCallArgs",
|
||||
"data": {
|
||||
"api_key": "k-1",
|
||||
"payload": {"authorization": "Bearer x"},
|
||||
"safe": "ok",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert out[0]["data"]["api_key"] == "***REDACTED***"
|
||||
assert out[0]["data"]["payload"]["authorization"] == "***REDACTED***"
|
||||
assert out[0]["data"]["safe"] == "ok"
|
||||
|
||||
|
||||
def test_bridge_redacts_sensitive_key_variants() -> None:
|
||||
out = to_agui_events(
|
||||
[
|
||||
{
|
||||
"type": "toolCallArgs",
|
||||
"data": {
|
||||
"x-api-key": "k-2",
|
||||
"auth_token": "t-1",
|
||||
"openaiApiKey": "k-3",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert out[0]["data"]["x-api-key"] == "***REDACTED***"
|
||||
assert out[0]["data"]["auth_token"] == "***REDACTED***"
|
||||
assert out[0]["data"]["openaiApiKey"] == "***REDACTED***"
|
||||
|
||||
|
||||
def test_bridge_rejects_unknown_event_type() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
to_agui_events([{"type": "NOT_A_REAL_EVENT", "data": {}}])
|
||||
|
||||
|
||||
def test_sse_format_includes_id_event_data() -> None:
|
||||
payload = to_sse_event(
|
||||
stream_id="1-0", event={"type": "RUN_STARTED", "data": {"a": 1}}
|
||||
)
|
||||
|
||||
assert payload.startswith("id: 1-0\nevent: RUN_STARTED\ndata: {")
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_runtime_raises_if_model_or_api_key_missing() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={}),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
resolver.resolve(model_code="", provider_name="dashscope")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
|
||||
|
||||
|
||||
def test_runtime_reads_provider_api_key_from_settings() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="gpt-4o-mini",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-like-api-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
resolved = resolver.resolve(model_code="", provider_name="dashscope")
|
||||
|
||||
assert resolved.model_code == "gpt-4o-mini"
|
||||
assert resolved.provider_api_key == "env-like-api-key"
|
||||
|
||||
|
||||
def test_runtime_reads_provider_api_key_from_env(monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "env-key")
|
||||
resolver = AgentConfigResolver(settings=Settings())
|
||||
|
||||
resolved = resolver.resolve(model_code="gpt-4o-mini", provider_name="dashscope")
|
||||
|
||||
assert resolved.provider_api_key == "env-key"
|
||||
|
||||
|
||||
def test_runtime_supports_provider_alias_to_env_key() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="deepseek-v3.2",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"ark": "ark-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
resolved = resolver.resolve(model_code="", provider_name="volcengine-ark")
|
||||
|
||||
assert resolved.provider_api_key == "ark-key"
|
||||
|
||||
|
||||
def test_runtime_rejects_unsupported_provider() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="qwen3.5-flash", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "dash-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
resolver.resolve(model_code="", provider_name="unknown-provider")
|
||||
|
||||
|
||||
def test_runtime_config_repr_does_not_expose_api_key() -> None:
|
||||
resolver = AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="qwen3.5-flash", streaming_enabled=True
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "very-secret-key"}),
|
||||
)
|
||||
)
|
||||
|
||||
resolved = resolver.resolve(model_code="", provider_name="dashscope")
|
||||
|
||||
assert "very-secret-key" not in repr(resolved)
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.agent.infrastructure.config.resolver import AgentConfigResolver
|
||||
from core.agent.infrastructure.crewai.runtime import CrewAIRuntime
|
||||
|
||||
|
||||
def test_runtime_emits_text_tool_reasoning_events() -> None:
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
)
|
||||
),
|
||||
model_code="gpt-4o-mini",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
events = runtime.map_events(
|
||||
[
|
||||
{"type": "textMessageContent", "data": {"text": "hello"}},
|
||||
{"type": "toolCallStart", "data": {"tool_name": "weather"}},
|
||||
{"type": "toolCallResult", "data": {"ok": True}},
|
||||
{"type": "reasoningMessageContent", "data": {"text": "thinking"}},
|
||||
{"type": "runFinished", "data": {"status": "completed"}},
|
||||
]
|
||||
)
|
||||
|
||||
assert [event["type"] for event in events] == [
|
||||
"TEXT_MESSAGE_CONTENT",
|
||||
"TOOL_CALL_START",
|
||||
"TOOL_CALL_RESULT",
|
||||
"REASONING_MESSAGE_CONTENT",
|
||||
"RUN_FINISHED",
|
||||
]
|
||||
|
||||
|
||||
def test_runtime_execute_uses_provider_prefixed_litellm_model(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_completion(
|
||||
*, model: str, api_key: str, messages: list[dict[str, object]]
|
||||
):
|
||||
captured["model"] = model
|
||||
captured["api_key"] = api_key
|
||||
captured["messages"] = messages
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "hello",
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.run_completion",
|
||||
_fake_completion,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.extract_usage_and_cost",
|
||||
lambda _response: SimpleNamespace(
|
||||
prompt_tokens=1,
|
||||
completion_tokens=2,
|
||||
total_tokens=3,
|
||||
cost=0.001,
|
||||
),
|
||||
)
|
||||
|
||||
runtime = CrewAIRuntime(
|
||||
resolver=AgentConfigResolver(
|
||||
settings=SimpleNamespace(
|
||||
agent_runtime=SimpleNamespace(
|
||||
default_model_code="",
|
||||
streaming_enabled=True,
|
||||
),
|
||||
llm=SimpleNamespace(provider_keys={"dashscope": "env-api-key"}),
|
||||
)
|
||||
),
|
||||
model_code="qwen3.5-flash",
|
||||
provider_name="dashscope",
|
||||
)
|
||||
|
||||
result = runtime.execute(user_input="hi")
|
||||
|
||||
assert captured["model"] == "dashscope/qwen3.5-flash"
|
||||
assert captured["api_key"] == "env-api-key"
|
||||
assert result["assistant_text"] == "hello"
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.litellm.usage_tracker import extract_usage_and_cost
|
||||
|
||||
|
||||
def test_usage_tracker_extracts_tokens_and_cost(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
||||
lambda completion_response: 0.123,
|
||||
)
|
||||
response = {
|
||||
"usage": {"prompt_tokens": 11, "completion_tokens": 7, "total_tokens": 18},
|
||||
}
|
||||
|
||||
usage = extract_usage_and_cost(response)
|
||||
|
||||
assert usage.prompt_tokens == 11
|
||||
assert usage.completion_tokens == 7
|
||||
assert usage.total_tokens == 18
|
||||
assert usage.cost == 0.123
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("prompt_tokens", "completion_tokens", "expected_cost"),
|
||||
[
|
||||
(128000, 1000, 0.0276),
|
||||
(200000, 1000, 0.168),
|
||||
(300000, 1000, 0.372),
|
||||
],
|
||||
)
|
||||
def test_usage_tracker_falls_back_to_local_qwen35_pricing_when_model_unmapped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
expected_cost: float,
|
||||
) -> None:
|
||||
def _raise_unmapped(*, completion_response): # type: ignore[no-untyped-def]
|
||||
del completion_response
|
||||
raise Exception("This model isn't mapped yet")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.litellm.usage_tracker.completion_cost",
|
||||
_raise_unmapped,
|
||||
)
|
||||
response = {
|
||||
"model": "dashscope/qwen3.5-flash",
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
usage = extract_usage_and_cost(response)
|
||||
|
||||
assert usage.cost == pytest.approx(expected_cost)
|
||||
assert usage.cost_source == "custom_pricing"
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
|
||||
|
||||
class _FakeResumeService:
|
||||
async def resume(self, *, session_id: str, tool_call_id: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "tool_call_id": tool_call_id}
|
||||
|
||||
|
||||
def test_run_agent_task_emits_started_runtime_and_finished_events() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
events: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert events == ["RUN_STARTED", "RUNTIME_EVENT", "RUN_FINISHED"]
|
||||
|
||||
|
||||
def test_run_agent_task_emits_error_event_on_exception() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
class _BrokenRunService(_FakeRunService):
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
del session_id, user_input
|
||||
raise RuntimeError("boom")
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
events.append(event_type)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_BrokenRunService(),
|
||||
resume_service=_FakeResumeService(),
|
||||
)
|
||||
|
||||
assert events == ["RUN_STARTED", "RUN_ERROR"]
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[tuple[str, dict[str, str]]] = []
|
||||
|
||||
def xadd(self, stream: str, fields: dict[str, str]) -> str:
|
||||
self.calls.append((stream, fields))
|
||||
return "1-0"
|
||||
|
||||
async def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
key, start_id = next(iter(streams.items()))
|
||||
if start_id == "$":
|
||||
return [(key, [("11-0", {"event": '{"type":"RUN_STARTED"}'})])]
|
||||
return [(key, [("12-0", {"event": '{"type":"RUN_FINISHED"}'})])]
|
||||
|
||||
|
||||
def test_append_event_writes_json_payload() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
|
||||
|
||||
stream_id = store.append_event_sync(
|
||||
session_id=session_id, event={"type": "RUN_STARTED"}
|
||||
)
|
||||
|
||||
assert stream_id == "1-0"
|
||||
assert len(client.calls) == 1
|
||||
stream, fields = client.calls[0]
|
||||
assert stream == f"agent:events:{session_id}"
|
||||
assert fields["event"] == '{"type":"RUN_STARTED"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_events_respects_last_event_id() -> None:
|
||||
client = _FakeRedisClient()
|
||||
session_id = uuid4()
|
||||
store = RedisStreamEventStore(client=client, stream_prefix="agent:events")
|
||||
|
||||
from_start = await store.read_events(session_id=session_id, last_event_id=None)
|
||||
from_last = await store.read_events(session_id=session_id, last_event_id="11-0")
|
||||
|
||||
assert from_start[0]["id"] == "11-0"
|
||||
assert from_last[0]["id"] == "12-0"
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_service_rejects_invalid_session_id() -> None:
|
||||
run_service = RunService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await run_service.run(session_id="session-1", user_input="hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_service_requires_pending_tool_call() -> None:
|
||||
resume_service = ResumeService()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await resume_service.resume(session_id="session-1", tool_call_id="call-1")
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.domain.state_snapshot import AgentStateSnapshot
|
||||
|
||||
|
||||
def test_state_snapshot_serialization_round_trip() -> None:
|
||||
snapshot = AgentStateSnapshot(status="running", pending_tool_call_id="call-1")
|
||||
|
||||
payload = snapshot.model_dump()
|
||||
|
||||
assert payload["status"] == "running"
|
||||
assert payload["pending_tool_call_id"] == "call-1"
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.domain.tool_correlation import build_tool_result_metadata
|
||||
|
||||
|
||||
def test_tool_correlation_builds_tool_result_metadata() -> None:
|
||||
metadata = build_tool_result_metadata(
|
||||
run_id="run-1",
|
||||
turn_id="turn-1",
|
||||
tool_call_id="call-1",
|
||||
tool_name="weather",
|
||||
storage_bucket="private",
|
||||
storage_path="tool-results/run-1/call-1.json",
|
||||
payload_sha256="sha256",
|
||||
payload_bytes=128,
|
||||
payload_format="json",
|
||||
)
|
||||
|
||||
assert metadata["type"] == "tool_result"
|
||||
assert metadata["tool_call_id"] == "call-1"
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_session_has_state_snapshot_and_status_contract() -> None:
|
||||
model_path = (
|
||||
Path(__file__).resolve().parents[3] / "src" / "models" / "agent_chat_session.py"
|
||||
)
|
||||
content = model_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "state_snapshot" in content
|
||||
assert "AgentChatSessionStatus" in content
|
||||
|
||||
|
||||
def test_message_has_token_cost_and_metadata_contract() -> None:
|
||||
model_path = (
|
||||
Path(__file__).resolve().parents[3] / "src" / "models" / "agent_chat_message.py"
|
||||
)
|
||||
content = model_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "input_tokens" in content
|
||||
assert "output_tokens" in content
|
||||
assert "cost" in content
|
||||
assert '"metadata"' in content
|
||||
|
||||
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
|
||||
migration_file = versions_dir / "20260305_agent_runtime_closed_loop_contract.py"
|
||||
assert migration_file.exists()
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_social_prefixed_llm_provider_keys_populates_settings(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DASHSCOPE", "dash-key")
|
||||
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__DEEPSEEK", "deep-key")
|
||||
monkeypatch.setenv("SOCIAL_LLM__PROVIDER_KEYS__ARK", "ark-key")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
keys = {key.lower(): value for key, value in settings.llm.provider_keys.items()}
|
||||
assert keys["dashscope"] == "dash-key"
|
||||
assert keys["deepseek"] == "deep-key"
|
||||
assert keys["ark"] == "ark-key"
|
||||
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.service import ensure_session_owner
|
||||
|
||||
|
||||
def test_owner_guard_denies_non_owner() -> None:
|
||||
user = CurrentUser(id=uuid4(), email="self@example.com")
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
ensure_session_owner(owner_id="other-user", current_user=user)
|
||||
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.service import AgentService
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.deleted_session_id: str | None = None
|
||||
|
||||
async def get_session_owner(self, *, session_id: str) -> str:
|
||||
del session_id
|
||||
return "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
async def create_session_for_user(self, *, user_id: str) -> str:
|
||||
del user_id
|
||||
return "00000000-0000-0000-0000-000000000999"
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def delete_session(self, *, session_id: str) -> None:
|
||||
self.deleted_session_id = session_id
|
||||
|
||||
|
||||
class _FakeQueue:
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
del command, dedup_key
|
||||
return "task-1"
|
||||
|
||||
|
||||
class _FailingQueue:
|
||||
async def enqueue(
|
||||
self, *, command: dict[str, object], dedup_key: str | None
|
||||
) -> str:
|
||||
del command, dedup_key
|
||||
raise RuntimeError("enqueue failed")
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
async def read(
|
||||
self, *, session_id: str, last_event_id: str | None
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
]
|
||||
|
||||
|
||||
def _user() -> CurrentUser:
|
||||
return CurrentUser(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
email="user@example.com",
|
||||
)
|
||||
|
||||
|
||||
async def test_resume_idempotency_uses_redis_lock_and_task_key() -> None:
|
||||
service = AgentService(
|
||||
repository=_FakeRepository(),
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
user = _user()
|
||||
|
||||
first = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
current_user=user,
|
||||
)
|
||||
second = await service.enqueue_resume(
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
current_user=user,
|
||||
)
|
||||
|
||||
assert first.task_id == second.task_id
|
||||
|
||||
|
||||
async def test_enqueue_run_without_session_creates_new_session() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FakeQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
|
||||
accepted = await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
current_user=_user(),
|
||||
)
|
||||
|
||||
assert accepted.session_id == "00000000-0000-0000-0000-000000000999"
|
||||
assert accepted.created is True
|
||||
assert repository.committed is True
|
||||
|
||||
|
||||
async def test_enqueue_run_keeps_created_session_when_enqueue_fails() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = AgentService(
|
||||
repository=repository,
|
||||
queue=_FailingQueue(),
|
||||
stream=_FakeStream(),
|
||||
)
|
||||
|
||||
try:
|
||||
await service.enqueue_run(
|
||||
session_id=None,
|
||||
prompt="hello",
|
||||
current_user=_user(),
|
||||
)
|
||||
raise AssertionError("expected RuntimeError")
|
||||
except RuntimeError as exc:
|
||||
assert str(exc) == "enqueue failed"
|
||||
|
||||
assert repository.deleted_session_id is None
|
||||
Reference in New Issue
Block a user