Files
social-app/backend/tests/quality/test_model_ab.py
T

441 lines
13 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
import os
import time
from uuid import uuid4
import httpx
import jwt
import pytest
from backend.tests.quality.evaluators import ModelScorecard, ScoreDetail, ScenarioScore
from backend.tests.quality.scenarios import ALL_SCENARIOS
CANDIDATE_MODELS = ["qwen3.5-flash", "deepseek-chat"]
MODEL_LLM_IDS = {
"qwen3.5-flash": "c625bce4-970e-4a76-bebe-cb8840fed854",
"deepseek-chat": "12bc1963-4b67-404b-b952-5948bea0f690",
}
BASE_URL = os.getenv("AGENT_LIVE_BASE_URL", "http://localhost:5775")
def _load_env() -> None:
from pathlib import Path
env_path = Path(__file__).resolve().parents[3] / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_env()
def _get_jwt_secret() -> str:
secret = (
os.getenv("SOCIAL_SUPABASE__JWT_SECRET")
or os.getenv("SUPABASE_JWT_SECRET")
or os.getenv("JWT_SECRET")
)
if not secret:
raise RuntimeError("JWT_SECRET not found in environment")
return secret
def _get_supabase_url() -> str:
return (
os.getenv("SOCIAL_SUPABASE__PUBLIC_URL")
or os.getenv("SOCIAL_SUPABASE__URL")
or os.getenv("SUPABASE_URL")
or "http://localhost:54321"
)
def _get_supabase_key() -> str:
from core.config.settings import config
key = os.getenv("SOCIAL_SUPABASE__SERVICE_ROLE_KEY", "")
if key:
return key
return config.supabase.service_role_key
def _get_test_user_id() -> str:
user_id = os.getenv("TEST_USER_ID")
if user_id:
return user_id
raise RuntimeError("TEST_USER_ID not set")
def _create_jwt(user_id: str) -> str:
now = int(time.time())
payload = {
"sub": user_id,
"role": "authenticated",
"aud": "authenticated",
"iss": _get_supabase_url(),
"iat": now,
"exp": now + 3600,
}
return jwt.encode(payload, _get_jwt_secret(), algorithm="HS256")
async def _run_via_http(
*,
user_message: str,
token: str,
timeout: float = 120.0,
) -> dict:
thread_id = str(uuid4())
run_id = f"q-{uuid4().hex[:12]}"
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
headers = {"Authorization": f"Bearer {token}"}
run_resp = await client.post(
f"{BASE_URL}/api/v1/agent/runs",
headers=headers,
json={
"threadId": thread_id,
"runId": run_id,
"state": {},
"messages": [
{"id": "u1", "role": "user", "content": user_message}
],
"tools": [],
"context": [],
"forwardedProps": {"runtime_mode": "chat"},
},
)
run_data = run_resp.json()
eff_thread = str(run_data.get("threadId", thread_id))
eff_run = run_data.get("runId", run_id)
events_url = (
f"{BASE_URL}/api/v1/agent/runs/{eff_thread}/events"
f"?runId={eff_run}"
)
t_start = time.monotonic()
tool_results: list[dict] = []
all_events: list[dict] = []
final_answer = ""
run_finished = False
token_usage: dict = {}
async with client.stream(
"GET", events_url, headers=headers, timeout=timeout
) as sse:
buffer = ""
async for line in sse.aiter_lines():
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if data_str:
buffer = data_str
elif line == "" and buffer:
try:
ev = json.loads(buffer)
all_events.append(ev)
etype = ev.get("type")
if etype == "TOOL_CALL_RESULT":
tool_results.append(ev)
elif etype == "TEXT_MESSAGE_END":
final_answer = ev.get("answer", "") or ev.get("text", "")
token_usage = {
"totalTokens": ev.get("totalTokens", 0),
"inputTokens": ev.get("inputTokens", 0),
"outputTokens": ev.get("outputTokens", 0),
"promptCacheMissTokens": ev.get(
"promptCacheMissTokens", 0
),
"promptCacheHitTokens": ev.get(
"promptCacheHitTokens", 0
),
}
elif etype in {"RUN_FINISHED", "RUN_ERROR"}:
run_finished = True
except json.JSONDecodeError:
pass
buffer = ""
t_end = time.monotonic()
tool_names = [
tr.get("tool_name", "") or tr.get("toolName", "")
for tr in tool_results
]
successful_tool_names = [
tr.get("tool_name", "") or tr.get("toolName", "")
for tr in tool_results
if tr.get("status") in ("success", "partial")
]
return {
"final_answer": final_answer,
"tool_results": tool_results,
"tool_names": tool_names,
"successful_tool_names": successful_tool_names,
"run_finished": run_finished,
"latency_ms": round((t_end - t_start) * 1000),
"token_usage": token_usage,
}
def _switch_model(model_code: str) -> None:
from supabase import create_client
sb = create_client(_get_supabase_url(), _get_supabase_key())
llm_id = MODEL_LLM_IDS[model_code]
for agent_type in ("router", "worker"):
(
sb.table("system_agents")
.update({"llm_id": llm_id})
.eq("agent_type", agent_type)
.execute()
)
def _save_original_models() -> list[dict]:
from supabase import create_client
sb = create_client(_get_supabase_url(), _get_supabase_key())
return (
sb.table("system_agents")
.select("agent_type, llm_id")
.execute()
.data
)
def _restore_models(original_rows: list[dict]) -> None:
from supabase import create_client
sb = create_client(_get_supabase_url(), _get_supabase_key())
for row in original_rows:
(
sb.table("system_agents")
.update({"llm_id": row["llm_id"]})
.eq("agent_type", row["agent_type"])
.execute()
)
def _evaluate_answer_quality(
*,
answer: str,
run_finished: bool,
expect_tool_use: bool,
has_tool_success: bool,
tool_names: list[str],
) -> float:
if not run_finished:
return 0.0
if not answer or not answer.strip():
return 0.0
score = 0.6
if expect_tool_use:
if has_tool_success:
score += 0.2
elif tool_names:
score += 0.1
else:
score -= 0.3
else:
if not tool_names:
score += 0.2
else:
score -= 0.1
if len(answer) > 10:
score += 0.1
if "无法" in answer or "失败" in answer or "错误" in answer:
if expect_tool_use:
score -= 0.1
return max(0.0, min(1.0, score))
def _evaluate_criteria(
*,
answer: str,
run_finished: bool,
tool_names: list[str],
has_tool_success: bool,
tool_results: list[dict],
scenario: object,
) -> list[ScoreDetail]:
details: list[ScoreDetail] = []
for criterion in getattr(scenario, "quality_criteria", []):
passed = False
note = ""
if "调用" in criterion or "project_cli" in criterion:
passed = any("project_cli" in tn for tn in tool_names)
note = f"tools: {tool_names}" if not passed else ""
elif "mode" in criterion and "day" in criterion:
for tr in tool_results:
args = tr.get("tool_call_args", {}) or tr.get("toolCallArgs", {})
inp = args.get("input", {})
if isinstance(inp, dict) and inp.get("mode") == "day":
passed = True
break
elif "具体" in criterion or "时间戳" in criterion:
passed = has_tool_success
elif "基于工具" in criterion or "返回" in criterion:
passed = has_tool_success
elif "无日程" in criterion:
passed = "" in answer or "没有" in answer
elif "简短" in criterion or "简洁" in criterion:
passed = 0 < len(answer) < 200
elif "自我介绍" in criterion:
passed = "Linksy" in answer or "助手" in answer
elif "礼貌" in criterion:
passed = len(answer) > 0
else:
passed = run_finished and len(answer) > 0
details.append(ScoreDetail(criterion=criterion, passed=passed, note=note))
return details
async def _run_model_scenarios(model_code: str, user_id: str) -> ModelScorecard:
from services.llm_pricing.service import LlmPricingService
pricing = LlmPricingService()
token = _create_jwt(user_id)
scores: list[ScenarioScore] = []
for scenario in ALL_SCENARIOS:
result = await _run_via_http(
user_message=scenario.prompt,
token=token,
)
answer = result["final_answer"]
tool_names = result["tool_names"]
has_tool_success = len(result["successful_tool_names"]) > 0
tu = result["token_usage"]
total_tokens = tu.get("totalTokens", 0)
input_tokens = tu.get("inputTokens", 0) or tu.get("promptCacheMissTokens", 0)
output_tokens = tu.get("outputTokens", 0) or max(total_tokens - input_tokens, 0)
try:
cost_usd = pricing.calculate_cost(
model=model_code,
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
cached_prompt_tokens=tu.get("promptCacheHitTokens", 0),
)
except ValueError:
cost_usd = 0.0
cost_usd = round(cost_usd, 8)
tool_called = any("project_cli" in tn for tn in tool_names)
tool_succeeded = has_tool_success if scenario.expect_tool_use else True
answer_quality = _evaluate_answer_quality(
answer=answer,
run_finished=result["run_finished"],
expect_tool_use=scenario.expect_tool_use,
has_tool_success=has_tool_success,
tool_names=tool_names,
)
details = _evaluate_criteria(
answer=answer,
run_finished=result["run_finished"],
tool_names=tool_names,
has_tool_success=has_tool_success,
tool_results=result["tool_results"],
scenario=scenario,
)
print(
f" [{model_code}] {scenario.id:<25} "
f"lat={result['latency_ms']}ms "
f"tokens={total_tokens} "
f"cost=${cost_usd:.6f} "
f"tool={'OK' if has_tool_success else 'FAIL'} "
f"answer={answer[:60]}"
)
scores.append(
ScenarioScore(
scenario_id=scenario.id,
model_code=model_code,
latency_ms=result["latency_ms"],
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=cost_usd,
tool_called=tool_called,
tool_succeeded=tool_succeeded,
answer_quality=answer_quality,
details=details,
raw_answer=answer[:500],
run_finished=result["run_finished"],
)
)
return ModelScorecard(model_code=model_code, scenario_scores=scores)
@pytest.fixture(autouse=True)
def _check_env():
if os.getenv("QUALITY_TEST") != "1":
pytest.skip("set QUALITY_TEST=1 to run quality tests")
@pytest.fixture(autouse=True)
def _require_test_user_id():
_get_test_user_id()
@pytest.mark.asyncio
@pytest.mark.quality
@pytest.mark.live
async def test_model_ab_comparison():
user_id = _get_test_user_id()
original_rows = _save_original_models()
scorecards: list[ModelScorecard] = []
try:
for model_code in CANDIDATE_MODELS:
_switch_model(model_code)
card = await _run_model_scenarios(model_code, user_id)
scorecards.append(card)
print(card.summary_table())
finally:
_restore_models(original_rows)
print("\n" + "=" * 60)
print("COMPARISON")
print("=" * 60)
for card in scorecards:
print(
f" {card.model_code:<20} "
f"overall={card.avg_overall:.2f} "
f"latency={card.avg_latency_ms:.0f}ms "
f"cost=${card.avg_cost_usd:.6f} "
f"tool_success={card.tool_success_rate:.0%}"
)
if len(scorecards) == 2:
a, b = scorecards
winner = a.model_code if a.avg_overall >= b.avg_overall else b.model_code
print(f"\n Winner: {winner} (by overall score)")