Files
social-app/backend/tests/integration/test_cli_tool_flow.py
T

296 lines
10 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
import os
import time
from pathlib import Path
from uuid import uuid4
import httpx
import jwt
import pytest
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
def _load_env() -> None:
env_path = Path(__file__).resolve().parents[3] / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_env()
def _get_jwt_secret() -> str | None:
return (
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
or os.getenv("SUPABASE_JWT_SECRET")
or os.getenv("JWT_SECRET")
)
def _get_supabase_url() -> str:
return (
os.getenv("SOCIAL_SUPABASE__URL")
or os.getenv("SUPABASE_URL")
or "http://localhost:54321"
)
def _get_test_user_id() -> str | None:
return os.getenv("TEST_USER_ID")
def _create_test_jwt(user_id: str) -> str:
jwt_secret = _get_jwt_secret()
if not jwt_secret:
raise RuntimeError("JWT_SECRET not found in environment")
supabase_url = _get_supabase_url()
now = int(time.time())
payload = {
"sub": user_id,
"role": "authenticated",
"aud": "authenticated",
"iss": supabase_url,
"iat": now,
"exp": now + 3600,
}
return jwt.encode(payload, jwt_secret, algorithm="HS256")
async def _get_test_user_token() -> str:
user_id = _get_test_user_id()
if user_id:
return _create_test_jwt(user_id)
import subprocess
result = subprocess.run(
["psql", "-t", "-A", "-c", "SELECT id FROM auth.users LIMIT 1;"],
capture_output=True,
text=True,
env={**os.environ, "PGHOST": "localhost", "PGPORT": "54322", "PGDATABASE": "postgres", "PGUSER": "postgres", "PGPASSWORD": "postgres"},
)
if result.returncode == 0 and result.stdout.strip():
user_id = result.stdout.strip()
return _create_test_jwt(user_id)
pytest.skip("Could not find test user. Set TEST_USER_ID or ensure database is accessible")
async def _run_agent_and_collect_events(
client: httpx.AsyncClient,
headers: dict,
thread_id: str,
run_id: str,
user_message: str,
) -> tuple[list[dict], bool]:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": run_id,
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": user_message,
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
if run_resp.status_code != 202:
pytest.fail(f"Run request failed: {run_resp.status_code} - {run_resp.text}")
assert run_resp.status_code == 202
run_data = run_resp.json()
effective_thread_id = str(run_data.get("threadId", thread_id))
effective_run_id = run_data.get("runId", run_id)
events_url = f"{BASE_URL}/api/v1/agent/runs/{effective_thread_id}/events?runId={effective_run_id}"
tool_call_results: list[dict] = []
run_finished = False
async with client.stream(
"GET", events_url, headers=headers, timeout=60.0
) as sse_resp:
if sse_resp.status_code != 200:
error_body = await sse_resp.aread()
pytest.fail(f"SSE request failed: {sse_resp.status_code} - {error_body.decode()}")
assert sse_resp.status_code == 200
buffer = ""
async for line in sse_resp.aiter_lines():
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if data_str:
buffer = data_str
elif line == "" and buffer:
try:
event_data = json.loads(buffer)
event_type = event_data.get("type")
if event_type == "TOOL_CALL_RESULT":
tool_call_results.append(event_data)
elif event_type in {"RUN_FINISHED", "RUN_ERROR"}:
run_finished = True
break
except json.JSONDecodeError:
pass
buffer = ""
assert run_finished, "RUN_FINISHED or RUN_ERROR not received"
return tool_call_results, effective_thread_id
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_TOOL_LIVE_INTEGRATION") != "1",
reason="set CLI_TOOL_LIVE_INTEGRATION=1 to run live CLI tool integration test",
)
async def test_agent_calendar_read_via_cli() -> None:
token = await _get_test_user_token()
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
tool_call_results, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-cli-calendar-read",
user_message="请查询我今天的日程安排,不要猜测结果,按你的技能说明执行。",
)
assert tool_call_results, "expected at least one TOOL_CALL_RESULT event"
tool_names = [result.get("tool_name") for result in tool_call_results]
assert "view_skill_file" in tool_names
assert "project_cli" in tool_names
assert tool_names.index("view_skill_file") < tool_names.index("project_cli")
view_result = next(
result for result in tool_call_results if result.get("tool_name") == "view_skill_file"
)
assert view_result.get("status") in {"success", "failure", "partial"}
view_args = view_result.get("tool_call_args")
assert isinstance(view_args, dict)
assert view_args.get("file_path") == "calendar/SKILL.md"
result = next(
result for result in tool_call_results if result.get("tool_name") == "project_cli"
)
assert result.get("status") in {"success", "failure", "partial"}
tool_call_args = result.get("tool_call_args")
assert isinstance(tool_call_args, dict)
assert tool_call_args.get("command") == "calendar"
assert tool_call_args.get("subcommand") == "read"
raw_result = result.get("result")
if isinstance(raw_result, str):
raw_result = json.loads(raw_result)
assert isinstance(raw_result, dict), f"result should be dict, got {type(raw_result)}"
assert raw_result.get("command") == "calendar"
assert raw_result.get("subcommand") == "read"
if "ui_schema" in result:
ui_schema = result["ui_schema"]
assert isinstance(ui_schema, dict)
assert "version" in ui_schema
assert "ui_hints" not in result, "ui_hints should not appear in SSE wire (replaced by ui_schema)"
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_TOOL_LIVE_INTEGRATION") != "1",
reason="set CLI_TOOL_LIVE_INTEGRATION=1 to run live CLI tool integration test",
)
async def test_tool_ui_schema_in_history() -> None:
token = await _get_test_user_token()
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
_, effective_thread_id = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-cli-history-test",
user_message="请查询我今天的日程安排,不要猜测结果,按你的技能说明执行。",
)
history_resp = await client.get(
f"{BASE_URL}/api/v1/agent/history",
headers=headers,
params={"threadId": effective_thread_id},
)
assert history_resp.status_code == 200
history = history_resp.json()
assert "scope" in history
assert "messages" in history
messages = history.get("messages", [])
tool_messages = [
m
for m in messages
if isinstance(m, dict) and m.get("role") == "tool"
]
assert tool_messages, "expected at least one tool message in history"
found_project_cli = False
found_view_skill_file = False
for tool_msg in tool_messages:
metadata = tool_msg.get("metadata", {})
tool_agent_output = metadata.get("tool_agent_output")
if not tool_agent_output:
continue
tool_name = tool_agent_output.get("tool_name")
assert tool_name in {"project_cli", "view_skill_file"}
assert "result" in tool_agent_output
assert "status" in tool_agent_output
if tool_name == "view_skill_file":
tool_call_args = tool_agent_output.get("tool_call_args")
assert isinstance(tool_call_args, dict)
assert tool_call_args.get("file_path") == "calendar/SKILL.md"
found_view_skill_file = True
continue
result = tool_agent_output.get("result")
if isinstance(result, str):
try:
result = json.loads(result)
tool_agent_output["result"] = result
except (json.JSONDecodeError, ValueError):
pass
assert isinstance(result, dict), f"result in DB should be dict, got {type(result)}: {result!r}"
assert result.get("command") == "calendar"
assert result.get("subcommand") == "read"
ui_hints = tool_agent_output.get("ui_hints")
assert isinstance(ui_hints, dict), f"ui_hints should be dict, got {type(ui_hints)}"
found_project_cli = True
assert found_view_skill_file, "expected persisted view_skill_file tool output"
assert found_project_cli, "expected persisted project_cli tool output"