162 lines
5.6 KiB
Python
162 lines
5.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Live diagnostic script for Agent Run -> SSE closed loop."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from uuid import UUID
|
|
|
|
import httpx
|
|
import jwt
|
|
from sqlalchemy import select
|
|
|
|
backend_src = Path(__file__).parent / "backend" / "src"
|
|
sys.path.insert(0, str(backend_src))
|
|
os.environ.setdefault("PYTHONPATH", str(backend_src))
|
|
|
|
from core.config import config # noqa: E402
|
|
from core.db.session import AsyncSessionLocal # noqa: E402
|
|
from models.agent_chat_message import AgentChatMessage # noqa: E402
|
|
from models.agent_chat_session import AgentChatSession # noqa: E402
|
|
from models.profile import Profile # noqa: E402
|
|
|
|
BASE_URL = "http://localhost:5775"
|
|
|
|
|
|
def _print_step(title: str) -> None:
|
|
print(f"\n=== {title} ===")
|
|
|
|
|
|
async def get_owner_id() -> UUID:
|
|
async with AsyncSessionLocal() as session:
|
|
owner_id = (await session.execute(select(Profile.id).limit(1))).scalar_one()
|
|
return owner_id
|
|
|
|
|
|
def create_jwt_token(user_id: UUID) -> str:
|
|
supabase_url = config.supabase.public_url.rstrip("/")
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"role": "authenticated",
|
|
"aud": "authenticated",
|
|
"iss": f"{supabase_url}/auth/v1",
|
|
"iat": datetime.now(timezone.utc),
|
|
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
|
|
}
|
|
jwt_secret = config.supabase.jwt_secret
|
|
if not jwt_secret:
|
|
raise ValueError("JWT secret not configured")
|
|
return jwt.encode(payload, jwt_secret, algorithm="HS256")
|
|
|
|
|
|
async def assert_db_state(session_id: str) -> None:
|
|
_print_step("DB Assertions")
|
|
session_uuid = UUID(session_id)
|
|
async with AsyncSessionLocal() as session:
|
|
chat_session = await session.get(AgentChatSession, session_uuid)
|
|
if chat_session is None:
|
|
raise RuntimeError("session row not found")
|
|
|
|
print(f"session.status={chat_session.status}")
|
|
print(f"session.message_count={chat_session.message_count}")
|
|
print(f"session.total_tokens={chat_session.total_tokens}")
|
|
print(f"session.total_cost={chat_session.total_cost}")
|
|
|
|
rows = await session.execute(
|
|
select(AgentChatMessage)
|
|
.where(AgentChatMessage.session_id == session_uuid)
|
|
.order_by(AgentChatMessage.seq.asc())
|
|
)
|
|
messages = list(rows.scalars().all())
|
|
print(f"messages.count={len(messages)}")
|
|
if messages:
|
|
first = messages[0]
|
|
last = messages[-1]
|
|
print(f"messages.first_role={first.role}")
|
|
print(f"messages.last_role={last.role}")
|
|
|
|
|
|
async def run_closed_loop(*, prompt: str, reuse_session: str | None) -> None:
|
|
_print_step("Prepare Auth")
|
|
owner_id = await get_owner_id()
|
|
token = create_jwt_token(owner_id)
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
print(f"owner_id={owner_id}")
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
_print_step("Submit Run")
|
|
payload: dict[str, object] = {"prompt": prompt}
|
|
if reuse_session:
|
|
payload["session_id"] = reuse_session
|
|
|
|
try:
|
|
run_resp = await client.post(
|
|
f"{BASE_URL}/api/v1/agent/runs", headers=headers, json=payload
|
|
)
|
|
except (httpx.ConnectError, httpx.ConnectTimeout) as exc:
|
|
raise RuntimeError(
|
|
"web service unreachable; start runtime via infra/scripts/app.sh start"
|
|
) from exc
|
|
print(f"run.status={run_resp.status_code}")
|
|
if run_resp.status_code != 202:
|
|
raise RuntimeError(f"run failed: {run_resp.text}")
|
|
|
|
accepted = run_resp.json()
|
|
session_id = str(accepted["session_id"])
|
|
task_id = str(accepted["task_id"])
|
|
created = bool(accepted.get("created", False))
|
|
print(f"task_id={task_id}")
|
|
print(f"session_id={session_id}")
|
|
print(f"created={created}")
|
|
|
|
_print_step("Subscribe SSE")
|
|
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
|
|
events: list[str] = []
|
|
async with client.stream(
|
|
"GET", events_url, headers=headers, timeout=20.0
|
|
) as sse_resp:
|
|
print(f"events.status={sse_resp.status_code}")
|
|
print(f"events.content_type={sse_resp.headers.get('content-type')}")
|
|
if sse_resp.status_code != 200:
|
|
raise RuntimeError(f"events failed: {await sse_resp.aread()}")
|
|
|
|
async for line in sse_resp.aiter_lines():
|
|
if not line.strip():
|
|
continue
|
|
print(line)
|
|
if line.startswith("event:"):
|
|
event_name = line.split(":", 1)[1].strip()
|
|
events.append(event_name)
|
|
|
|
_print_step("Event Checks")
|
|
print(f"events={events}")
|
|
if "RUN_STARTED" not in events:
|
|
raise RuntimeError("missing RUN_STARTED")
|
|
if "RUN_FINISHED" not in events and "RUN_ERROR" not in events:
|
|
raise RuntimeError("missing final event")
|
|
|
|
await assert_db_state(session_id)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Agent closed-loop live diagnostic")
|
|
parser.add_argument("--prompt", default="你好,请介绍一下你自己")
|
|
parser.add_argument("--reuse-session", default=None)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
try:
|
|
asyncio.run(
|
|
run_closed_loop(prompt=args.prompt, reuse_session=args.reuse_session)
|
|
)
|
|
except Exception as exc: # noqa: BLE001
|
|
print(f"\nERROR: {exc}")
|
|
sys.exit(1)
|