Files
social-app/backend/tests/integration/v1/agent/test_sse_flow_live.py
T
zl-q 120df903d2 feat: AG-UI 协议对齐与路由导航功能
- 前端: 添加 SSE 流式支持、stateSnapshot 事件、路由导航工具
- 前端: 实现工具调用审批流程,支持 pending 状态展示
- 后端: Agent 状态管理与会话持久化相关重构
- 文档: 新增 agent-agui-full-alignance 设计文档
- 测试: 补充相关单元测试和集成测试
2026-03-07 17:30:20 +08:00

100 lines
3.4 KiB
Python

from __future__ import annotations
import os
from datetime import datetime, timedelta, timezone
from uuid import UUID, uuid4
import httpx
import jwt
import pytest
from sqlalchemy import select
from core.config import config
from core.db.session import AsyncSessionLocal
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.profile import Profile
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
async def _owner_id() -> UUID:
async with AsyncSessionLocal() as session:
owner_id = (await session.execute(select(Profile.id).limit(1))).scalar_one_or_none()
if owner_id is None:
pytest.skip("profile owner not found")
return owner_id
def _jwt_for(user_id: UUID) -> str:
secret = config.supabase.jwt_secret
if not secret:
pytest.skip("JWT secret not configured")
issuer = f"{config.supabase.public_url.rstrip('/')}/auth/v1"
payload = {
"sub": str(user_id),
"role": "authenticated",
"aud": "authenticated",
"iss": issuer,
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
return jwt.encode(payload, secret, algorithm="HS256")
@pytest.mark.asyncio
@pytest.mark.live
async def test_agent_sse_closed_loop_live() -> None:
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
owner_id = await _owner_id()
token = _jwt_for(owner_id)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=30.0) as client:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": str(uuid4()),
"runId": "run-live-1",
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": "请用一句话介绍你自己"}
],
"tools": [],
"context": [],
"forwardedProps": {},
},
)
assert run_resp.status_code == 202
accepted = run_resp.json()
thread_id = str(accepted["threadId"])
assert thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events"
event_names: list[str] = []
async with client.stream("GET", events_url, headers=headers, timeout=20.0) as sse_resp:
assert sse_resp.status_code == 200
assert sse_resp.headers.get("content-type", "").startswith("text/event-stream")
async for line in sse_resp.aiter_lines():
if line.startswith("event:"):
event_names.append(line.split(":", 1)[1].strip())
assert "RUN_STARTED" in event_names
assert "RUN_FINISHED" in event_names or "RUN_ERROR" in event_names
async with AsyncSessionLocal() as session:
session_row = await session.get(AgentChatSession, UUID(thread_id))
assert session_row is not None
assert session_row.message_count >= 1
assert session_row.total_tokens >= 0
assert session_row.total_cost >= 0
rows = await session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == UUID(thread_id))
)
assert len(list(rows.scalars().all())) >= 1