Files
social-app/backend/tests/integration/v1/agent/test_sse_flow_live.py
T

305 lines
11 KiB
Python

from __future__ import annotations
import base64
from pathlib import Path
from uuid import UUID, uuid4
import httpx
import pytest
from sqlalchemy import select
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from schemas.enums import AgentChatMessageRole
BASE_URL = f"http://localhost:{5775}"
FIXTURE_IMAGE_PATH = (
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
)
def _require_test_phone() -> str:
phone = config.test.phone
if not phone:
pytest.fail("SOCIAL_TEST__PHONE is required for live integration tests")
return phone
async def _live_access_token(client: httpx.AsyncClient) -> str:
phone = _require_test_phone()
if not phone.startswith("+"):
phone = f"+{phone}"
code = config.test.code or "000000"
response = await client.post(
f"{BASE_URL}/api/v1/auth/phone-session",
json={"phone": phone, "token": code},
)
response_text = response.text.strip().replace("\n", " ")
truncated_text = response_text[:200]
if len(response_text) > 200:
truncated_text += "..."
assert response.status_code == 200, (
f"live login failed: status={response.status_code}, response={truncated_text!r}"
)
token = response.json().get("access_token")
assert isinstance(token, str) and token
return token
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_sse_closed_loop_live() -> None:
if config.runtime.environment not in {"dev", "test"}:
pytest.skip("live integration tests require dev or test environment")
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": str(uuid4()),
"runId": "run-live-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
thread_id = str(accepted["threadId"])
assert thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=20.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_names.append(line.split(":", 1)[1].strip())
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id)
)
)
assert len(list(rows.scalars().all())) >= 1
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_runs_events_history_live_with_image_input() -> None:
if config.runtime.environment not in {"dev", "test"}:
pytest.skip("live integration tests require dev or test environment")
image_data = base64.b64encode(FIXTURE_IMAGE_PATH.read_bytes()).decode("ascii")
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": "run-live-image-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请描述图片里的内容"},
{
"type": "binary",
"data": image_data,
"mimeType": "image/png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert run_resp.status_code == 202
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=90.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
event_names.append(event_name)
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
break
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
history_resp = await client.get(
f"{BASE_URL}/api/v1/agent/history",
headers=headers,
params={"threadId": thread_id},
)
assert history_resp.status_code == 200
history = history_resp.json()
assert history.get("type") == "STATE_SNAPSHOT"
snapshot = history.get("snapshot", {})
assert snapshot.get("scope") == "history_day"
messages = snapshot.get("messages", [])
user_messages = [
item
for item in messages
if isinstance(item, dict) and item.get("role") == "user"
]
assert user_messages
metadata = user_messages[0].get("metadata")
assert isinstance(metadata, dict)
user_attachments = metadata.get("user_message_attachments")
assert isinstance(user_attachments, list)
assert user_attachments
assert isinstance(user_attachments[0], dict)
assert isinstance(user_attachments[0].get("path"), str)
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id)
)
)
all_messages = list(rows.scalars().all())
assert all_messages
user_rows = [
row
for row in all_messages
if (
getattr(row.role, "value", row.role) == "user"
or str(getattr(row.role, "value", row.role)) == "user"
)
]
assert user_rows
metadata = user_rows[0].metadata_json or {}
user_attachments = metadata.get("user_message_attachments")
assert isinstance(user_attachments, list)
assert user_attachments
assert isinstance(user_attachments[0], dict)
assert isinstance(user_attachments[0].get("path"), str)
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_tool_call_result_persisted_live() -> None:
if config.runtime.environment not in {"dev", "test"}:
pytest.skip("live integration tests require dev or test environment")
thread_id = str(uuid4())
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": "run-tool-verify-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "帮我查一下明天有哪些日程安排",
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
assert str(accepted["threadId"]) == thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events?runId=run-tool-verify-1"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=90.0
) as sse_resp:
assert sse_resp.status_code == 200
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
event_names.append(event_name)
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
break
assert "RUN_STARTED" in event_names, (
f"missing RUN_STARTED, got: {event_names}"
)
finished_ok = "RUN_FINISHED" in event_names
finished_err = "RUN_ERROR" in event_names
assert finished_ok or finished_err, (
f"no terminal event, got: {event_names}"
)
async with AsyncSessionLocal() as session:
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id),
AgentChatMessage.role == AgentChatMessageRole.TOOL,
)
)
tool_messages = list(rows.scalars().all())
if finished_ok:
assert len(tool_messages) >= 1, (
f"expected >=1 role='tool' message but found {len(tool_messages)}. "
f"SSE events: {event_names}"
)