197 lines
5.3 KiB
Python
197 lines
5.3 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
from pathlib import Path
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
import jwt
|
||
|
|
|
||
|
|
|
||
|
|
def _load_env() -> None:
|
||
|
|
env_path = Path(__file__).resolve().parents[3] / ".env"
|
||
|
|
if env_path.exists():
|
||
|
|
for line in env_path.read_text().splitlines():
|
||
|
|
line = line.strip()
|
||
|
|
if not line or line.startswith("#") or "=" not in line:
|
||
|
|
continue
|
||
|
|
key, _, value = line.partition("=")
|
||
|
|
key = key.strip()
|
||
|
|
value = value.strip().strip('"').strip("'")
|
||
|
|
if key and key not in os.environ:
|
||
|
|
os.environ[key] = value
|
||
|
|
|
||
|
|
|
||
|
|
_load_env()
|
||
|
|
|
||
|
|
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
|
||
|
|
|
||
|
|
|
||
|
|
def get_jwt_secret() -> str:
|
||
|
|
secret = (
|
||
|
|
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
|
||
|
|
or os.getenv("SUPABASE_JWT_SECRET")
|
||
|
|
or os.getenv("JWT_SECRET")
|
||
|
|
)
|
||
|
|
if not secret:
|
||
|
|
raise RuntimeError("JWT_SECRET not found in environment")
|
||
|
|
return secret
|
||
|
|
|
||
|
|
|
||
|
|
def get_supabase_url() -> str:
|
||
|
|
return (
|
||
|
|
os.getenv("SOCIAL_SUPABASE__URL")
|
||
|
|
or os.getenv("SUPABASE_URL")
|
||
|
|
or "http://localhost:54321"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_test_user_id() -> str:
|
||
|
|
user_id = os.getenv("TEST_USER_ID")
|
||
|
|
if user_id:
|
||
|
|
return user_id
|
||
|
|
raise RuntimeError("TEST_USER_ID not set")
|
||
|
|
|
||
|
|
|
||
|
|
def create_test_jwt(user_id: str) -> str:
|
||
|
|
now = int(time.time())
|
||
|
|
payload = {
|
||
|
|
"sub": user_id,
|
||
|
|
"role": "authenticated",
|
||
|
|
"aud": "authenticated",
|
||
|
|
"iss": get_supabase_url(),
|
||
|
|
"iat": now,
|
||
|
|
"exp": now + 3600,
|
||
|
|
}
|
||
|
|
return jwt.encode(payload, get_jwt_secret(), algorithm="HS256")
|
||
|
|
|
||
|
|
|
||
|
|
async def run_agent_and_collect(
|
||
|
|
*,
|
||
|
|
user_message: str,
|
||
|
|
client: httpx.AsyncClient,
|
||
|
|
headers: dict,
|
||
|
|
run_id: str | None = None,
|
||
|
|
thread_id: str | None = None,
|
||
|
|
timeout: float = 120.0,
|
||
|
|
) -> AgentRunResult:
|
||
|
|
if thread_id is None:
|
||
|
|
thread_id = str(uuid4())
|
||
|
|
if run_id is None:
|
||
|
|
run_id = f"quality-{thread_id[:8]}"
|
||
|
|
|
||
|
|
t_start = time.monotonic()
|
||
|
|
|
||
|
|
run_resp = await client.post(
|
||
|
|
f"{BASE_URL}/api/v1/agent/runs",
|
||
|
|
headers=headers,
|
||
|
|
json={
|
||
|
|
"threadId": thread_id,
|
||
|
|
"runId": run_id,
|
||
|
|
"state": {},
|
||
|
|
"messages": [
|
||
|
|
{"id": "u1", "role": "user", "content": user_message}
|
||
|
|
],
|
||
|
|
"tools": [],
|
||
|
|
"context": [],
|
||
|
|
"forwardedProps": {"runtime_mode": "chat"},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
run_data = run_resp.json()
|
||
|
|
effective_thread_id = str(run_data.get("threadId", thread_id))
|
||
|
|
effective_run_id = run_data.get("runId", run_id)
|
||
|
|
|
||
|
|
events_url = (
|
||
|
|
f"{BASE_URL}/api/v1/agent/runs/{effective_thread_id}/events"
|
||
|
|
f"?runId={effective_run_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
import json
|
||
|
|
|
||
|
|
tool_results: list[dict] = []
|
||
|
|
all_events: list[dict] = []
|
||
|
|
run_finished = False
|
||
|
|
final_answer = ""
|
||
|
|
|
||
|
|
async with client.stream(
|
||
|
|
"GET", events_url, headers=headers, timeout=timeout
|
||
|
|
) as sse_resp:
|
||
|
|
buffer = ""
|
||
|
|
async for line in sse_resp.aiter_lines():
|
||
|
|
if line.startswith("data:"):
|
||
|
|
data_str = line.split(":", 1)[1].strip()
|
||
|
|
if data_str:
|
||
|
|
buffer = data_str
|
||
|
|
elif line == "" and buffer:
|
||
|
|
try:
|
||
|
|
event_data = json.loads(buffer)
|
||
|
|
event_type = event_data.get("type")
|
||
|
|
all_events.append(event_data)
|
||
|
|
|
||
|
|
if event_type == "TOOL_CALL_RESULT":
|
||
|
|
tool_results.append(event_data)
|
||
|
|
elif event_type == "TEXT_MESSAGE_END":
|
||
|
|
final_answer = event_data.get("answer", "") or event_data.get("text", "")
|
||
|
|
elif event_type in {"RUN_FINISHED", "RUN_ERROR"}:
|
||
|
|
run_finished = True
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
buffer = ""
|
||
|
|
|
||
|
|
t_end = time.monotonic()
|
||
|
|
|
||
|
|
return AgentRunResult(
|
||
|
|
thread_id=effective_thread_id,
|
||
|
|
run_id=effective_run_id,
|
||
|
|
user_message=user_message,
|
||
|
|
final_answer=final_answer,
|
||
|
|
tool_results=tool_results,
|
||
|
|
all_events=all_events,
|
||
|
|
run_finished=run_finished,
|
||
|
|
latency_ms=round((t_end - t_start) * 1000),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class AgentRunResult:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
thread_id: str,
|
||
|
|
run_id: str,
|
||
|
|
user_message: str,
|
||
|
|
final_answer: str,
|
||
|
|
tool_results: list[dict],
|
||
|
|
all_events: list[dict],
|
||
|
|
run_finished: bool,
|
||
|
|
latency_ms: int,
|
||
|
|
) -> None:
|
||
|
|
self.thread_id = thread_id
|
||
|
|
self.run_id = run_id
|
||
|
|
self.user_message = user_message
|
||
|
|
self.final_answer = final_answer
|
||
|
|
self.tool_results = tool_results
|
||
|
|
self.all_events = all_events
|
||
|
|
self.run_finished = run_finished
|
||
|
|
self.latency_ms = latency_ms
|
||
|
|
|
||
|
|
@property
|
||
|
|
def tool_names_called(self) -> list[str]:
|
||
|
|
return [
|
||
|
|
tr.get("tool_name", "") or tr.get("toolName", "")
|
||
|
|
for tr in self.tool_results
|
||
|
|
]
|
||
|
|
|
||
|
|
@property
|
||
|
|
def successful_tool_names(self) -> list[str]:
|
||
|
|
return [
|
||
|
|
tr.get("tool_name", "") or tr.get("toolName", "")
|
||
|
|
for tr in self.tool_results
|
||
|
|
if tr.get("status") in ("success", "partial")
|
||
|
|
]
|
||
|
|
|
||
|
|
@property
|
||
|
|
def has_tool_success(self) -> bool:
|
||
|
|
return len(self.successful_tool_names) > 0
|