feat(agent): complete closed-loop runtime and pricing fallback
This commit is contained in:
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.agent.application.resume_service import ResumeService
|
||||
from core.agent.application.run_service import RunService
|
||||
from core.agent.infrastructure.persistence.session_repository import SessionRepository
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
from core.db import AsyncSessionLocal, engine
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.profile import Profile
|
||||
from models.system_agents import SystemAgents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_then_resume_persists_messages_and_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def _fake_execute(self, *, user_input: str) -> dict[str, object]:
|
||||
del user_input
|
||||
return {
|
||||
"assistant_text": "Mocked answer",
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
"cost": 0.0025,
|
||||
"agui_events": [
|
||||
{"type": "TEXT_MESSAGE_START", "data": {"session_id": "__TBD__"}},
|
||||
{
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"data": {"session_id": "__TBD__", "text": "Mocked answer"},
|
||||
},
|
||||
{"type": "TEXT_MESSAGE_END", "data": {"session_id": "__TBD__"}},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.runtime.CrewAIRuntime.execute",
|
||||
_fake_execute,
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
existing_owner = await lookup_session.execute(
|
||||
select(AgentChatSession.user_id).limit(1)
|
||||
)
|
||||
owner_id = existing_owner.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No existing session owner available in local database")
|
||||
factory_id = uuid.uuid4()
|
||||
session_uuid = uuid.uuid4()
|
||||
agent_type = f"AAA_TEST_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
llm_row = await seed_session.execute(select(Llm.id).limit(1))
|
||||
llm_id = llm_row.scalar_one_or_none()
|
||||
if llm_id is None:
|
||||
seed_session.add(
|
||||
LlmFactory(
|
||||
id=factory_id,
|
||||
name=f"dashscope-test-{uuid.uuid4().hex[:8]}",
|
||||
request_url="https://dashscope.example",
|
||||
)
|
||||
)
|
||||
llm_id = uuid.uuid4()
|
||||
seed_session.add(
|
||||
Llm(
|
||||
id=llm_id,
|
||||
factory_id=factory_id,
|
||||
model_code=f"qwen3.5-flash-test-{uuid.uuid4().hex[:6]}",
|
||||
)
|
||||
)
|
||||
seed_session.add(
|
||||
SystemAgents(agent_type=agent_type, llm_id=llm_id, status="active")
|
||||
)
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.commit()
|
||||
|
||||
published: list[str] = []
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
|
||||
try:
|
||||
run_result = run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": str(session_uuid),
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
pending_tool_call_id = str(run_result["pending_tool_call_id"])
|
||||
|
||||
run_agent_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"session_id": str(session_uuid),
|
||||
"tool_call_id": pending_tool_call_id,
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=RunService(),
|
||||
resume_service=ResumeService(),
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
db_session = await verify_session.get(AgentChatSession, session_uuid)
|
||||
assert db_session is not None
|
||||
assert db_session.status == AgentChatSessionStatus.COMPLETED
|
||||
assert db_session.message_count == 4
|
||||
assert db_session.total_tokens == 18
|
||||
assert db_session.total_cost == Decimal("0.002500")
|
||||
assert db_session.state_snapshot == {
|
||||
"status": "completed",
|
||||
"pending_tool_call_id": None,
|
||||
}
|
||||
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == session_uuid)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = list(rows.scalars().all())
|
||||
assert [item.role for item in messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
AgentChatMessageRole.TOOL,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert messages[1].input_tokens == 11
|
||||
assert messages[1].output_tokens == 7
|
||||
assert messages[1].cost == Decimal("0.002500")
|
||||
|
||||
assert "RUN_STARTED" in published
|
||||
assert "RUN_RESUMED" in published
|
||||
assert "TEXT_MESSAGE_CONTENT" in published
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(SystemAgents).where(SystemAgents.agent_type == agent_type)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_session_cascades_to_messages() -> None:
|
||||
session_uuid = uuid.uuid4()
|
||||
await engine.dispose()
|
||||
|
||||
async with AsyncSessionLocal() as lookup_session:
|
||||
owner = await lookup_session.execute(select(Profile.id).limit(1))
|
||||
owner_id = owner.scalar_one_or_none()
|
||||
if owner_id is None:
|
||||
pytest.skip("No profile owner available in local database")
|
||||
|
||||
async with AsyncSessionLocal() as seed_session:
|
||||
seed_session.add(AgentChatSession(id=session_uuid, user_id=owner_id))
|
||||
await seed_session.flush()
|
||||
seed_session.add(
|
||||
AgentChatMessage(
|
||||
session_id=session_uuid,
|
||||
seq=1,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
await seed_session.commit()
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as mutate_session:
|
||||
repo = SessionRepository(mutate_session)
|
||||
affected = await repo.soft_delete_session_with_messages(
|
||||
session_id=session_uuid
|
||||
)
|
||||
await mutate_session.commit()
|
||||
assert affected == 1
|
||||
|
||||
async with AsyncSessionLocal() as verify_session:
|
||||
db_session = await verify_session.get(AgentChatSession, session_uuid)
|
||||
assert db_session is not None
|
||||
assert db_session.deleted_at is not None
|
||||
rows = await verify_session.execute(
|
||||
select(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == session_uuid
|
||||
)
|
||||
)
|
||||
messages = list(rows.scalars().all())
|
||||
assert len(messages) == 1
|
||||
assert messages[0].deleted_at is not None
|
||||
finally:
|
||||
async with AsyncSessionLocal() as cleanup_session:
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == session_uuid
|
||||
)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(AgentChatSession).where(AgentChatSession.id == session_uuid)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.application.session_state_persistence import persist_tool_result_payload
|
||||
from core.agent.domain.tool_correlation import reconstruct_tool_call_result_event
|
||||
from core.agent.infrastructure.queue.tasks import run_agent_task
|
||||
|
||||
|
||||
class _FakeStorage:
|
||||
def __init__(self) -> None:
|
||||
self.writes: dict[str, dict[str, object]] = {}
|
||||
|
||||
async def upload_json(
|
||||
self, *, bucket: str, path: str, payload: dict[str, object]
|
||||
) -> str:
|
||||
self.writes[f"{bucket}/{path}"] = payload
|
||||
return "etag-1"
|
||||
|
||||
|
||||
def test_closed_loop_run_flow_frontend_to_sse() -> None:
|
||||
session_id = "00000000-0000-0000-0000-000000000001"
|
||||
published: list[str] = []
|
||||
|
||||
class _FakeRunService:
|
||||
async def run(self, *, session_id: str, user_input: str) -> dict[str, object]:
|
||||
return {"session_id": session_id, "user_input": user_input}
|
||||
|
||||
def _publish(event_type: str, payload: dict[str, object]) -> None:
|
||||
del payload
|
||||
published.append(event_type)
|
||||
|
||||
result = run_agent_task(
|
||||
{
|
||||
"command": "run",
|
||||
"session_id": session_id,
|
||||
"user_input": "hello",
|
||||
},
|
||||
publish_event=_publish,
|
||||
run_service=_FakeRunService(),
|
||||
)
|
||||
|
||||
assert result["session_id"] == session_id
|
||||
assert published[0] == "RUN_STARTED"
|
||||
assert published[-1] == "RUN_FINISHED"
|
||||
|
||||
|
||||
async def test_tool_result_full_payload_persist_and_reconstruct() -> None:
|
||||
storage = _FakeStorage()
|
||||
payload = {
|
||||
"schema": "ui.v1",
|
||||
"components": [{"type": "card", "title": "Weather"}],
|
||||
}
|
||||
|
||||
metadata = await persist_tool_result_payload(
|
||||
storage=storage,
|
||||
run_id="run-1",
|
||||
turn_id="turn-1",
|
||||
tool_call_id="call-1",
|
||||
tool_name="weather",
|
||||
payload=payload,
|
||||
bucket="private",
|
||||
path="tool-results/run-1/call-1.json",
|
||||
)
|
||||
|
||||
event = reconstruct_tool_call_result_event(metadata=metadata, payload=payload)
|
||||
|
||||
assert metadata["type"] == "tool_result"
|
||||
assert metadata["storage_bucket"] == "private"
|
||||
assert event["type"] == "TOOL_CALL_RESULT"
|
||||
assert event["data"]["schema"] == "ui.v1"
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class _FakeAgentService:
|
||||
def __init__(self) -> None:
|
||||
self._stream_called = False
|
||||
|
||||
async def enqueue_run(
|
||||
self, *, session_id: str | None, prompt: str, current_user: CurrentUser
|
||||
):
|
||||
del prompt, current_user
|
||||
resolved_session = session_id or "auto-created-session"
|
||||
return SimpleNamespace(
|
||||
task_id="task-run-1",
|
||||
session_id=resolved_session,
|
||||
created=session_id is None,
|
||||
)
|
||||
|
||||
async def enqueue_resume(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
current_user: CurrentUser,
|
||||
):
|
||||
del tool_call_id, current_user
|
||||
return SimpleNamespace(
|
||||
task_id="task-resume-1", session_id=session_id, created=False
|
||||
)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
current_user: CurrentUser,
|
||||
) -> list[dict[str, object]]:
|
||||
del session_id, current_user
|
||||
if self._stream_called:
|
||||
return []
|
||||
self._stream_called = True
|
||||
return [
|
||||
{"id": "2-0", "event": {"type": "RUN_STARTED"}, "cursor": last_event_id}
|
||||
]
|
||||
|
||||
|
||||
def test_run_requires_auth_and_returns_202_task_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
unauthorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
)
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
authorized = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"session_id": "session-1", "prompt": "hello"},
|
||||
)
|
||||
assert authorized.status_code == 202
|
||||
assert authorized.json()["task_id"] == "task-run-1"
|
||||
assert authorized.json()["created"] is False
|
||||
|
||||
first_chat = client.post(
|
||||
"/api/v1/agent/runs",
|
||||
json={"prompt": "hello"},
|
||||
)
|
||||
assert first_chat.status_code == 202
|
||||
assert first_chat.json()["session_id"] == "auto-created-session"
|
||||
assert first_chat.json()["created"] is True
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_reads_from_last_event_id() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: _FakeAgentService()
|
||||
app.dependency_overrides[get_current_user] = lambda: CurrentUser(
|
||||
id=uuid4(), email="user@example.com"
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/agent/runs/session-1/events?idle_limit=1",
|
||||
headers={"Last-Event-ID": "1-0"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert "id: 2-0" in response.text
|
||||
assert "event: RUN_STARTED" in response.text
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
Reference in New Issue
Block a user