Files
social-app/backend/tests/integration/test_cli_skills_live.py
T
qzl d060962a5f feat(agent): redesign project_cli with module/method/input protocol
- Replace command/subcommand/args with module/method/input envelope
- Calendar handler uses discriminated union (mode) for read operations
- Strict Pydantic models with extra='forbid' for all calendar methods
- Worker max_iters=7, router prompt simplified (removed project_cli_defaults)
- Skill index cards + per-action files for progressive disclosure
- Frontend/AG-UI aligned to module/method dispatch
- Protocol docs updated to module/method/input contract

WIP: action cards need envelope fix, 2 tests need update, memory
handler needs Pydantic models.
2026-04-24 13:24:13 +08:00

428 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
import subprocess
import time
import asyncio
from pathlib import Path
from uuid import uuid4
import httpx
import jwt
import pytest
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
def _load_env() -> None:
env_path = Path(__file__).resolve().parents[3] / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_env()
def _get_jwt_secret() -> str | None:
return (
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
or os.getenv("SUPABASE_JWT_SECRET")
or os.getenv("JWT_SECRET")
)
def _get_supabase_url() -> str:
return (
os.getenv("SOCIAL_SUPABASE__URL")
or os.getenv("SUPABASE_URL")
or "http://localhost:54321"
)
def _get_test_user_id() -> str | None:
return os.getenv("TEST_USER_ID")
def _create_test_jwt(user_id: str) -> str:
jwt_secret = _get_jwt_secret()
if not jwt_secret:
raise RuntimeError("JWT_SECRET not found in environment")
supabase_url = _get_supabase_url()
now = int(time.time())
payload = {
"sub": user_id,
"role": "authenticated",
"aud": "authenticated",
"iss": supabase_url,
"iat": now,
"exp": now + 3600,
}
return jwt.encode(payload, jwt_secret, algorithm="HS256")
async def _get_test_user_token() -> str:
user_id = _get_test_user_id()
if user_id:
return _create_test_jwt(user_id)
result = subprocess.run(
["psql", "-t", "-A", "-c", "SELECT id FROM auth.users LIMIT 1;"],
capture_output=True,
text=True,
env={
**os.environ,
"PGHOST": "localhost",
"PGPORT": "54322",
"PGDATABASE": "postgres",
"PGUSER": "postgres",
"PGPASSWORD": "postgres",
},
)
if result.returncode == 0 and result.stdout.strip():
user_id = result.stdout.strip()
return _create_test_jwt(user_id)
pytest.skip("Could not find test user. Set TEST_USER_ID or ensure database is accessible")
async def _run_agent_and_collect_events(
client: httpx.AsyncClient,
headers: dict,
thread_id: str,
run_id: str,
user_message: str,
runtime_mode: str = "chat",
) -> tuple[list[dict], bool, str]:
max_attempts = 3
last_thread_id = thread_id
for attempt in range(max_attempts):
attempt_run_id = run_id if attempt == 0 else f"{run_id}-retry-{attempt}"
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": attempt_run_id,
"state": {},
"messages": [
{
"id": "u1",
"role": "user",
"content": user_message,
}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": runtime_mode},
},
)
if run_resp.status_code != 202:
pytest.fail(f"Run request failed: {run_resp.status_code} - {run_resp.text}")
assert run_resp.status_code == 202
run_data = run_resp.json()
effective_thread_id = str(run_data.get("threadId", thread_id))
effective_run_id = run_data.get("runId", attempt_run_id)
last_thread_id = effective_thread_id
events_url = f"{BASE_URL}/api/v1/agent/runs/{effective_thread_id}/events?runId={effective_run_id}"
tool_call_results: list[dict] = []
run_finished = False
run_error_code: str | None = None
async with client.stream(
"GET", events_url, headers=headers, timeout=120.0
) as sse_resp:
if sse_resp.status_code != 200:
error_body = await sse_resp.aread()
pytest.fail(
f"SSE request failed: {sse_resp.status_code} - {error_body.decode()}"
)
assert sse_resp.status_code == 200
buffer = ""
async for line in sse_resp.aiter_lines():
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if data_str:
buffer = data_str
elif line == "" and buffer:
try:
event_data = json.loads(buffer)
event_type = event_data.get("type")
if event_type == "TOOL_CALL_RESULT":
tool_call_results.append(event_data)
elif event_type == "RUN_ERROR":
run_finished = True
run_error_code = event_data.get("code")
print(f"RUN_ERROR: {event_data}")
break
elif event_type == "RUN_FINISHED":
run_finished = True
break
except json.JSONDecodeError:
pass
buffer = ""
if run_error_code == "AGENT_UPSTREAM_CONNECTION_ERROR" and attempt < (max_attempts - 1):
await asyncio.sleep(0.4)
continue
return tool_call_results, run_finished, effective_thread_id
return [], False, last_thread_id
def _check_db_record(table: str, user_id: str, extra_condition: str = "") -> bool:
result = subprocess.run(
[
"psql",
"-t",
"-A",
"-c",
f"SELECT COUNT(*) FROM {table} WHERE owner_id = '{user_id}'{extra_condition};",
],
capture_output=True,
text=True,
env={
**os.environ,
"PGHOST": "localhost",
"PGPORT": "54322",
"PGDATABASE": "postgres",
"PGUSER": "postgres",
"PGPASSWORD": "postgres",
},
)
if result.returncode == 0:
count = int(result.stdout.strip() or "0")
return count > 0
return False
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_calendar_create_skill_creates_db_record() -> None:
token = await _get_test_user_token()
user_id = _get_test_user_id()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
tomorrow = time.strftime("%Y-%m-%d", time.localtime(time.time() + 86400))
user_message = (
f"请帮我创建一个日程测试事件,标题为'CLI集成测试-{thread_id[:8]}'"
f"开始时间是明天{tomorrow}上午10点,持续1小时。"
f"严格按技能说明执行,不要猜测结果。"
)
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-calendar-create-test",
user_message=user_message,
)
assert run_finished, "Run did not finish"
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") == "success", f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("module") == "calendar"
assert args.get("method") == "create"
result_payload = cli_result.get("result")
assert isinstance(result_payload, dict), f"Unexpected result payload: {cli_result}"
data_payload = result_payload.get("data")
assert isinstance(data_payload, dict), f"Missing result data payload: {cli_result}"
created_ids = data_payload.get("ids")
assert isinstance(created_ids, list) and created_ids, f"No created event ids returned: {cli_result}"
created_event_id = str(created_ids[0])
if user_id and _get_supabase_url().startswith("http://localhost"):
time.sleep(1)
_check_db_record(
"schedule_items",
user_id,
f" AND id = '{created_event_id}'",
)
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_calendar_read_skill_queries_db() -> None:
token = await _get_test_user_token()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
user_message = "请查询我今天的日程安排,严格按技能说明执行,不要猜测结果。"
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-calendar-read-test",
user_message=user_message,
)
assert run_finished, "Run did not finish"
print(f"\n=== Tool call results: {len(tool_call_results)} ===")
for r in tool_call_results:
print(f" - tool_name: {r.get('tool_name')}")
print(f" status: {r.get('status')}")
print(f" tool_call_args: {json.dumps(r.get('tool_call_args', {}), ensure_ascii=False)}")
result = r.get("result")
if isinstance(result, str) and len(result) < 200:
print(f" result: {result}")
elif isinstance(result, dict):
print(f" result keys: {list(result.keys())}")
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") in {"success", "partial"}, f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("module") == "calendar"
assert args.get("method") in {"read"}
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_contacts_read_skill_queries_db() -> None:
token = await _get_test_user_token()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
user_message = "请帮我查找我的联系人列表,严格按技能说明执行,不要猜测结果。"
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-contacts-read-test",
user_message=user_message,
)
assert run_finished, "Run did not finish"
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") in {"success", "partial"}, f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("module") == "contacts"
assert args.get("method") == "read"
@pytest.mark.asyncio
@pytest.mark.live
@pytest.mark.skipif(
os.getenv("CLI_SKILLS_LIVE_TEST") != "1",
reason="set CLI_SKILLS_LIVE_TEST=1 to run live CLI + skills integration test",
)
async def test_memory_update_skill_via_automation() -> None:
token = await _get_test_user_token()
user_id = _get_test_user_id()
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {"Authorization": f"Bearer {token}"}
thread_id = str(uuid4())
user_message = (
f"请将以下信息写入我的记忆:用户偏好测试字段值为'test-value-{thread_id[:8]}'"
f"严格按技能说明执行,不要猜测结果。"
)
tool_call_results, run_finished, _ = await _run_agent_and_collect_events(
client=client,
headers=headers,
thread_id=thread_id,
run_id="run-memory-update-test",
user_message=user_message,
runtime_mode="automation",
)
assert run_finished, "Run did not finish"
project_cli_results = [
r for r in tool_call_results if r.get("tool_name") == "project_cli"
]
assert project_cli_results, "No project_cli tool call found"
cli_result = project_cli_results[0]
assert cli_result.get("status") in {"success", "partial"}, f"Tool call failed: {cli_result}"
args = cli_result.get("tool_call_args", {})
assert args.get("module") == "memory"
assert args.get("method") == "update"
if user_id:
time.sleep(1)
result = subprocess.run(
[
"psql",
"-t",
"-A",
"-c",
f"SELECT content FROM memories WHERE owner_id = '{user_id}' AND memory_type = 'user' ORDER BY updated_at DESC LIMIT 1;",
],
capture_output=True,
text=True,
env={
**os.environ,
"PGHOST": "localhost",
"PGPORT": "54322",
"PGDATABASE": "postgres",
"PGUSER": "postgres",
"PGPASSWORD": "postgres",
},
)
if result.returncode == 0 and result.stdout.strip():
content = result.stdout.strip()
assert f"test-value-{thread_id[:8]}" in content or "测试" in content