Files
social-app/backend/tests/integration/v1/agent/test_sse_flow_live.py
T

319 lines
11 KiB
Python
Raw Normal View History

from __future__ import annotations
from pathlib import Path
from uuid import UUID, uuid4
import httpx
import pytest
from sqlalchemy import select
2026-04-24 14:11:11 +08:00
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.config.settings import config
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from schemas.enums import AgentChatMessageRole
BASE_URL = f"http://localhost:{5775}"
FIXTURE_IMAGE_PATH = (
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
)
2026-04-24 14:11:11 +08:00
def _make_session():
engine = create_async_engine(
config.database_url,
pool_pre_ping=True,
)
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)()
def _require_test_phone() -> str:
phone = config.test.phone
if not phone:
pytest.fail("SOCIAL_TEST__PHONE is required for live integration tests")
return phone
async def _live_access_token(client: httpx.AsyncClient) -> str:
phone = _require_test_phone()
if not phone.startswith("+"):
phone = f"+{phone}"
code = config.test.code or "000000"
response = await client.post(
f"{BASE_URL}/api/v1/auth/phone-session",
json={"phone": phone, "token": code},
)
response_text = response.text.strip().replace("\n", " ")
truncated_text = response_text[:200]
if len(response_text) > 200:
truncated_text += "..."
assert response.status_code == 200, (
f"live login failed: status={response.status_code}, response={truncated_text!r}"
)
token = response.json().get("access_token")
assert isinstance(token, str) and token
return token
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_sse_closed_loop_live() -> None:
if config.runtime.environment not in {"dev", "test"}:
pytest.skip("live integration tests require dev or test environment")
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": str(uuid4()),
"runId": "run-live-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
thread_id = str(accepted["threadId"])
assert thread_id
2026-04-24 14:11:11 +08:00
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events?runId=run-live-1"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=20.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_names.append(line.split(":", 1)[1].strip())
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
2026-04-24 14:11:11 +08:00
async with _make_session() as session:
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id)
)
)
assert len(list(rows.scalars().all())) >= 1
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_runs_events_history_live_with_image_input() -> None:
if config.runtime.environment not in {"dev", "test"}:
pytest.skip("live integration tests require dev or test environment")
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
2026-04-24 14:11:11 +08:00
upload_resp = await client.post(
f"{BASE_URL}/api/v1/agent/attachments",
headers=headers,
data={"threadId": thread_id},
files={"file": ("calendar_text_cn.png", FIXTURE_IMAGE_PATH.read_bytes(), "image/png")},
)
assert upload_resp.status_code == 200, (
f"upload failed: {upload_resp.status_code} {upload_resp.text[:200]}"
)
attachment = upload_resp.json()["attachment"]
image_url = attachment["url"]
assert isinstance(image_url, str) and image_url
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": "run-live-image-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": [
{"type": "text", "text": "请描述图片里的内容"},
{
"type": "binary",
2026-04-24 14:11:11 +08:00
"url": image_url,
"mimeType": "image/png",
},
],
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert run_resp.status_code == 202
2026-04-24 14:11:11 +08:00
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events?runId=run-live-image-1"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=90.0
) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith(
"text/event-stream"
)
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
event_names.append(event_name)
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
break
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
history_resp = await client.get(
f"{BASE_URL}/api/v1/agent/history",
headers=headers,
params={"threadId": thread_id},
)
assert history_resp.status_code == 200
history = history_resp.json()
2026-04-24 14:11:11 +08:00
assert history.get("scope") == "history_day"
messages = history.get("messages", [])
user_messages = [
item
for item in messages
if isinstance(item, dict) and item.get("role") == "user"
]
assert user_messages
2026-04-24 14:11:11 +08:00
user_attachments = user_messages[0].get("attachments")
assert isinstance(user_attachments, list)
assert user_attachments
assert isinstance(user_attachments[0], dict)
2026-04-24 14:11:11 +08:00
assert isinstance(user_attachments[0].get("url"), str)
2026-04-24 14:11:11 +08:00
async with _make_session() as session:
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id)
)
)
all_messages = list(rows.scalars().all())
assert all_messages
user_rows = [
row
for row in all_messages
if (
getattr(row.role, "value", row.role) == "user"
or str(getattr(row.role, "value", row.role)) == "user"
)
]
assert user_rows
metadata = user_rows[0].metadata_json or {}
user_attachments = metadata.get("user_message_attachments")
assert isinstance(user_attachments, list)
assert user_attachments
assert isinstance(user_attachments[0], dict)
assert isinstance(user_attachments[0].get("path"), str)
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_tool_call_result_persisted_live() -> None:
if config.runtime.environment not in {"dev", "test"}:
pytest.skip("live integration tests require dev or test environment")
thread_id = str(uuid4())
async with httpx.AsyncClient(timeout=30.0) as client:
token = await _live_access_token(client)
headers = {"Authorization": f"Bearer {token}"}
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": "run-tool-verify-1",
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": "帮我查一下明天有哪些日程安排",
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
assert str(accepted["threadId"]) == thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events?runId=run-tool-verify-1"
event_names: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=90.0
) as sse_resp:
assert sse_resp.status_code == 200
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
event_names.append(event_name)
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
break
assert "RUN_STARTED" in event_names, (
f"missing RUN_STARTED, got: {event_names}"
)
finished_ok = "RUN_FINISHED" in event_names
finished_err = "RUN_ERROR" in event_names
assert finished_ok or finished_err, (
f"no terminal event, got: {event_names}"
)
2026-04-24 14:11:11 +08:00
async with _make_session() as session:
rows = await session.execute(
select(AgentChatMessage).where(
AgentChatMessage.session_id == UUID(thread_id),
AgentChatMessage.role == AgentChatMessageRole.TOOL,
)
)
tool_messages = list(rows.scalars().all())
if finished_ok:
assert len(tool_messages) >= 1, (
f"expected >=1 role='tool' message but found {len(tool_messages)}. "
f"SSE events: {event_names}"
)