Files
social-app/backend/tests/integration/test_cli_skills_live.py
T

404 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
import subprocess
import time
from pathlib import Path
from uuid import uuid4
import httpx
import jwt
import pytest
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
def _load_env() -> None:
env_path = Path(__file__).resolve().parents[3] / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_env()
def _get_jwt_secret() -> str | None:
return (
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
or os.getenv("SUPABASE_JWT_SECRET")
or os.getenv("JWT_SECRET")
)
def _get_supabase_url() -> str:
return (
os.getenv("SOCIAL_SUPABASE__URL")
or os.getenv("SUPABASE_URL")
or "http://localhost:54321"
)
def _get_test_user_id() -> str | None:
return os.getenv("TEST_USER_ID")
def _create_test_jwt(user_id: str) -> str:
jwt_secret = _get_jwt_secret()
if not jwt_secret:
raise RuntimeError("JWT_SECRET not found in environment")
supabase_url = _get_supabase_url()
now = int(time.time())
payload = {
"sub": user_id,
"role": "authenticated",
"aud": "authenticated",
"iss": supabase_url,
"iat": now,
"exp": now + 3600,
}
return jwt.encode(payload, jwt_secret, algorithm="HS256")
async def _get_test_user_token() -> str:
user_id = _get_test_user_id()
if user_id:
return _create_test_jwt(user_id)
result = subprocess.run(
["psql", "-t", "-A", "-c", "SELECT id FROM auth.users LIMIT 1;"],
capture_output=True,
text=True,
env={
**os.environ,
"PGHOST": "localhost",
"PGPORT": "54322",
"PGDATABASE": "postgres",
"PGUSER": "postgres",
"PGPASSWORD": "postgres",
},
)
if result.returncode == 0 and result.stdout.strip():
user_id = result.stdout.strip()
return _create_test_jwt(user_id)
pytest.skip("Could not find test user. Set TEST_USER_ID or ensure database is accessible")
async def _run_agent_and_collect_events(
client: httpx.AsyncClient,
headers: dict,
thread_id: str,
run_id: str,
user_message: str,
runtime_mode: str = "chat",
) -> tuple[list[dict], bool, str]:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": run_id,
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": user_message,
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": runtime_mode},
},
)
if run_resp.status_code != 202:
pytest.fail(f"Run request failed: {run_resp.status_code} - {run_resp.text}")
assert run_resp.status_code == 202
run_data = run_resp.json()
effective_thread_id = str(run_data.get("threadId", thread_id))
effective_run_id = run_data.get("runId", run_id)
events_url = f"{BASE_URL}/api/v1/agent/runs/{effective_thread_id}/events?runId={effective_run_id}"
tool_call_results: list[dict] = []
run_finished = False
async with client.stream(
"GET", events_url, headers=headers, timeout=120.0
) as sse_resp:
if sse_resp.status_code != 200:
error_body = await sse_resp.aread()
pytest.fail(f"SSE request failed: {sse_resp.status_code} - {error_body.decode()}")
assert sse_resp.status_code == 200
buffer = ""
async for line in sse_resp.aiter_lines():
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if data_str:
buffer = data_str
elif line == "" and buffer:
try:
event_data = json.loads(buffer)
event_type = event_data.get("type")
if event_type == "TOOL_CALL_RESULT":
tool_call_results.append(event_data)
elif event_type == "RUN_ERROR":
run_finished = True
print(f"RUN_ERROR: {event_data}")
break
elif event_type == "RUN_FINISHED":
run_finished = True
break
except json.JSONDecodeError:
pass
buffer = ""
return tool_call_results, run_finished, effective_thread_id
def _check_db_record(table: str, user_id: str, extra_condition: str = "") -> bool:
result = subprocess.run(
[
"psql",
"-t",
"-A",
"-c",
f"SELECT COUNT(*) FROM {table} WHERE owner_id = '{user_id}'{extra_condition};",
],
capture_output=True,
text=True,
env={
**os.environ,
"PGHOST": "localhost",
"PGPORT": "54322",
"PGDATABASE": "postgres",
"PGUSER": "postgres",
"PGPASSWORD": "postgres",
},
)
if result.returncode == 0:
count = int(result.stdout.strip() or "0")
return count > 0
return False
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_calendar_write_skill_creates_db_record() -> None:
token = await _get_test_user_token()
user_id = _get_test_user_id()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
tomorrow = time.strftime("%Y-%m-%d", time.localtime(time.time() + 86400))
user_message = (
f"请帮我创建一个日程测试事件,标题为'CLI集成测试-{thread_id[:8]}'"
f"开始时间是明天{tomorrow}上午10点,持续1小时。"
f"严格按技能说明执行,不要猜测结果。"
)
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-calendar-write-test",
user_message=user_message,
)
assert run_finished, "Run did not finish"
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") == "success", f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("command") == "calendar"
assert args.get("subcommand") == "write"
if user_id:
time.sleep(1)
has_record = _check_db_record(
"schedule_items",
user_id,
f" AND title LIKE '%CLI集成测试-{thread_id[:8]}%'",
)
assert has_record, f"No schedule_items record found for user {user_id}"
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_calendar_read_skill_queries_db() -> None:
token = await _get_test_user_token()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
user_message = "请查询我今天的日程安排,严格按技能说明执行,不要猜测结果。"
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-calendar-read-test",
user_message=user_message,
)
assert run_finished, "Run did not finish"
print(f"\n=== Tool call results: {len(tool_call_results)} ===")
for r in tool_call_results:
print(f" - tool_name: {r.get('tool_name')}")
print(f" status: {r.get('status')}")
print(f" tool_call_args: {json.dumps(r.get('tool_call_args', {}), ensure_ascii=False)}")
result = r.get("result")
if isinstance(result, str) and len(result) < 200:
print(f" result: {result}")
elif isinstance(result, dict):
print(f" result keys: {list(result.keys())}")
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") in {"success", "partial"}, f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("command") == "calendar"
assert args.get("subcommand") == "read"
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_contacts_lookup_skill_queries_db() -> None:
token = await _get_test_user_token()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
user_message = "请帮我查找我的联系人列表,严格按技能说明执行,不要猜测结果。"
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-contacts-lookup-test",
user_message=user_message,
)
assert run_finished, "Run did not finish"
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") in {"success", "partial"}, f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("command") == "contacts"
assert args.get("subcommand") == "lookup"
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_memory_write_skill_via_automation() -> None:
token = await _get_test_user_token()
user_id = _get_test_user_id()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
user_message = (
f"请将以下信息写入我的记忆:用户偏好测试字段值为'test-value-{thread_id[:8]}'"
f"严格按技能说明执行,不要猜测结果。"
)
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-memory-write-test",
user_message=user_message,
runtime_mode="automation",
)
assert run_finished, "Run did not finish"
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") in {"success", "partial"}, f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("command") == "memory"
assert args.get("subcommand") in {"write", "update"}
if user_id:
time.sleep(1)
result = subprocess.run(
[
"psql",
"-t",
"-A",
"-c",
f"SELECT content FROM memories WHERE owner_id = '{user_id}' AND memory_type = 'user' ORDER BY updated_at DESC LIMIT 1;",
],
capture_output=True,
text=True,
env={
**os.environ,
"PGHOST": "localhost",
"PGPORT": "54322",
"PGDATABASE": "postgres",
"PGUSER": "postgres",
"PGPASSWORD": "postgres",
},
)
if result.returncode == 0 and result.stdout.strip():
content = result.stdout.strip()
assert f"test-value-{thread_id[:8]}" in content or "测试" in content