feat(agent): complete closed-loop runtime and pricing fallback

This commit is contained in:
qzl
2026-03-05 15:34:37 +08:00
parent b02a322bf3
commit b486e78ff3
67 changed files with 3832 additions and 7 deletions
+161
View File
@@ -0,0 +1,161 @@
#!/usr/bin/env python3
"""Live diagnostic script for Agent Run -> SSE closed loop."""
from __future__ import annotations
import argparse
import asyncio
import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
from uuid import UUID
import httpx
import jwt
from sqlalchemy import select
backend_src = Path(__file__).parent / "backend" / "src"
sys.path.insert(0, str(backend_src))
os.environ.setdefault("PYTHONPATH", str(backend_src))
from core.config import config # noqa: E402
from core.db.session import AsyncSessionLocal # noqa: E402
from models.agent_chat_message import AgentChatMessage # noqa: E402
from models.agent_chat_session import AgentChatSession # noqa: E402
from models.profile import Profile # noqa: E402
BASE_URL = "http://localhost:5775"
def _print_step(title: str) -> None:
print(f"\n=== {title} ===")
async def get_owner_id() -> UUID:
async with AsyncSessionLocal() as session:
owner_id = (await session.execute(select(Profile.id).limit(1))).scalar_one()
return owner_id
def create_jwt_token(user_id: UUID) -> str:
supabase_url = config.supabase.public_url.rstrip("/")
payload = {
"sub": str(user_id),
"role": "authenticated",
"aud": "authenticated",
"iss": f"{supabase_url}/auth/v1",
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
}
jwt_secret = config.supabase.jwt_secret
if not jwt_secret:
raise ValueError("JWT secret not configured")
return jwt.encode(payload, jwt_secret, algorithm="HS256")
async def assert_db_state(session_id: str) -> None:
_print_step("DB Assertions")
session_uuid = UUID(session_id)
async with AsyncSessionLocal() as session:
chat_session = await session.get(AgentChatSession, session_uuid)
if chat_session is None:
raise RuntimeError("session row not found")
print(f"session.status={chat_session.status}")
print(f"session.message_count={chat_session.message_count}")
print(f"session.total_tokens={chat_session.total_tokens}")
print(f"session.total_cost={chat_session.total_cost}")
rows = await session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == session_uuid)
.order_by(AgentChatMessage.seq.asc())
)
messages = list(rows.scalars().all())
print(f"messages.count={len(messages)}")
if messages:
first = messages[0]
last = messages[-1]
print(f"messages.first_role={first.role}")
print(f"messages.last_role={last.role}")
async def run_closed_loop(*, prompt: str, reuse_session: str | None) -> None:
_print_step("Prepare Auth")
owner_id = await get_owner_id()
token = create_jwt_token(owner_id)
headers = {"Authorization": f"Bearer {token}"}
print(f"owner_id={owner_id}")
async with httpx.AsyncClient(timeout=30.0) as client:
_print_step("Submit Run")
payload: dict[str, object] = {"prompt": prompt}
if reuse_session:
payload["session_id"] = reuse_session
try:
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs", headers=headers, json=payload
)
except (httpx.ConnectError, httpx.ConnectTimeout) as exc:
raise RuntimeError(
"web service unreachable; start runtime via infra/scripts/app.sh start"
) from exc
print(f"run.status={run_resp.status_code}")
if run_resp.status_code != 202:
raise RuntimeError(f"run failed: {run_resp.text}")
accepted = run_resp.json()
session_id = str(accepted["session_id"])
task_id = str(accepted["task_id"])
created = bool(accepted.get("created", False))
print(f"task_id={task_id}")
print(f"session_id={session_id}")
print(f"created={created}")
_print_step("Subscribe SSE")
events_url = f"{BASE_URL}/api/v1/agent/runs/{session_id}/events"
events: list[str] = []
async with client.stream(
"GET", events_url, headers=headers, timeout=20.0
) as sse_resp:
print(f"events.status={sse_resp.status_code}")
print(f"events.content_type={sse_resp.headers.get('content-type')}")
if sse_resp.status_code != 200:
raise RuntimeError(f"events failed: {await sse_resp.aread()}")
async for line in sse_resp.aiter_lines():
if not line.strip():
continue
print(line)
if line.startswith("event:"):
event_name = line.split(":", 1)[1].strip()
events.append(event_name)
_print_step("Event Checks")
print(f"events={events}")
if "RUN_STARTED" not in events:
raise RuntimeError("missing RUN_STARTED")
if "RUN_FINISHED" not in events and "RUN_ERROR" not in events:
raise RuntimeError("missing final event")
await assert_db_state(session_id)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Agent closed-loop live diagnostic")
parser.add_argument("--prompt", default="你好,请介绍一下你自己")
parser.add_argument("--reuse-session", default=None)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
try:
asyncio.run(
run_closed_loop(prompt=args.prompt, reuse_session=args.reuse_session)
)
except Exception as exc: # noqa: BLE001
print(f"\nERROR: {exc}")
sys.exit(1)