Files
social-app/backend/tests/quality/conftest.py
T

197 lines
5.3 KiB
Python
Raw Normal View History

from __future__ import annotations
import os
import time
from pathlib import Path
from uuid import uuid4
import httpx
import jwt
def _load_env() -> None:
env_path = Path(__file__).resolve().parents[3] / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_env()
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
def get_jwt_secret() -> str:
secret = (
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
or os.getenv("SUPABASE_JWT_SECRET")
or os.getenv("JWT_SECRET")
)
if not secret:
raise RuntimeError("JWT_SECRET not found in environment")
return secret
def get_supabase_url() -> str:
return (
os.getenv("SOCIAL_SUPABASE__URL")
or os.getenv("SUPABASE_URL")
or "http://localhost:54321"
)
def get_test_user_id() -> str:
user_id = os.getenv("TEST_USER_ID")
if user_id:
return user_id
raise RuntimeError("TEST_USER_ID not set")
def create_test_jwt(user_id: str) -> str:
now = int(time.time())
payload = {
"sub": user_id,
"role": "authenticated",
"aud": "authenticated",
"iss": get_supabase_url(),
"iat": now,
"exp": now + 3600,
}
return jwt.encode(payload, get_jwt_secret(), algorithm="HS256")
async def run_agent_and_collect(
*,
user_message: str,
client: httpx.AsyncClient,
headers: dict,
run_id: str | None = None,
thread_id: str | None = None,
timeout: float = 120.0,
) -> AgentRunResult:
if thread_id is None:
thread_id = str(uuid4())
if run_id is None:
run_id = f"quality-{thread_id[:8]}"
t_start = time.monotonic()
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": run_id,
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": user_message}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
run_data = run_resp.json()
effective_thread_id = str(run_data.get("threadId", thread_id))
effective_run_id = run_data.get("runId", run_id)
events_url = (
f"{BASE_URL}/api/v1/agent/runs/{effective_thread_id}/events"
f"?runId={effective_run_id}"
)
import json
tool_results: list[dict] = []
all_events: list[dict] = []
run_finished = False
final_answer = ""
async with client.stream(
"GET", events_url, headers=headers, timeout=timeout
) as sse_resp:
buffer = ""
async for line in sse_resp.aiter_lines():
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if data_str:
buffer = data_str
elif line == "" and buffer:
try:
event_data = json.loads(buffer)
event_type = event_data.get("type")
all_events.append(event_data)
if event_type == "TOOL_CALL_RESULT":
tool_results.append(event_data)
elif event_type == "TEXT_MESSAGE_END":
final_answer = event_data.get("answer", "") or event_data.get("text", "")
elif event_type in {"RUN_FINISHED", "RUN_ERROR"}:
run_finished = True
except json.JSONDecodeError:
pass
buffer = ""
t_end = time.monotonic()
return AgentRunResult(
thread_id=effective_thread_id,
run_id=effective_run_id,
user_message=user_message,
final_answer=final_answer,
tool_results=tool_results,
all_events=all_events,
run_finished=run_finished,
latency_ms=round((t_end - t_start) * 1000),
)
class AgentRunResult:
def __init__(
self,
*,
thread_id: str,
run_id: str,
user_message: str,
final_answer: str,
tool_results: list[dict],
all_events: list[dict],
run_finished: bool,
latency_ms: int,
) -> None:
self.thread_id = thread_id
self.run_id = run_id
self.user_message = user_message
self.final_answer = final_answer
self.tool_results = tool_results
self.all_events = all_events
self.run_finished = run_finished
self.latency_ms = latency_ms
@property
def tool_names_called(self) -> list[str]:
return [
tr.get("tool_name", "") or tr.get("toolName", "")
for tr in self.tool_results
]
@property
def successful_tool_names(self) -> list[str]:
return [
tr.get("tool_name", "") or tr.get("toolName", "")
for tr in self.tool_results
if tr.get("status") in ("success", "partial")
]
@property
def has_tool_success(self) -> bool:
return len(self.successful_tool_names) > 0