refactor: unify skills+cli runtime and streamline ag-ui flow
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.formatter import DashScopeChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
|
||||
from core.agentscope.tools.toolkit import build_toolkit
|
||||
|
||||
|
||||
class _DummyModel:
|
||||
stream = False
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("dummy model should not be called in this test")
|
||||
|
||||
|
||||
def test_react_agent_sys_prompt_includes_registered_skill_prompt() -> None:
|
||||
toolkit = build_toolkit(enabled_skill_names={"calendar", "contacts"})
|
||||
agent = ReActAgent(
|
||||
name="tester",
|
||||
sys_prompt="base prompt",
|
||||
model=_DummyModel(),
|
||||
formatter=DashScopeChatFormatter(),
|
||||
toolkit=toolkit,
|
||||
memory=InMemoryMemory(),
|
||||
)
|
||||
|
||||
prompt = agent.sys_prompt
|
||||
assert "base prompt" in prompt
|
||||
assert "# Agent Skills" in prompt
|
||||
assert "## calendar" in prompt
|
||||
assert "## contacts" in prompt
|
||||
assert "SKILL.md" in prompt
|
||||
|
||||
|
||||
def test_view_skill_file_tool_reads_registered_skill_content() -> None:
|
||||
toolkit = build_toolkit(enabled_skill_names={"calendar"})
|
||||
tool = toolkit.tools["view_skill_file"].original_func
|
||||
|
||||
response = asyncio.run(
|
||||
tool(file_path="calendar/SKILL.md", ranges=[1, 20]),
|
||||
)
|
||||
|
||||
assert response.content
|
||||
block = response.content[0]
|
||||
text = block["text"] if isinstance(block, dict) else block.text
|
||||
assert "Calendar Skill" in text or "name: calendar" in text
|
||||
@@ -0,0 +1,295 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||
|
||||
|
||||
def _load_env() -> None:
|
||||
env_path = Path(__file__).resolve().parents[3] / ".env"
|
||||
if env_path.exists():
|
||||
for line in env_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"').strip("'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
_load_env()
|
||||
|
||||
|
||||
def _get_jwt_secret() -> str | None:
|
||||
return (
|
||||
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
|
||||
or os.getenv("SUPABASE_JWT_SECRET")
|
||||
or os.getenv("JWT_SECRET")
|
||||
)
|
||||
|
||||
|
||||
def _get_supabase_url() -> str:
|
||||
return (
|
||||
os.getenv("SOCIAL_SUPABASE__URL")
|
||||
or os.getenv("SUPABASE_URL")
|
||||
or "http://localhost:54321"
|
||||
)
|
||||
|
||||
|
||||
def _get_test_user_id() -> str | None:
|
||||
return os.getenv("TEST_USER_ID")
|
||||
|
||||
|
||||
def _create_test_jwt(user_id: str) -> str:
|
||||
jwt_secret = _get_jwt_secret()
|
||||
if not jwt_secret:
|
||||
raise RuntimeError("JWT_SECRET not found in environment")
|
||||
|
||||
supabase_url = _get_supabase_url()
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": "authenticated",
|
||||
"aud": "authenticated",
|
||||
"iss": supabase_url,
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
}
|
||||
|
||||
return jwt.encode(payload, jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
async def _get_test_user_token() -> str:
|
||||
user_id = _get_test_user_id()
|
||||
if user_id:
|
||||
return _create_test_jwt(user_id)
|
||||
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["psql", "-t", "-A", "-c", "SELECT id FROM auth.users LIMIT 1;"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**os.environ, "PGHOST": "localhost", "PGPORT": "54322", "PGDATABASE": "postgres", "PGUSER": "postgres", "PGPASSWORD": "postgres"},
|
||||
)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
user_id = result.stdout.strip()
|
||||
return _create_test_jwt(user_id)
|
||||
|
||||
pytest.skip("Could not find test user. Set TEST_USER_ID or ensure database is accessible")
|
||||
|
||||
|
||||
async def _run_agent_and_collect_events(
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
user_message: str,
|
||||
) -> tuple[list[dict], bool]:
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={
|
||||
"threadId": thread_id,
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": user_message,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"runtime_mode": "chat"},
|
||||
},
|
||||
)
|
||||
if run_resp.status_code != 202:
|
||||
pytest.fail(f"Run request failed: {run_resp.status_code} - {run_resp.text}")
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
run_data = run_resp.json()
|
||||
effective_thread_id = str(run_data.get("threadId", thread_id))
|
||||
effective_run_id = run_data.get("runId", run_id)
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{effective_thread_id}/events?runId={effective_run_id}"
|
||||
tool_call_results: list[dict] = []
|
||||
run_finished = False
|
||||
|
||||
async with client.stream(
|
||||
"GET", events_url, headers=headers, timeout=60.0
|
||||
) as sse_resp:
|
||||
if sse_resp.status_code != 200:
|
||||
error_body = await sse_resp.aread()
|
||||
pytest.fail(f"SSE request failed: {sse_resp.status_code} - {error_body.decode()}")
|
||||
assert sse_resp.status_code == 200
|
||||
buffer = ""
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data_str = line.split(":", 1)[1].strip()
|
||||
if data_str:
|
||||
buffer = data_str
|
||||
elif line == "" and buffer:
|
||||
try:
|
||||
event_data = json.loads(buffer)
|
||||
event_type = event_data.get("type")
|
||||
if event_type == "TOOL_CALL_RESULT":
|
||||
tool_call_results.append(event_data)
|
||||
elif event_type in {"RUN_FINISHED", "RUN_ERROR"}:
|
||||
run_finished = True
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
buffer = ""
|
||||
|
||||
assert run_finished, "RUN_FINISHED or RUN_ERROR not received"
|
||||
return tool_call_results, effective_thread_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("CLI_TOOL_LIVE_INTEGRATION") != "1",
|
||||
reason="set CLI_TOOL_LIVE_INTEGRATION=1 to run live CLI tool integration test",
|
||||
)
|
||||
async def test_agent_calendar_read_via_cli() -> None:
|
||||
token = await _get_test_user_token()
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
thread_id = str(uuid4())
|
||||
|
||||
tool_call_results, _ = await _run_agent_and_collect_events(
|
||||
client=client,
|
||||
headers=headers,
|
||||
thread_id=thread_id,
|
||||
run_id="run-cli-calendar-read",
|
||||
user_message="请查询我今天的日程安排,不要猜测结果,按你的技能说明执行。",
|
||||
)
|
||||
|
||||
assert tool_call_results, "expected at least one TOOL_CALL_RESULT event"
|
||||
tool_names = [result.get("tool_name") for result in tool_call_results]
|
||||
assert "view_skill_file" in tool_names
|
||||
assert "project_cli" in tool_names
|
||||
assert tool_names.index("view_skill_file") < tool_names.index("project_cli")
|
||||
|
||||
view_result = next(
|
||||
result for result in tool_call_results if result.get("tool_name") == "view_skill_file"
|
||||
)
|
||||
assert view_result.get("status") in {"success", "failure", "partial"}
|
||||
view_args = view_result.get("tool_call_args")
|
||||
assert isinstance(view_args, dict)
|
||||
assert view_args.get("file_path") == "calendar/SKILL.md"
|
||||
|
||||
result = next(
|
||||
result for result in tool_call_results if result.get("tool_name") == "project_cli"
|
||||
)
|
||||
assert result.get("status") in {"success", "failure", "partial"}
|
||||
|
||||
tool_call_args = result.get("tool_call_args")
|
||||
assert isinstance(tool_call_args, dict)
|
||||
assert tool_call_args.get("command") == "calendar"
|
||||
assert tool_call_args.get("subcommand") == "read"
|
||||
|
||||
raw_result = result.get("result")
|
||||
if isinstance(raw_result, str):
|
||||
raw_result = json.loads(raw_result)
|
||||
assert isinstance(raw_result, dict), f"result should be dict, got {type(raw_result)}"
|
||||
assert raw_result.get("command") == "calendar"
|
||||
assert raw_result.get("subcommand") == "read"
|
||||
|
||||
if "ui_schema" in result:
|
||||
ui_schema = result["ui_schema"]
|
||||
assert isinstance(ui_schema, dict)
|
||||
assert "version" in ui_schema
|
||||
|
||||
assert "ui_hints" not in result, "ui_hints should not appear in SSE wire (replaced by ui_schema)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("CLI_TOOL_LIVE_INTEGRATION") != "1",
|
||||
reason="set CLI_TOOL_LIVE_INTEGRATION=1 to run live CLI tool integration test",
|
||||
)
|
||||
async def test_tool_ui_schema_in_history() -> None:
|
||||
token = await _get_test_user_token()
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
thread_id = str(uuid4())
|
||||
|
||||
_, effective_thread_id = await _run_agent_and_collect_events(
|
||||
client=client,
|
||||
headers=headers,
|
||||
thread_id=thread_id,
|
||||
run_id="run-cli-history-test",
|
||||
user_message="请查询我今天的日程安排,不要猜测结果,按你的技能说明执行。",
|
||||
)
|
||||
|
||||
history_resp = await client.get(
|
||||
f"{BASE_URL}/api/v1/agent/history",
|
||||
headers=headers,
|
||||
params={"threadId": effective_thread_id},
|
||||
)
|
||||
assert history_resp.status_code == 200
|
||||
history = history_resp.json()
|
||||
|
||||
assert "scope" in history
|
||||
assert "messages" in history
|
||||
|
||||
messages = history.get("messages", [])
|
||||
tool_messages = [
|
||||
m
|
||||
for m in messages
|
||||
if isinstance(m, dict) and m.get("role") == "tool"
|
||||
]
|
||||
|
||||
assert tool_messages, "expected at least one tool message in history"
|
||||
found_project_cli = False
|
||||
found_view_skill_file = False
|
||||
for tool_msg in tool_messages:
|
||||
metadata = tool_msg.get("metadata", {})
|
||||
tool_agent_output = metadata.get("tool_agent_output")
|
||||
if not tool_agent_output:
|
||||
continue
|
||||
tool_name = tool_agent_output.get("tool_name")
|
||||
assert tool_name in {"project_cli", "view_skill_file"}
|
||||
assert "result" in tool_agent_output
|
||||
assert "status" in tool_agent_output
|
||||
|
||||
if tool_name == "view_skill_file":
|
||||
tool_call_args = tool_agent_output.get("tool_call_args")
|
||||
assert isinstance(tool_call_args, dict)
|
||||
assert tool_call_args.get("file_path") == "calendar/SKILL.md"
|
||||
found_view_skill_file = True
|
||||
continue
|
||||
|
||||
result = tool_agent_output.get("result")
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
result = json.loads(result)
|
||||
tool_agent_output["result"] = result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
assert isinstance(result, dict), f"result in DB should be dict, got {type(result)}: {result!r}"
|
||||
assert result.get("command") == "calendar"
|
||||
assert result.get("subcommand") == "read"
|
||||
|
||||
ui_hints = tool_agent_output.get("ui_hints")
|
||||
assert isinstance(ui_hints, dict), f"ui_hints should be dict, got {type(ui_hints)}"
|
||||
found_project_cli = True
|
||||
assert found_view_skill_file, "expected persisted view_skill_file tool output"
|
||||
assert found_project_cli, "expected persisted project_cli tool output"
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -9,27 +8,34 @@ import httpx
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from schemas.enums import AgentChatMessageRole
|
||||
|
||||
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||||
BASE_URL = f"http://localhost:{5775}"
|
||||
FIXTURE_IMAGE_PATH = (
|
||||
Path(__file__).resolve().parents[3] / "fixtures" / "images" / "calendar_text_cn.png"
|
||||
)
|
||||
|
||||
|
||||
def _require_test_phone() -> str:
|
||||
phone = config.test.phone
|
||||
if not phone:
|
||||
pytest.fail("SOCIAL_TEST__PHONE is required for live integration tests")
|
||||
return phone
|
||||
|
||||
|
||||
async def _live_access_token(client: httpx.AsyncClient) -> str:
|
||||
phone = os.getenv("AGENT_LIVE_PHONE")
|
||||
password = os.getenv("AGENT_LIVE_PASSWORD")
|
||||
if not phone or not password:
|
||||
pytest.fail(
|
||||
"AGENT_LIVE_INTEGRATION=1 requires AGENT_LIVE_PHONE and AGENT_LIVE_PASSWORD"
|
||||
)
|
||||
phone = _require_test_phone()
|
||||
if not phone.startswith("+"):
|
||||
phone = f"+{phone}"
|
||||
code = config.test.code or "000000"
|
||||
|
||||
response = await client.post(
|
||||
f"{BASE_URL}/api/v1/auth/sessions",
|
||||
json={"phone": phone, "password": password},
|
||||
f"{BASE_URL}/api/v1/auth/phone-session",
|
||||
json={"phone": phone, "token": code},
|
||||
)
|
||||
response_text = response.text.strip().replace("\n", " ")
|
||||
truncated_text = response_text[:200]
|
||||
@@ -48,8 +54,8 @@ async def _live_access_token(client: httpx.AsyncClient) -> str:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
async def test_agent_sse_closed_loop_live() -> None:
|
||||
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
||||
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
||||
if config.runtime.environment not in {"dev", "test"}:
|
||||
pytest.skip("live integration tests require dev or test environment")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
token = await _live_access_token(client)
|
||||
@@ -67,7 +73,7 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "chat"},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
@@ -110,8 +116,8 @@ async def test_agent_sse_closed_loop_live() -> None:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
if os.getenv("AGENT_LIVE_INTEGRATION") != "1":
|
||||
pytest.skip("set AGENT_LIVE_INTEGRATION=1 to run live integration test")
|
||||
if config.runtime.environment not in {"dev", "test"}:
|
||||
pytest.skip("live integration tests require dev or test environment")
|
||||
|
||||
image_data = base64.b64encode(FIXTURE_IMAGE_PATH.read_bytes()).decode("ascii")
|
||||
|
||||
@@ -143,7 +149,7 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"agent_type": "worker"},
|
||||
"forwardedProps": {"runtime_mode": "chat"},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
@@ -221,3 +227,78 @@ async def test_agent_runs_events_history_live_with_image_input() -> None:
|
||||
assert user_attachments
|
||||
assert isinstance(user_attachments[0], dict)
|
||||
assert isinstance(user_attachments[0].get("path"), str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.live
|
||||
async def test_agent_tool_call_result_persisted_live() -> None:
|
||||
if config.runtime.environment not in {"dev", "test"}:
|
||||
pytest.skip("live integration tests require dev or test environment")
|
||||
|
||||
thread_id = str(uuid4())
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
token = await _live_access_token(client)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
run_resp = await client.post(
|
||||
f"{BASE_URL}/api/v1/agent/runs",
|
||||
headers=headers,
|
||||
json={
|
||||
"threadId": thread_id,
|
||||
"runId": "run-tool-verify-1",
|
||||
"state": {},
|
||||
"messages": [
|
||||
{
|
||||
"id": "u1",
|
||||
"role": "user",
|
||||
"content": "帮我查一下明天有哪些日程安排",
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {"runtime_mode": "chat"},
|
||||
},
|
||||
)
|
||||
assert run_resp.status_code == 202
|
||||
|
||||
accepted = run_resp.json()
|
||||
assert str(accepted["threadId"]) == thread_id
|
||||
|
||||
events_url = f"{BASE_URL}/api/v1/agent/runs/{thread_id}/events?runId=run-tool-verify-1"
|
||||
event_names: list[str] = []
|
||||
async with client.stream(
|
||||
"GET", events_url, headers=headers, timeout=90.0
|
||||
) as sse_resp:
|
||||
assert sse_resp.status_code == 200
|
||||
async for line in sse_resp.aiter_lines():
|
||||
if line.startswith("event:"):
|
||||
event_name = line.split(":", 1)[1].strip()
|
||||
event_names.append(event_name)
|
||||
if event_name in {"RUN_FINISHED", "RUN_ERROR"}:
|
||||
break
|
||||
|
||||
assert "RUN_STARTED" in event_names, (
|
||||
f"missing RUN_STARTED, got: {event_names}"
|
||||
)
|
||||
|
||||
finished_ok = "RUN_FINISHED" in event_names
|
||||
finished_err = "RUN_ERROR" in event_names
|
||||
assert finished_ok or finished_err, (
|
||||
f"no terminal event, got: {event_names}"
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
rows = await session.execute(
|
||||
select(AgentChatMessage).where(
|
||||
AgentChatMessage.session_id == UUID(thread_id),
|
||||
AgentChatMessage.role == AgentChatMessageRole.TOOL,
|
||||
)
|
||||
)
|
||||
tool_messages = list(rows.scalars().all())
|
||||
|
||||
if finished_ok:
|
||||
assert len(tool_messages) >= 1, (
|
||||
f"expected >=1 role='tool' message but found {len(tool_messages)}. "
|
||||
f"SSE events: {event_names}"
|
||||
)
|
||||
|
||||
@@ -33,13 +33,13 @@ def _make_job_response(
|
||||
status=overrides.get("status", "active"),
|
||||
is_system=overrides.get("is_system", False),
|
||||
config=overrides.get(
|
||||
"config",
|
||||
{
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"config",
|
||||
{
|
||||
"input_template": "Hello",
|
||||
"enabled_skills": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
"window_count": 2,
|
||||
},
|
||||
"schedule": {
|
||||
@@ -118,7 +118,7 @@ def test_create_automation_job_requires_auth() -> None:
|
||||
"timezone": "Asia/Shanghai",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"enabled_skills": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
@@ -161,7 +161,7 @@ def test_create_automation_job_succeeds() -> None:
|
||||
"status": "active",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"enabled_skills": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
@@ -205,7 +205,7 @@ def test_create_automation_job_respects_limit() -> None:
|
||||
"status": "active",
|
||||
"config": {
|
||||
"input_template": "Hello",
|
||||
"enabled_tools": [],
|
||||
"enabled_skills": [],
|
||||
"context": {
|
||||
"source": "latest_chat",
|
||||
"window_mode": "day",
|
||||
|
||||
Reference in New Issue
Block a user