refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置
This commit is contained in:
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.agent.event_bridge import map_internal_event
|
||||
|
||||
|
||||
class AguiAdapter:
|
||||
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
message = payload.get("message")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
raise ValueError("message is required")
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"session_id": payload.get("session_id"),
|
||||
}
|
||||
|
||||
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
return map_internal_event(event)
|
||||
@@ -1,67 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrewAITemplate:
|
||||
agents: dict[str, Any]
|
||||
tasks: dict[str, Any]
|
||||
workflow: dict[str, Any]
|
||||
tools_whitelist: set[str]
|
||||
|
||||
|
||||
def _default_static_root() -> Path:
|
||||
return Path(__file__).resolve().parents[3] / "config" / "static" / "crewai"
|
||||
|
||||
|
||||
def _read_yaml(file_path: Path) -> dict[str, Any]:
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"Required config file not found: {file_path}")
|
||||
with file_path.open("r", encoding="utf-8") as file:
|
||||
loaded = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"YAML file must be a mapping: {file_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def validate_workflow_stages(stages: list[str]) -> None:
|
||||
expected = ["intent", "execution", "organization"]
|
||||
if stages != expected:
|
||||
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
|
||||
|
||||
|
||||
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
|
||||
root = static_root or _default_static_root()
|
||||
tools = _read_yaml(root / "tools.yaml")
|
||||
raw_tools = tools.get("tools", [])
|
||||
if not isinstance(raw_tools, list):
|
||||
raise ValueError("tools.yaml field 'tools' must be a list")
|
||||
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
|
||||
raise ValueError("tools.yaml list items must be non-empty strings")
|
||||
whitelist = {item.strip() for item in raw_tools}
|
||||
return whitelist
|
||||
|
||||
|
||||
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
|
||||
root = static_root or _default_static_root()
|
||||
|
||||
agents = _read_yaml(root / "agents.yaml")
|
||||
tasks = _read_yaml(root / "tasks.yaml")
|
||||
workflow = _read_yaml(root / "workflow.yaml")
|
||||
|
||||
stages = workflow.get("stages")
|
||||
if not isinstance(stages, list):
|
||||
raise ValueError("workflow.yaml field 'stages' must be a list")
|
||||
validate_workflow_stages([str(stage) for stage in stages])
|
||||
|
||||
return CrewAITemplate(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
workflow=workflow,
|
||||
tools_whitelist=load_tools_whitelist(root),
|
||||
)
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
|
||||
missing = [field for field in required if field not in event]
|
||||
if missing:
|
||||
raise ValueError(f"Missing fields for {kind}: {missing}")
|
||||
|
||||
|
||||
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
kind = event.get("kind")
|
||||
|
||||
if kind == "run_started":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.started",
|
||||
"run_id": event["session_id"],
|
||||
}
|
||||
|
||||
if kind == "message_delta":
|
||||
_require_fields(event, kind=kind, required=["message_id", "delta"])
|
||||
return {
|
||||
"type": "message.delta",
|
||||
"message_id": event["message_id"],
|
||||
"delta": event["delta"],
|
||||
}
|
||||
|
||||
if kind == "tool_started":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.started",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
}
|
||||
|
||||
if kind == "tool_completed":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.completed",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
"result": event.get("result"),
|
||||
}
|
||||
|
||||
if kind == "run_completed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": event["session_id"],
|
||||
"output": event.get("output", ""),
|
||||
}
|
||||
|
||||
if kind == "run_failed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": event["session_id"],
|
||||
"error": event.get("error", ""),
|
||||
}
|
||||
|
||||
raise ValueError(f"Unsupported event kind: {kind}")
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_started(*, run_id: str) -> dict[str, Any]:
|
||||
return {"type": "run.started", "run_id": run_id}
|
||||
|
||||
|
||||
def stage_completed(
|
||||
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
event: dict[str, Any] = {
|
||||
"type": "stage.completed",
|
||||
"run_id": run_id,
|
||||
"stage": stage,
|
||||
}
|
||||
if usage is not None:
|
||||
event["usage"] = usage
|
||||
return event
|
||||
|
||||
|
||||
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": run_id,
|
||||
"output": output,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
|
||||
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": run_id,
|
||||
"error": error,
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_code: str
|
||||
factory_name: str
|
||||
litellm_model: str
|
||||
request_url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMResponse:
|
||||
content: str
|
||||
usage: dict[str, Any]
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
def __init__(self, config: LLMConfig, api_key: str | None = None) -> None:
|
||||
self._config = config
|
||||
self._api_key = api_key or self._get_api_key(config.factory_name)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_key(factory_name: str) -> str:
|
||||
key_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"moonshot": "MOONSHOT_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"volcengine-ark": "ARK_API_KEY",
|
||||
"z-ai": "ZAI_API_KEY",
|
||||
}
|
||||
env_key = key_map.get(factory_name)
|
||||
if not env_key:
|
||||
raise ValueError(f"No API key mapping for factory: {factory_name}")
|
||||
key = os.environ.get(env_key)
|
||||
if not key:
|
||||
raise ValueError(f"Environment variable {env_key} is not set")
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def load_config(
|
||||
model_code: str,
|
||||
static_root: Path | None = None,
|
||||
) -> LLMConfig:
|
||||
root = static_root or (
|
||||
Path(__file__).resolve().parents[3] / "config" / "static" / "database"
|
||||
)
|
||||
yaml_path = root / "llm_catalog.yaml"
|
||||
with yaml_path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
factories = {f["name"]: f for f in data.get("factories", [])}
|
||||
llms = data.get("llms", [])
|
||||
|
||||
for llm in llms:
|
||||
if llm.get("model_code") == model_code:
|
||||
factory_name = llm["factory_name"]
|
||||
factory = factories.get(factory_name)
|
||||
if not factory:
|
||||
raise ValueError(f"Factory not found: {factory_name}")
|
||||
return LLMConfig(
|
||||
model_code=model_code,
|
||||
factory_name=factory_name,
|
||||
litellm_model=llm.get("litellm_model", model_code),
|
||||
request_url=factory["request_url"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Model not found: {model_code}")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = litellm.completion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
) -> LLMResponse:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion( # type: ignore[attr-defined]
|
||||
model=self._config.litellm_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base=self._config.request_url,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or "" # type: ignore[attr-defined]
|
||||
usage = response.usage.model_dump() if response.usage else {} # type: ignore[attr-defined]
|
||||
|
||||
return LLMResponse(content=content, usage=usage)
|
||||
|
||||
|
||||
def get_model_cost(usage: dict[str, Any]) -> Decimal:
|
||||
cost = usage.get("cost")
|
||||
if cost is None:
|
||||
return Decimal("0")
|
||||
return Decimal(str(cost))
|
||||
@@ -1,117 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from core.agent import events
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrchestratorResult:
|
||||
output: str
|
||||
usage: dict[str, Any]
|
||||
events: list[dict[str, Any]]
|
||||
context: dict[str, Any]
|
||||
failed: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class _UsageTracker:
|
||||
def __init__(self) -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
|
||||
def add_usage(self, usage: dict[str, Any]) -> None:
|
||||
input_tokens = usage.get("prompt_tokens", 0) or usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0) or usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
total = usage.get("total_tokens")
|
||||
|
||||
self._input_tokens += input_tokens
|
||||
self._output_tokens += output_tokens
|
||||
self._total_tokens += total if total else (input_tokens + output_tokens)
|
||||
self._cost += get_model_cost(usage)
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": str(self._cost),
|
||||
}
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intent_stage: StageCallable,
|
||||
execution_stage: StageCallable,
|
||||
organization_stage: StageCallable,
|
||||
) -> None:
|
||||
self._intent_stage = intent_stage
|
||||
self._execution_stage = execution_stage
|
||||
self._organization_stage = organization_stage
|
||||
|
||||
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
|
||||
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
tracker = _UsageTracker()
|
||||
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
stage_pipeline: list[tuple[str, StageCallable]] = [
|
||||
("intent", self._intent_stage),
|
||||
("execution", self._execution_stage),
|
||||
("organization", self._organization_stage),
|
||||
]
|
||||
|
||||
stage_output = user_message
|
||||
try:
|
||||
for stage_name, stage_callable in stage_pipeline:
|
||||
stage_result = await stage_callable(
|
||||
message=stage_output, context=context
|
||||
)
|
||||
stage_output = str(stage_result.get("content", stage_output))
|
||||
usage = stage_result.get("usage", {})
|
||||
if isinstance(usage, dict):
|
||||
tracker.add_usage(usage)
|
||||
emitted_events.append(
|
||||
events.stage_completed(
|
||||
run_id=run_id,
|
||||
stage=stage_name,
|
||||
usage=tracker.snapshot(),
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage=tracker.snapshot(),
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=True,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
summary = tracker.snapshot()
|
||||
emitted_events.append(
|
||||
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
|
||||
)
|
||||
return OrchestratorResult(
|
||||
output=stage_output,
|
||||
usage=summary,
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=False,
|
||||
error=None,
|
||||
)
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
TranscribeCallable = Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
class FunASRTool:
|
||||
_transcribe_callable: TranscribeCallable
|
||||
_model: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcribe_callable: TranscribeCallable | None = None,
|
||||
model: str = "fun-asr-realtime-2025-11-07",
|
||||
) -> None:
|
||||
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
|
||||
self._model = model
|
||||
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
|
||||
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
|
||||
return {
|
||||
"model": self._model,
|
||||
**payload,
|
||||
}
|
||||
|
||||
def _dashscope_transcribe(
|
||||
self, *, audio_bytes: bytes, filename: str
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
importlib.import_module("dashscope")
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("DashScope SDK is not installed") from exc
|
||||
|
||||
raise RuntimeError(
|
||||
"DashScope transcribe runtime integration is not configured yet"
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- attachment_extract
|
||||
@@ -1,9 +0,0 @@
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
|
||||
timeouts:
|
||||
intent_seconds: 8
|
||||
execution_seconds: 30
|
||||
organization_seconds: 10
|
||||
@@ -1 +0,0 @@
|
||||
from __future__ import annotations
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FlowState(BaseModel):
|
||||
stage_trace: list[str] = Field(default_factory=list)
|
||||
current_stage: str | None = None
|
||||
message: str = ""
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentFlow:
|
||||
def __init__(self) -> None:
|
||||
self.state = FlowState()
|
||||
|
||||
async def run(self) -> dict[str, Any]:
|
||||
result = await self.intent_recognition()
|
||||
result = await self.task_execution(result)
|
||||
result = await self.result_reporting(result)
|
||||
return result
|
||||
|
||||
async def intent_recognition(self) -> dict[str, Any]:
|
||||
self.state.current_stage = "intent"
|
||||
self.state.stage_trace.append("intent")
|
||||
return {"stage": "intent", "result": "intent recognized"}
|
||||
|
||||
async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
self.state.current_stage = "execution"
|
||||
self.state.stage_trace.append("execution")
|
||||
return {"stage": "execution", "result": "task executed"}
|
||||
|
||||
async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
|
||||
self.state.current_stage = "reporting"
|
||||
self.state.stage_trace.append("reporting")
|
||||
return {"stage": "reporting", "result": "reported"}
|
||||
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.agent.service import AgentChatService
|
||||
from v1.profile.dependencies import get_current_user
|
||||
|
||||
|
||||
def get_agent_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> AgentChatService:
|
||||
return AgentChatService(session=session, current_user=user)
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("/runs")
|
||||
async def create_run(
|
||||
input_data: RunAgentInput,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||
) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
service.stream_run(input_data),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs/{run_id}/resume")
|
||||
async def resume_run(
|
||||
run_id: Annotated[str, Path(min_length=1, max_length=255)],
|
||||
input_data: RunAgentInput,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_service)],
|
||||
) -> StreamingResponse:
|
||||
if input_data.runId != run_id:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
|
||||
)
|
||||
await service.prepare_resume(run_id, input_data)
|
||||
return StreamingResponse(
|
||||
service.stream_resume(run_id, input_data),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class RunAgentInput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
threadId: str = Field(min_length=1, max_length=255)
|
||||
runId: str = Field(min_length=1, max_length=255)
|
||||
parentRunId: str | None = Field(default=None, max_length=255)
|
||||
state: dict[str, Any] = Field(default_factory=dict)
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
tools: list[dict[str, Any]] = Field(default_factory=list)
|
||||
context: list[dict[str, Any]] = Field(default_factory=list)
|
||||
forwardedProps: dict[str, Any] = Field(default_factory=dict)
|
||||
resume: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentChatRunRequest(BaseModel):
|
||||
message: str = Field(min_length=1, max_length=8000)
|
||||
session_id: UUID | None = None
|
||||
|
||||
|
||||
class AgentChatEvent(BaseModel):
|
||||
type: str
|
||||
run_id: str | None = None
|
||||
message_id: str | None = None
|
||||
delta: str | None = None
|
||||
tool_name: str | None = None
|
||||
result: str | None = None
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AgentChatRunResponse(BaseModel):
|
||||
session_id: UUID
|
||||
output: str
|
||||
events: list[AgentChatEvent]
|
||||
|
||||
|
||||
class PendingToolStatus(str, Enum):
|
||||
PENDING_APPROVAL = "PENDING_APPROVAL"
|
||||
APPROVED_EXECUTING = "APPROVED_EXECUTING"
|
||||
EXECUTED = "EXECUTED"
|
||||
REJECTED = "REJECTED"
|
||||
EXPIRED = "EXPIRED"
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
interrupt_id: str = Field(min_length=1, max_length=255)
|
||||
tool_name: str = Field(min_length=1, max_length=255)
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
status: PendingToolStatus
|
||||
expires_at: datetime
|
||||
decision: dict[str, Any] | None = None
|
||||
result: dict[str, Any] | None = None
|
||||
updated_at: datetime
|
||||
|
||||
@field_validator("expires_at", "updated_at")
|
||||
@classmethod
|
||||
def _validate_timezone_aware(cls, value: datetime) -> datetime:
|
||||
if value.tzinfo is None or value.utcoffset() is None:
|
||||
raise ValueError("datetime must be timezone-aware")
|
||||
return value
|
||||
|
||||
|
||||
class SnapshotRunContext(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str = Field(min_length=1, max_length=255)
|
||||
run_id: str = Field(min_length=1, max_length=255)
|
||||
|
||||
|
||||
class AgentSessionSnapshot(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
version: Literal[2]
|
||||
pending_tool_call: PendingToolCall | None
|
||||
run_context: SnapshotRunContext
|
||||
@@ -1,545 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.agent.agui_adapter import AguiAdapter
|
||||
from core.agent.orchestrator import AgentChatOrchestrator
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.agent.schemas import (
|
||||
AgentSessionSnapshot,
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
PendingToolCall,
|
||||
PendingToolStatus,
|
||||
RunAgentInput,
|
||||
SnapshotRunContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.agent.service")
|
||||
|
||||
DEFAULT_RATE_LIMIT = 60
|
||||
EMPTY_USAGE = {"input_tokens": 0, "output_tokens": 0, "cost": "0"}
|
||||
|
||||
|
||||
class ResumeDecisionResult(BaseModel):
|
||||
applied: bool
|
||||
|
||||
|
||||
def build_session_title(first_message: str, *, now: datetime) -> str:
|
||||
title = first_message.strip().replace("\n", " ")[:24]
|
||||
if not title:
|
||||
return now.strftime("新对话 %Y-%m-%d %H:%M")
|
||||
return title
|
||||
|
||||
|
||||
def aggregate_session_cost(costs: list[Decimal]) -> Decimal:
|
||||
total = Decimal("0")
|
||||
for cost in costs:
|
||||
if cost < 0:
|
||||
raise ValueError("cost must be non-negative")
|
||||
total += cost
|
||||
return total
|
||||
|
||||
|
||||
def select_recent_session(
|
||||
sessions: list[AgentChatSession],
|
||||
) -> AgentChatSession | None:
|
||||
if not sessions:
|
||||
return None
|
||||
return max(sessions, key=lambda item: item.last_activity_at)
|
||||
|
||||
|
||||
class AgentChatService(BaseService):
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, session: AsyncSession, current_user: CurrentUser | None) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._session = session
|
||||
self._adapter = AguiAdapter()
|
||||
self._orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=self._intent_stage,
|
||||
execution_stage=self._execution_stage,
|
||||
organization_stage=self._organization_stage,
|
||||
)
|
||||
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
try:
|
||||
command = self._adapter.to_command(payload.model_dump(mode="python"))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
user_id = self.require_user_id()
|
||||
await enforce_rate_limit(
|
||||
scope="agent_run",
|
||||
identifier=str(user_id),
|
||||
limit=DEFAULT_RATE_LIMIT,
|
||||
window_seconds=DEFAULT_RATE_LIMIT,
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
chat_session = await self._resolve_session(
|
||||
session_id=command["session_id"],
|
||||
user_id=user_id,
|
||||
first_message=command["message"],
|
||||
now=now,
|
||||
)
|
||||
|
||||
base_seq = await self._next_seq_base(chat_session.id)
|
||||
user_message = AgentChatMessage(
|
||||
session_id=chat_session.id,
|
||||
seq=base_seq + 1,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=command["message"],
|
||||
cost=Decimal("0"),
|
||||
)
|
||||
orchestrator_result = await self._orchestrator.run(
|
||||
run_id=str(chat_session.id),
|
||||
user_message=command["message"],
|
||||
)
|
||||
assistant_message = AgentChatMessage(
|
||||
session_id=chat_session.id,
|
||||
seq=base_seq + 2,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=orchestrator_result.output,
|
||||
input_tokens=int(orchestrator_result.usage["input_tokens"]),
|
||||
output_tokens=int(orchestrator_result.usage["output_tokens"]),
|
||||
cost=Decimal(orchestrator_result.usage["cost"]),
|
||||
)
|
||||
self._session.add(user_message)
|
||||
self._session.add(assistant_message)
|
||||
|
||||
chat_session.status = (
|
||||
AgentChatSessionStatus.FAILED
|
||||
if orchestrator_result.failed
|
||||
else AgentChatSessionStatus.COMPLETED
|
||||
)
|
||||
chat_session.last_activity_at = now
|
||||
chat_session.message_count = chat_session.message_count + 2
|
||||
chat_session.total_tokens = chat_session.total_tokens + int(
|
||||
orchestrator_result.usage["total_tokens"]
|
||||
)
|
||||
chat_session.total_cost = aggregate_session_cost(
|
||||
[
|
||||
Decimal(chat_session.total_cost),
|
||||
Decimal(orchestrator_result.usage["cost"]),
|
||||
]
|
||||
)
|
||||
|
||||
await self._session.commit()
|
||||
await self._session.refresh(chat_session)
|
||||
await self._session.refresh(user_message)
|
||||
|
||||
mapped_events = self._build_mapped_events(
|
||||
session_id=str(chat_session.id),
|
||||
message_id=str(user_message.id),
|
||||
user_message=command["message"],
|
||||
assistant_output=assistant_message.content,
|
||||
failed=orchestrator_result.failed,
|
||||
error=orchestrator_result.error,
|
||||
)
|
||||
events = [AgentChatEvent.model_validate(item) for item in mapped_events]
|
||||
if orchestrator_result.failed:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Agent orchestration failed"
|
||||
)
|
||||
return AgentChatRunResponse(
|
||||
session_id=chat_session.id,
|
||||
output=assistant_message.content,
|
||||
events=events,
|
||||
)
|
||||
except HTTPException:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Agent chat run failed")
|
||||
raise HTTPException(status_code=503, detail="Agent chat store unavailable")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._session.rollback()
|
||||
logger.exception(
|
||||
"Agent chat unexpected failure", error_type=type(exc).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Agent orchestration failed"
|
||||
) from exc
|
||||
|
||||
def _build_mapped_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
user_message: str,
|
||||
assistant_output: str,
|
||||
failed: bool,
|
||||
error: str | None,
|
||||
) -> list[dict[str, object]]:
|
||||
mapped_events = [
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_started",
|
||||
"session_id": session_id,
|
||||
}
|
||||
),
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "message_delta",
|
||||
"message_id": message_id,
|
||||
"delta": user_message,
|
||||
}
|
||||
),
|
||||
]
|
||||
if failed:
|
||||
mapped_events.append(
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_failed",
|
||||
"session_id": session_id,
|
||||
"error": error or "orchestration failed",
|
||||
}
|
||||
)
|
||||
)
|
||||
return mapped_events
|
||||
|
||||
mapped_events.append(
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": session_id,
|
||||
"output": assistant_output,
|
||||
}
|
||||
)
|
||||
)
|
||||
return mapped_events
|
||||
|
||||
async def _resolve_session(
|
||||
self,
|
||||
*,
|
||||
session_id: object | None,
|
||||
user_id: UUID,
|
||||
first_message: str,
|
||||
now: datetime,
|
||||
) -> AgentChatSession:
|
||||
if session_id is not None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.where(AgentChatSession.user_id == user_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.with_for_update()
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
existing.status = AgentChatSessionStatus.RUNNING
|
||||
return existing
|
||||
|
||||
title = build_session_title(first_message, now=now)
|
||||
|
||||
created = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=now,
|
||||
)
|
||||
self._session.add(created)
|
||||
await self._session.flush()
|
||||
return created
|
||||
|
||||
async def _next_seq_base(self, session_id: object) -> int:
|
||||
stmt = select(func.max(AgentChatMessage.seq)).where(
|
||||
AgentChatMessage.session_id == session_id
|
||||
)
|
||||
result = await self._session.scalar(stmt)
|
||||
if result is None:
|
||||
return 0
|
||||
return int(result)
|
||||
|
||||
async def _intent_stage(
|
||||
self, *, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
context["intent"] = "default"
|
||||
return {"content": message, "usage": EMPTY_USAGE}
|
||||
|
||||
async def _execution_stage(
|
||||
self, *, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
return {"content": message, "usage": EMPTY_USAGE}
|
||||
|
||||
async def _organization_stage(
|
||||
self, *, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
return {"content": message, "usage": EMPTY_USAGE}
|
||||
|
||||
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
return None
|
||||
return session.state_snapshot
|
||||
|
||||
@staticmethod
|
||||
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
|
||||
try:
|
||||
return AgentSessionSnapshot.model_validate(raw_snapshot)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise ValueError("Invalid state_snapshot format") from exc
|
||||
|
||||
async def _get_session_for_update(
|
||||
self, session_id: UUID
|
||||
) -> AgentChatSession | None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.with_for_update()
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _assert_session_owner(self, session: AgentChatSession) -> None:
|
||||
if self._current_user is None:
|
||||
return
|
||||
|
||||
if session.user_id != self.require_user_id():
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
@staticmethod
|
||||
def _validate_no_newlines(value: str, *, field_name: str) -> None:
|
||||
if "\n" in value or "\r" in value:
|
||||
raise ValueError(f"{field_name} must not contain newlines")
|
||||
|
||||
@staticmethod
|
||||
def _sse_data(payload: dict[str, Any]) -> str:
|
||||
return f"data: {json.dumps(payload)}\\n\\n"
|
||||
|
||||
async def set_pending_tool_call(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict,
|
||||
expires_at: datetime,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
self._assert_session_owner(session)
|
||||
|
||||
snapshot = AgentSessionSnapshot(
|
||||
version=2,
|
||||
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
|
||||
pending_tool_call=PendingToolCall(
|
||||
interrupt_id=interrupt_id,
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
status=PendingToolStatus.PENDING_APPROVAL,
|
||||
expires_at=expires_at,
|
||||
decision=None,
|
||||
result=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
),
|
||||
)
|
||||
session.state_snapshot = snapshot.model_dump(mode="json")
|
||||
|
||||
async def update_pending_tool_call_status(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
|
||||
session = await self._session.scalar(stmt)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
self._assert_session_owner(session)
|
||||
if session.state_snapshot is None:
|
||||
raise ValueError("No pending tool call found")
|
||||
|
||||
snapshot = self._load_snapshot_v2(session.state_snapshot)
|
||||
pending = snapshot.pending_tool_call
|
||||
if pending is None:
|
||||
raise ValueError("No pending tool call found")
|
||||
if pending.interrupt_id != interrupt_id:
|
||||
raise ValueError("Interrupt ID mismatch")
|
||||
|
||||
updated_pending = pending.model_copy(
|
||||
update={
|
||||
"status": PendingToolStatus(status),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
updated_snapshot = snapshot.model_copy(
|
||||
update={"pending_tool_call": updated_pending}
|
||||
)
|
||||
session.state_snapshot = updated_snapshot.model_dump(mode="json")
|
||||
|
||||
async def apply_resume_decision(
|
||||
self,
|
||||
*,
|
||||
session_id: UUID,
|
||||
interrupt_id: str,
|
||||
decision: dict[str, Any],
|
||||
) -> ResumeDecisionResult:
|
||||
session = await self._get_session_for_update(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
self._assert_session_owner(session)
|
||||
|
||||
if session.state_snapshot is None:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
snapshot = self._load_snapshot_v2(session.state_snapshot)
|
||||
pending = snapshot.pending_tool_call
|
||||
if pending is None:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
if pending.interrupt_id != interrupt_id:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
if pending.status != PendingToolStatus.PENDING_APPROVAL:
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if pending.expires_at <= now:
|
||||
expired_pending = pending.model_copy(
|
||||
update={
|
||||
"status": PendingToolStatus.EXPIRED,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
expired_snapshot = snapshot.model_copy(
|
||||
update={"pending_tool_call": expired_pending}
|
||||
)
|
||||
session.state_snapshot = expired_snapshot.model_dump(mode="json")
|
||||
return ResumeDecisionResult(applied=False)
|
||||
|
||||
decision_value = decision.get("decision", "approved")
|
||||
next_status = (
|
||||
PendingToolStatus.APPROVED_EXECUTING
|
||||
if decision_value == "approved"
|
||||
else PendingToolStatus.REJECTED
|
||||
)
|
||||
|
||||
updated_pending = pending.model_copy(
|
||||
update={
|
||||
"status": next_status,
|
||||
"decision": decision,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
updated_snapshot = snapshot.model_copy(
|
||||
update={"pending_tool_call": updated_pending}
|
||||
)
|
||||
session.state_snapshot = updated_snapshot.model_dump(mode="json")
|
||||
|
||||
return ResumeDecisionResult(applied=True)
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
|
||||
self._validate_no_newlines(input_data.runId, field_name="runId")
|
||||
|
||||
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
|
||||
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
|
||||
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
|
||||
self._validate_no_newlines(run_id, field_name="runId")
|
||||
|
||||
user_id = self.require_user_id()
|
||||
await enforce_rate_limit(
|
||||
scope="agent_resume",
|
||||
identifier=str(user_id),
|
||||
limit=DEFAULT_RATE_LIMIT,
|
||||
window_seconds=DEFAULT_RATE_LIMIT,
|
||||
)
|
||||
|
||||
try:
|
||||
session_id = UUID(run_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="run_id must be a valid UUID"
|
||||
) from exc
|
||||
|
||||
session = await self._get_session_for_update(session_id)
|
||||
if session is None or session.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if input_data.resume is None:
|
||||
raise HTTPException(status_code=422, detail="resume payload is required")
|
||||
|
||||
interrupt_id = input_data.resume.get("interruptId")
|
||||
if not isinstance(interrupt_id, str) or not interrupt_id:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="resume.interruptId is required"
|
||||
)
|
||||
|
||||
decision_payload = input_data.resume.get("payload", {})
|
||||
if not isinstance(decision_payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="resume.payload must be an object",
|
||||
)
|
||||
|
||||
try:
|
||||
decision_result = await self.apply_resume_decision(
|
||||
session_id=session_id,
|
||||
interrupt_id=interrupt_id,
|
||||
decision=decision_payload,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
if not decision_result.applied:
|
||||
if session.state_snapshot is not None:
|
||||
snapshot = self._load_snapshot_v2(session.state_snapshot)
|
||||
pending = snapshot.pending_tool_call
|
||||
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
|
||||
await self._session.commit()
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail="Pending tool call expired",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Resume decision not applicable",
|
||||
)
|
||||
|
||||
await self._session.commit()
|
||||
|
||||
async def stream_resume(
|
||||
self, run_id: str, input_data: RunAgentInput
|
||||
) -> AsyncGenerator[str, None]:
|
||||
self._validate_no_newlines(run_id, field_name="runId")
|
||||
|
||||
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
|
||||
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
|
||||
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})
|
||||
@@ -1,80 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
ALLOWED_BACKEND_TOOLS = frozenset(
|
||||
{
|
||||
"srv.search_docs",
|
||||
"srv.get_user_info",
|
||||
"srv.send_message",
|
||||
"srv.transfer_funds",
|
||||
"srv.delete_file",
|
||||
}
|
||||
)
|
||||
|
||||
ALLOWED_FRONTEND_TOOLS = frozenset(
|
||||
{
|
||||
"ui.navigate_to",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class InterruptResult(BaseModel):
|
||||
interrupt_type: str
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class BackendExecutionResult(BaseModel):
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
result: Any | None = None
|
||||
|
||||
|
||||
class ToolDispatcher:
|
||||
def dispatch(
|
||||
self, tool: dict[str, Any]
|
||||
) -> InterruptResult | BackendExecutionResult:
|
||||
return dispatch_tool_call(tool)
|
||||
|
||||
|
||||
def dispatch_tool_call(
|
||||
tool: dict[str, Any],
|
||||
) -> InterruptResult | BackendExecutionResult:
|
||||
validate_tool_spec(tool)
|
||||
|
||||
name = tool["name"]
|
||||
target = tool["execution_target"]
|
||||
args = tool.get("args", {})
|
||||
|
||||
if target == "frontend":
|
||||
if name not in ALLOWED_FRONTEND_TOOLS:
|
||||
raise ValueError(f"Frontend tool '{name}' not in allowlist")
|
||||
return InterruptResult(
|
||||
interrupt_type="tool_execution",
|
||||
tool_name=name,
|
||||
tool_args=args,
|
||||
)
|
||||
|
||||
if target == "backend":
|
||||
if name not in ALLOWED_BACKEND_TOOLS:
|
||||
raise ValueError(f"Backend tool '{name}' not in allowlist")
|
||||
|
||||
requires_approval = tool.get("requires_approval", False)
|
||||
if requires_approval:
|
||||
return InterruptResult(
|
||||
interrupt_type="approval_required",
|
||||
tool_name=name,
|
||||
tool_args=args,
|
||||
)
|
||||
return BackendExecutionResult(
|
||||
tool_name=name,
|
||||
tool_args=args,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown execution_target: {target}")
|
||||
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def validate_tool_spec(spec: dict[str, Any]) -> None:
|
||||
try:
|
||||
name = spec["name"]
|
||||
target = spec["execution_target"]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {e.args[0]}") from e
|
||||
|
||||
if not name or not isinstance(name, str):
|
||||
raise ValueError("Tool name must be a non-empty string")
|
||||
if not target or not isinstance(target, str):
|
||||
raise ValueError("execution_target must be a non-empty string")
|
||||
|
||||
if not (name.startswith("ui.") or name.startswith("srv.")):
|
||||
raise ValueError("Tool name must be in ui.* or srv.* namespace")
|
||||
|
||||
if name.startswith("ui.") and target != "frontend":
|
||||
raise ValueError("ui.* must use frontend target")
|
||||
if name.startswith("srv.") and target != "backend":
|
||||
raise ValueError("srv.* must use backend target")
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.http.models import HealthResponse
|
||||
from v1.agent.router import router as agent_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.friendships.router import router as friendships_router
|
||||
from v1.inbox_messages.router import router as inbox_messages_router
|
||||
@@ -17,7 +16,6 @@ router.include_router(auth_router)
|
||||
router.include_router(friendships_router)
|
||||
router.include_router(infra_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(agent_router)
|
||||
router.include_router(schedule_items_router)
|
||||
router.include_router(inbox_messages_router)
|
||||
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeE2EAgentChatService(AgentChatService):
|
||||
def __init__(self) -> None:
|
||||
return None
|
||||
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
session_id = payload.session_id or UUID("00000000-0000-0000-0000-000000000001")
|
||||
return AgentChatRunResponse(
|
||||
session_id=session_id,
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(type="run.started", run_id=str(session_id)),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed", run_id=str(session_id), output=payload.message
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_agent_chat_flow_e2e() -> None:
|
||||
app.dependency_overrides[get_agent_service] = lambda: FakeE2EAgentChatService()
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.post(
|
||||
"/api/v1/agent-chat",
|
||||
data=json.dumps({"message": "hello"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent.service import select_recent_session
|
||||
|
||||
|
||||
def test_recent_session_home_default_selection() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 8, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=100,
|
||||
total_cost=Decimal("0.010000"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-0000000000a2"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=3,
|
||||
total_tokens=120,
|
||||
total_cost=Decimal("0.020000"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-0000000000a2")
|
||||
@@ -1,97 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import MethodType
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent.schemas import AgentChatRunRequest
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class _FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
if isinstance(obj, AgentChatSession) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
if isinstance(obj, AgentChatMessage) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_persists_messages_and_emits_ordered_events() -> None:
|
||||
fake_db = _FakeAsyncSession()
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
async def _resolve_session(
|
||||
self: AgentChatService,
|
||||
*,
|
||||
session_id: object | None,
|
||||
user_id: UUID,
|
||||
first_message: str,
|
||||
now: datetime,
|
||||
) -> AgentChatSession:
|
||||
assert session_id is None
|
||||
assert first_message == "hello"
|
||||
return AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000111"),
|
||||
user_id=user_id,
|
||||
title="hello",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=now,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=Decimal("0"),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
|
||||
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
|
||||
return 2
|
||||
|
||||
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
|
||||
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
|
||||
|
||||
response = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert fake_db.committed is True
|
||||
inserted_messages = [
|
||||
item for item in fake_db.added if isinstance(item, AgentChatMessage)
|
||||
]
|
||||
assert len(inserted_messages) == 2
|
||||
assert [msg.seq for msg in inserted_messages] == [3, 4]
|
||||
assert [msg.role for msg in inserted_messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert [event.type for event in response.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
@@ -1,78 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAgentChatService:
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
return AgentChatRunResponse(
|
||||
session_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(
|
||||
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed",
|
||||
run_id="00000000-0000-0000-0000-000000000001",
|
||||
output=payload.message,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _override_agent_chat_service(
|
||||
service: FakeAgentChatService,
|
||||
) -> Callable[[], AgentChatService]:
|
||||
def _get_service() -> AgentChatService:
|
||||
return service # type: ignore[return-value]
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_run_route_returns_response() -> None:
|
||||
app.dependency_overrides[get_agent_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat", json={"message": "hello"})
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_route_validates_payload() -> None:
|
||||
app.dependency_overrides[get_agent_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat", json={"message": ""})
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from v1.agent.service import aggregate_session_cost
|
||||
|
||||
|
||||
def test_aggregate_session_cost_sums_non_negative_values() -> None:
|
||||
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
|
||||
assert total == Decimal("0.012500")
|
||||
|
||||
|
||||
def test_aggregate_session_cost_rejects_negative_value() -> None:
|
||||
try:
|
||||
aggregate_session_cost([Decimal("-0.010000")])
|
||||
raised = False
|
||||
except ValueError:
|
||||
raised = True
|
||||
|
||||
assert raised is True
|
||||
@@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent.service import select_recent_session
|
||||
|
||||
|
||||
def test_select_recent_session_uses_last_activity_desc() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=1,
|
||||
total_tokens=1,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000002"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=2,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
|
||||
def test_select_recent_session_returns_none_for_empty_collection() -> None:
|
||||
assert select_recent_session([]) is None
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentService:
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
return None
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
||||
|
||||
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "r1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_END", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "r1"}\n\n'
|
||||
|
||||
|
||||
def _get_test_user() -> CurrentUser:
|
||||
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
app.dependency_overrides[get_current_user] = _get_test_user
|
||||
app.dependency_overrides[get_agent_service] = lambda: FakeAgentService()
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestChatRoutes:
|
||||
def test_run_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
||||
assert 'data: {"type": "TEXT_MESSAGE_START"' in events[1]
|
||||
|
||||
def test_resume_route_streams_sse_events(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs/r1/resume", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
assert 'data: {"type": "RUN_STARTED"' in events[0]
|
||||
assert 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"' in events[2]
|
||||
@@ -1,144 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from core.auth.models import CurrentUser
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.users.dependencies import get_current_user
|
||||
|
||||
|
||||
class FakeAgentServiceWithInterrupt:
|
||||
async def prepare_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
return None
|
||||
|
||||
async def stream_run(self, input_data: RunAgentInput):
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + input_data.runId + '"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m1"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Let me navigate"}\n\n'
|
||||
yield 'data: {"type": "TOOL_CALL", "toolName": "ui.navigate_to", "args": {"path": "/home"}}\n\n'
|
||||
yield (
|
||||
'data: {"type": "RUN_FINISHED", "runId": "'
|
||||
+ input_data.runId
|
||||
+ '", "outcome": "interrupt", "interrupt": {"id": "int-1", "reason": "frontend_tool", "payload": {"toolName": "ui.navigate_to", "args": {"path": "/home"}}}}\n\n'
|
||||
)
|
||||
|
||||
async def stream_resume(self, run_id: str, input_data: RunAgentInput):
|
||||
if input_data.resume and input_data.resume.get("interruptId") == "int-1":
|
||||
payload = input_data.resume.get("payload", {})
|
||||
yield 'data: {"type": "RUN_STARTED", "runId": "' + run_id + '"}\n\n'
|
||||
yield (
|
||||
'data: {"type": "TOOL_RESULT", "toolName": "ui.navigate_to", "result": '
|
||||
+ json.dumps(payload.get("result", {}))
|
||||
+ "}\n\n"
|
||||
)
|
||||
yield 'data: {"type": "TEXT_MESSAGE_START", "messageId": "m2"}\n\n'
|
||||
yield 'data: {"type": "TEXT_MESSAGE_CONTENT", "delta": "Navigation completed"}\n\n'
|
||||
yield 'data: {"type": "RUN_FINISHED", "runId": "' + run_id + '"}\n\n'
|
||||
else:
|
||||
yield (
|
||||
'data: {"type": "RUN_FINISHED", "runId": "'
|
||||
+ run_id
|
||||
+ '", "outcome": "error"}\n\n'
|
||||
)
|
||||
|
||||
|
||||
def _get_test_user() -> CurrentUser:
|
||||
return CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
app.dependency_overrides[get_current_user] = _get_test_user
|
||||
app.dependency_overrides[get_agent_service] = (
|
||||
lambda: FakeAgentServiceWithInterrupt()
|
||||
)
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestInterruptResumeFlow:
|
||||
def test_frontend_tool_interrupt_then_resume_with_result(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [{"role": "user", "content": "Navigate to home"}],
|
||||
"tools": [{"name": "ui.navigate_to", "execution_target": "frontend"}],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
events = response.text.split("\n\n")
|
||||
interrupt_event = [e for e in events if '"outcome": "interrupt"' in e][0]
|
||||
assert '"id": "int-1"' in interrupt_event
|
||||
assert '"reason": "frontend_tool"' in interrupt_event
|
||||
|
||||
resume_payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {
|
||||
"interruptId": "int-1",
|
||||
"payload": {"result": {"success": True}},
|
||||
},
|
||||
}
|
||||
resume_response = client.post(
|
||||
"/api/v1/agent/runs/r1/resume", json=resume_payload
|
||||
)
|
||||
assert resume_response.status_code == 200
|
||||
|
||||
resume_events = resume_response.text.split("\n\n")
|
||||
tool_result_event = [e for e in resume_events if '"type": "TOOL_RESULT"' in e][
|
||||
0
|
||||
]
|
||||
assert '"toolName": "ui.navigate_to"' in tool_result_event
|
||||
assert '"success": true' in tool_result_event.lower()
|
||||
|
||||
def test_backend_tool_approval_rejected(self, client: TestClient):
|
||||
payload = {
|
||||
"threadId": "t2",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [{"role": "user", "content": "Transfer funds"}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "srv.transfer_funds",
|
||||
"execution_target": "backend",
|
||||
"requires_approval": True,
|
||||
}
|
||||
],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
response = client.post("/api/v1/agent/runs", json=payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
resume_payload = {
|
||||
"threadId": "t2",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {
|
||||
"interruptId": "int-1",
|
||||
"payload": {"decision": "rejected", "reason": "User denied"},
|
||||
},
|
||||
}
|
||||
resume_response = client.post(
|
||||
"/api/v1/agent/runs/r2/resume", json=resume_payload
|
||||
)
|
||||
assert resume_response.status_code == 200
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.agui_adapter import AguiAdapter
|
||||
|
||||
|
||||
def test_to_command_maps_payload_fields() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
command = adapter.to_command(
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
assert command["message"] == "hello"
|
||||
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_to_protocol_event_maps_internal_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
mapped = adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": "run-1",
|
||||
"output": "done",
|
||||
}
|
||||
)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
|
||||
|
||||
|
||||
def test_to_protocol_event_raises_for_invalid_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
adapter.to_protocol_event({"kind": "unknown"})
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_transcribe_uses_injected_dashscope_callable() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert filename == "voice.wav"
|
||||
assert audio_bytes == b"audio"
|
||||
return {"text": "你好", "request_id": "req-1"}
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
|
||||
assert result["text"] == "你好"
|
||||
assert result["request_id"] == "req-1"
|
||||
assert result["model"] == "fun-asr-realtime-2025-11-07"
|
||||
|
||||
|
||||
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
raise RuntimeError("upstream timeout")
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
@@ -1,48 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from core.agent.litellm_client import get_model_cost
|
||||
|
||||
|
||||
def test_get_model_cost_returns_decimal() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0.002500")
|
||||
|
||||
|
||||
def test_get_model_cost_with_no_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0")
|
||||
|
||||
|
||||
def test_get_model_cost_with_zero_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": "0",
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0")
|
||||
|
||||
|
||||
def test_get_model_cost_with_numeric_cost() -> None:
|
||||
usage = {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 12,
|
||||
"cost": 0.0025,
|
||||
}
|
||||
cost = get_model_cost(usage)
|
||||
assert cost == Decimal("0.0025")
|
||||
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.event_bridge import map_internal_event
|
||||
|
||||
|
||||
def test_map_run_started_event() -> None:
|
||||
event = {"kind": "run_started", "session_id": "s1"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.started", "run_id": "s1"}
|
||||
|
||||
|
||||
def test_map_message_delta_event() -> None:
|
||||
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
|
||||
def test_map_tool_events() -> None:
|
||||
started = {
|
||||
"kind": "tool_started",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
}
|
||||
completed = {
|
||||
"kind": "tool_completed",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
"result": "ok",
|
||||
}
|
||||
|
||||
mapped_started = map_internal_event(started)
|
||||
mapped_completed = map_internal_event(completed)
|
||||
|
||||
assert mapped_started["type"] == "tool.started"
|
||||
assert mapped_started["tool_name"] == "asr_fun_asr"
|
||||
assert mapped_completed["type"] == "tool.completed"
|
||||
assert mapped_completed["result"] == "ok"
|
||||
|
||||
|
||||
def test_map_run_completed_event() -> None:
|
||||
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
|
||||
|
||||
|
||||
def test_map_unknown_event_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "unknown"})
|
||||
|
||||
|
||||
def test_map_event_missing_required_field_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "message_delta", "message_id": "m1"})
|
||||
@@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.orchestrator import AgentChatOrchestrator
|
||||
|
||||
|
||||
async def _intent_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("intent")
|
||||
return {
|
||||
"content": f"intent:{message}",
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
|
||||
}
|
||||
|
||||
|
||||
async def _execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
return {
|
||||
"content": f"execution:{message}",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
|
||||
}
|
||||
|
||||
|
||||
async def _organization_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("organization")
|
||||
return {
|
||||
"content": "final answer",
|
||||
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
|
||||
}
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution", "organization"]
|
||||
assert result.output == "final answer"
|
||||
assert result.usage["total_tokens"] == 13
|
||||
assert result.events[0]["type"] == "run.started"
|
||||
assert result.events[-1]["type"] == "run.completed"
|
||||
|
||||
|
||||
async def _failing_execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_failing_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution"]
|
||||
assert result.events[-1]["type"] == "run.failed"
|
||||
assert result.events[-1]["run_id"] == "run-2"
|
||||
assert "boom" in (result.events[-1].get("error") or "")
|
||||
assert result.failed is True
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
def test_orchestrator_emits_stage_event_payload_shape() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
|
||||
|
||||
for event in result.events:
|
||||
assert "type" in event
|
||||
assert event.get("run_id") == "run-3"
|
||||
|
||||
stage_events = [
|
||||
event for event in result.events if event["type"] == "stage.completed"
|
||||
]
|
||||
assert [event["stage"] for event in stage_events] == [
|
||||
"intent",
|
||||
"execution",
|
||||
"organization",
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from v1.agent.service import build_session_title
|
||||
|
||||
|
||||
def test_build_session_title_truncates_first_message() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title(
|
||||
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
|
||||
)
|
||||
|
||||
assert len(title) == 24
|
||||
|
||||
|
||||
def test_build_session_title_falls_back_when_message_empty() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title("\n ", now=now)
|
||||
|
||||
assert title == "新对话 2026-02-25 10:30"
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.agui_adapter import AguiAdapter
|
||||
|
||||
|
||||
def test_to_command_maps_payload_fields() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
command = adapter.to_command(
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
assert command["message"] == "hello"
|
||||
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_to_protocol_event_maps_internal_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
mapped = adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": "run-1",
|
||||
"output": "done",
|
||||
}
|
||||
)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
|
||||
|
||||
|
||||
def test_to_protocol_event_raises_for_invalid_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
adapter.to_protocol_event({"kind": "unknown"})
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_transcribe_uses_injected_dashscope_callable() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert filename == "voice.wav"
|
||||
assert audio_bytes == b"audio"
|
||||
return {"text": "你好", "request_id": "req-1"}
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
|
||||
assert result["text"] == "你好"
|
||||
assert result["request_id"] == "req-1"
|
||||
assert result["model"] == "fun-asr-realtime-2025-11-07"
|
||||
|
||||
|
||||
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
raise RuntimeError("upstream timeout")
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cost_tracker import CostTracker
|
||||
|
||||
|
||||
def test_normalize_usage_and_cost_aggregation() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
)
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 3,
|
||||
"cost": "0.003000",
|
||||
"currency": "USD",
|
||||
}
|
||||
)
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 12
|
||||
assert snapshot["output_tokens"] == 8
|
||||
assert snapshot["total_tokens"] == 20
|
||||
assert snapshot["cost"] == Decimal("0.005500")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_negative_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": -1})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"cost": "-0.010000"})
|
||||
|
||||
|
||||
def test_snapshot_is_zero_before_any_usage() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 0
|
||||
assert snapshot["output_tokens"] == 0
|
||||
assert snapshot["total_tokens"] == 0
|
||||
assert snapshot["cost"] == Decimal("0")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_currency_mismatch() -> None:
|
||||
tracker = CostTracker(currency="USD")
|
||||
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"cost": "0.001000",
|
||||
"currency": "CNY",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_add_usage_rejects_non_integral_token_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": 1.5})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"output_tokens": True})
|
||||
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.event_bridge import map_internal_event
|
||||
|
||||
|
||||
def test_map_run_started_event() -> None:
|
||||
event = {"kind": "run_started", "session_id": "s1"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.started", "run_id": "s1"}
|
||||
|
||||
|
||||
def test_map_message_delta_event() -> None:
|
||||
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
|
||||
def test_map_tool_events() -> None:
|
||||
started = {
|
||||
"kind": "tool_started",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
}
|
||||
completed = {
|
||||
"kind": "tool_completed",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
"result": "ok",
|
||||
}
|
||||
|
||||
mapped_started = map_internal_event(started)
|
||||
mapped_completed = map_internal_event(completed)
|
||||
|
||||
assert mapped_started["type"] == "tool.started"
|
||||
assert mapped_started["tool_name"] == "asr_fun_asr"
|
||||
assert mapped_completed["type"] == "tool.completed"
|
||||
assert mapped_completed["result"] == "ok"
|
||||
|
||||
|
||||
def test_map_run_completed_event() -> None:
|
||||
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
|
||||
|
||||
|
||||
def test_map_unknown_event_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "unknown"})
|
||||
|
||||
|
||||
def test_map_event_missing_required_field_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "message_delta", "message_id": "m1"})
|
||||
@@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent.orchestrator import AgentChatOrchestrator
|
||||
|
||||
|
||||
async def _intent_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("intent")
|
||||
return {
|
||||
"content": f"intent:{message}",
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
|
||||
}
|
||||
|
||||
|
||||
async def _execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
return {
|
||||
"content": f"execution:{message}",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
|
||||
}
|
||||
|
||||
|
||||
async def _organization_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("organization")
|
||||
return {
|
||||
"content": "final answer",
|
||||
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
|
||||
}
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution", "organization"]
|
||||
assert result.output == "final answer"
|
||||
assert result.usage["total_tokens"] == 13
|
||||
assert result.events[0]["type"] == "run.started"
|
||||
assert result.events[-1]["type"] == "run.completed"
|
||||
|
||||
|
||||
async def _failing_execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_failing_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution"]
|
||||
assert result.events[-1]["type"] == "run.failed"
|
||||
assert result.events[-1]["run_id"] == "run-2"
|
||||
assert "boom" in (result.events[-1].get("error") or "")
|
||||
assert result.failed is True
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
def test_orchestrator_emits_stage_event_payload_shape() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
|
||||
|
||||
for event in result.events:
|
||||
assert "type" in event
|
||||
assert event.get("run_id") == "run-3"
|
||||
|
||||
stage_events = [
|
||||
event for event in result.events if event["type"] == "stage.completed"
|
||||
]
|
||||
assert [event["stage"] for event in stage_events] == [
|
||||
"intent",
|
||||
"execution",
|
||||
"organization",
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from v1.agent.service import build_session_title
|
||||
|
||||
|
||||
def test_build_session_title_truncates_first_message() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title(
|
||||
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
|
||||
)
|
||||
|
||||
assert len(title) == 24
|
||||
|
||||
|
||||
def test_build_session_title_falls_back_when_message_empty() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title("\n ", now=now)
|
||||
|
||||
assert title == "新对话 2026-02-25 10:30"
|
||||
@@ -1,132 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.crewai.template_loader import (
|
||||
load_crewai_template,
|
||||
load_tools_whitelist,
|
||||
validate_workflow_stages,
|
||||
)
|
||||
|
||||
|
||||
def _write(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _prepare_static_root(root: Path) -> Path:
|
||||
_write(
|
||||
root / "agents.yaml",
|
||||
"""
|
||||
intent:
|
||||
role: Intent Agent
|
||||
execution:
|
||||
role: Execution Agent
|
||||
organization:
|
||||
role: Organization Agent
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "tasks.yaml",
|
||||
"""
|
||||
intent:
|
||||
description: classify
|
||||
execution:
|
||||
description: run task
|
||||
organization:
|
||||
description: summarize
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- doc_extract
|
||||
""".strip(),
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
|
||||
template = load_crewai_template(static_root)
|
||||
|
||||
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
|
||||
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
|
||||
assert template.workflow["stages"] == ["intent", "execution", "organization"]
|
||||
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
(static_root / "tasks.yaml").unlink()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
_write(
|
||||
static_root / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- execution
|
||||
- intent
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
|
||||
whitelist = load_tools_whitelist(static_root)
|
||||
|
||||
assert whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
|
||||
validate_workflow_stages(["intent", "execution", "organization"])
|
||||
|
||||
|
||||
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution"])
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution", "organization", "extra"])
|
||||
|
||||
|
||||
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path)
|
||||
_write(
|
||||
static_root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- 123
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_tools_whitelist(static_root)
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from core.config.initial import init_data
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
|
||||
catalog_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "src"
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "llm_catalog.yaml"
|
||||
)
|
||||
|
||||
catalog = init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
assert len(catalog["factories"]) == 6
|
||||
assert len(catalog["llms"]) == 2
|
||||
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
|
||||
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_name"}
|
||||
|
||||
|
||||
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
|
||||
catalog_path = tmp_path / "llm_catalog.yaml"
|
||||
catalog_path.write_text(
|
||||
"""
|
||||
factories:
|
||||
- name: qwen
|
||||
llms:
|
||||
- model_code: qwen3.5-flash
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
|
||||
first = await init_data.initialize_data()
|
||||
second = await init_data.initialize_data()
|
||||
|
||||
assert first is True
|
||||
assert second is True
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 6
|
||||
assert llm_count == 2
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
monkeypatch.setattr(
|
||||
init_data,
|
||||
"load_llm_catalog",
|
||||
lambda *_: {
|
||||
"factories": [
|
||||
{
|
||||
"name": "qwen",
|
||||
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"avatar": "https://cdn.example.com/qwen.png",
|
||||
}
|
||||
],
|
||||
"llms": [
|
||||
{
|
||||
"model_code": "qwen3.5-flash",
|
||||
"factory_id": "missing_factory",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await init_data.initialize_data()
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 0
|
||||
assert llm_count == 0
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def test_user_agent_catalog_file_exists_and_has_required_fields() -> None:
|
||||
catalog_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "src"
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "user_agent_catalog.yaml"
|
||||
)
|
||||
|
||||
assert catalog_path.exists(), f"Catalog file not found: {catalog_path}"
|
||||
|
||||
catalog = init_data.load_user_agent_catalog(catalog_path)
|
||||
|
||||
assert "agents" in catalog
|
||||
assert isinstance(catalog["agents"], list)
|
||||
assert len(catalog["agents"]) == 3
|
||||
|
||||
for agent in catalog["agents"]:
|
||||
assert "agent_type" in agent
|
||||
assert "llm_model_code" in agent
|
||||
assert "status" in agent
|
||||
assert "config" in agent
|
||||
assert isinstance(agent["config"], dict)
|
||||
|
||||
|
||||
def test_load_user_agent_catalog_raises_on_invalid_structure(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
catalog_path = tmp_path / "user_agent_catalog.yaml"
|
||||
catalog_path.write_text(
|
||||
"""
|
||||
agents:
|
||||
- agent_type: TEST
|
||||
llm_model_code: test-model
|
||||
status: ACTIVE
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user agent catalog"):
|
||||
init_data.load_user_agent_catalog(catalog_path)
|
||||
@@ -1,27 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_initial_migration_exists_and_creates_expected_tables() -> None:
|
||||
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
|
||||
migration_files = sorted(versions_dir.glob("20260226_*.py"))
|
||||
assert len(migration_files) == 5, "split initial migrations should exist"
|
||||
|
||||
content = "\n".join(m.read_text(encoding="utf-8") for m in migration_files)
|
||||
|
||||
# New tables from social data model redesign
|
||||
assert "create_table(" in content and "automation_jobs" in content
|
||||
assert "create_table(" in content and "user_agents" in content
|
||||
assert "create_table(" in content and "memories" in content
|
||||
assert "create_table(" in content and "friendships" in content
|
||||
assert "create_table(" in content and "groups" in content
|
||||
assert "create_table(" in content and "group_members" in content
|
||||
assert "create_table(" in content and "schedule_items" in content
|
||||
assert "create_table(" in content and "schedule_subscriptions" in content
|
||||
assert "create_table(" in content and "inbox_messages" in content
|
||||
assert "create_table(" in content and "todos" in content
|
||||
assert "create_table(" in content and "todo_sources" in content
|
||||
assert "create_table(" in content and "profiles" in content
|
||||
assert "create_table(" in content and "sessions" in content
|
||||
assert "create_table(" in content and "messages" in content
|
||||
@@ -1,119 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
|
||||
factory = LlmFactory(
|
||||
name="qwen",
|
||||
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
avatar="https://cdn.example.com/qwen.png",
|
||||
)
|
||||
db_session.add(factory)
|
||||
await db_session.flush()
|
||||
|
||||
llm = Llm(
|
||||
factory_id=factory.id,
|
||||
model_code="qwen3.5-flash",
|
||||
)
|
||||
db_session.add(llm)
|
||||
await db_session.commit()
|
||||
|
||||
found_llm = await db_session.get(Llm, llm.id)
|
||||
assert found_llm is not None
|
||||
assert found_llm.factory_id == factory.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_status_supports_required_values(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
session = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title="test",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(session)
|
||||
await db_session.commit()
|
||||
|
||||
statuses = [
|
||||
AgentChatSessionStatus.PENDING,
|
||||
AgentChatSessionStatus.RUNNING,
|
||||
AgentChatSessionStatus.COMPLETED,
|
||||
AgentChatSessionStatus.FAILED,
|
||||
]
|
||||
for status in statuses:
|
||||
session.status = status
|
||||
await db_session.commit()
|
||||
await db_session.refresh(session)
|
||||
assert session.status == status
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
|
||||
user_id = uuid4()
|
||||
session = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title="tool test",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(session)
|
||||
await db_session.flush()
|
||||
|
||||
message = AgentChatMessage(
|
||||
session_id=session.id,
|
||||
seq=1,
|
||||
role="tool",
|
||||
content="tool output",
|
||||
cost=0,
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
|
||||
)
|
||||
found = result.scalar_one()
|
||||
assert found.role == "tool"
|
||||
@@ -1,80 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
class TestAgentSecurityRules:
|
||||
def test_tool_name_must_be_allowlisted(self):
|
||||
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
|
||||
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
|
||||
|
||||
def test_tool_name_rejected_if_not_in_namespace(self):
|
||||
try:
|
||||
validate_tool_spec(
|
||||
{"name": "malicious.tool", "execution_target": "frontend"}
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Should have raised ValueError for unknown namespace")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_result_fails_when_interrupt_mismatch(self):
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, session_obj: AgentChatSession) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(self._session_obj)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession(session), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-other",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is False
|
||||
@@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.crewai_flow import AgentFlow
|
||||
|
||||
|
||||
class TestCrewAIFlow:
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_stages_run_in_order(self):
|
||||
flow = AgentFlow()
|
||||
await flow.run()
|
||||
assert flow.state.stage_trace == ["intent", "execution", "reporting"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_state_initialized(self):
|
||||
flow = AgentFlow()
|
||||
assert flow.state.stage_trace == []
|
||||
assert flow.state.current_stage is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_updates_current_stage(self):
|
||||
flow = AgentFlow()
|
||||
await flow.run()
|
||||
assert flow.state.current_stage == "reporting"
|
||||
@@ -1,187 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||
self.last_fetch_with_lock = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
if isinstance(obj, AgentChatSession):
|
||||
self._sessions[obj.id] = obj
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
self.last_fetch_with_lock = "FOR UPDATE" in str(stmt)
|
||||
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(next(iter(self._sessions.values()), None))
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db() -> FakeAsyncSession:
|
||||
return FakeAsyncSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
|
||||
sess = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
fake_db.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestResumeIdempotency:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_is_idempotent(
|
||||
self,
|
||||
service: AgentChatService,
|
||||
session: AgentChatSession,
|
||||
fake_db: FakeAsyncSession,
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
first = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
second = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert first.applied is True
|
||||
assert second.applied is False
|
||||
assert fake_db.last_fetch_with_lock is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_updates_status_to_approved(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is True
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||
assert snapshot["pending_tool_call"]["decision"] == {"decision": "approved"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_updates_status_to_rejected(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-3",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-3",
|
||||
decision={"decision": "rejected"},
|
||||
)
|
||||
|
||||
assert result.applied is True
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "REJECTED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_expired_pending_marks_expired_and_not_applied(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) - timedelta(seconds=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-expired",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
result = await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-expired",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
assert result.applied is False
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "EXPIRED"
|
||||
@@ -1,127 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.schemas import AgentSessionSnapshot, RunAgentInput
|
||||
|
||||
|
||||
class TestRunAgentInput:
|
||||
def test_requires_full_fields(self):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
}
|
||||
model = RunAgentInput.model_validate(payload)
|
||||
assert model.threadId == "t1"
|
||||
assert model.runId == "r1"
|
||||
assert model.parentRunId is None
|
||||
assert model.state == {}
|
||||
assert model.messages == []
|
||||
assert model.tools == []
|
||||
assert model.context == []
|
||||
assert model.forwardedProps == {}
|
||||
assert model.resume is None
|
||||
|
||||
def test_resume_optional(self):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r2",
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
model = RunAgentInput.model_validate(payload)
|
||||
assert model.resume is not None
|
||||
assert model.resume["interruptId"] == "int-1"
|
||||
assert model.resume["payload"]["decision"] == "approved"
|
||||
|
||||
def test_parent_run_id_optional(self):
|
||||
payload = {
|
||||
"threadId": "t1",
|
||||
"runId": "r3",
|
||||
"parentRunId": "p1",
|
||||
"state": {"key": "value"},
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"tools": [{"name": "ui.navigate_to"}],
|
||||
"context": [{"type": "user", "id": "u1"}],
|
||||
"forwardedProps": {"theme": "dark"},
|
||||
}
|
||||
model = RunAgentInput.model_validate(payload)
|
||||
assert model.parentRunId == "p1"
|
||||
assert model.state == {"key": "value"}
|
||||
assert len(model.messages) == 1
|
||||
assert model.messages[0]["role"] == "user"
|
||||
|
||||
|
||||
class TestAgentSessionSnapshot:
|
||||
def test_state_snapshot_v2_model_accepts_valid_payload(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
model = AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
assert model.version == 2
|
||||
assert model.pending_tool_call is not None
|
||||
assert model.pending_tool_call.interrupt_id == "int-1"
|
||||
assert model.pending_tool_call.updated_at == datetime(
|
||||
2026, 3, 3, 11, 59, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
def test_state_snapshot_v2_rejects_wrong_version(self):
|
||||
payload = {
|
||||
"version": 1,
|
||||
"pending_tool_call": None,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
def test_state_snapshot_v2_requires_pending_tool_call_key(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
|
||||
def test_state_snapshot_v2_rejects_extra_fields(self):
|
||||
payload = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00Z",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00Z",
|
||||
"unexpected": True,
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1", "foo": "bar"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AgentSessionSnapshot.model_validate(payload)
|
||||
@@ -1,168 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self._sessions: dict[UUID, AgentChatSession] = {}
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
if isinstance(obj, AgentChatSession):
|
||||
self._sessions[obj.id] = obj
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
pass
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
return _Result(next(iter(self._sessions.values()), None))
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db() -> FakeAsyncSession:
|
||||
return FakeAsyncSession()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(fake_db: FakeAsyncSession) -> AgentChatSession:
|
||||
sess = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
)
|
||||
fake_db.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(fake_db: FakeAsyncSession) -> AgentChatService:
|
||||
return AgentChatService(fake_db, current_user=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestPendingToolCall:
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_pending_tool_call_to_state_snapshot(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-1",
|
||||
tool_name="srv.transfer_funds",
|
||||
tool_args={"to": "u2", "amount": 100},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is not None
|
||||
assert snapshot["version"] == 2
|
||||
assert snapshot["run_context"]["thread_id"] == "t1"
|
||||
assert snapshot["run_context"]["run_id"] == "r1"
|
||||
assert snapshot["pending_tool_call"]["status"] == "PENDING_APPROVAL"
|
||||
assert snapshot["pending_tool_call"]["interrupt_id"] == "int-1"
|
||||
assert snapshot["pending_tool_call"]["tool_name"] == "srv.transfer_funds"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_returns_none_when_empty(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pending_tool_call_status(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
await service.set_pending_tool_call(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
tool_name="srv.delete_file",
|
||||
tool_args={"file_id": "f1"},
|
||||
expires_at=expires_at,
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
)
|
||||
|
||||
await service.update_pending_tool_call_status(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-2",
|
||||
status="APPROVED_EXECUTING",
|
||||
)
|
||||
|
||||
snapshot = await service.get_state_snapshot(session.id)
|
||||
assert snapshot["pending_tool_call"]["status"] == "APPROVED_EXECUTING"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_legacy_snapshot_is_rejected(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
session.state_snapshot = {"pending_tool_call": {"status": "PENDING_APPROVAL"}}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-legacy",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_snapshot_rejects_naive_datetime(
|
||||
self, service: AgentChatService, session: AgentChatSession
|
||||
):
|
||||
session.state_snapshot = {
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-naive",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2026-03-03T12:00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": "2026-03-03T11:59:00",
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": "r1"},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await service.apply_resume_decision(
|
||||
session_id=session.id,
|
||||
interrupt_id="int-naive",
|
||||
decision={"decision": "approved"},
|
||||
)
|
||||
@@ -1,126 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_session import (
|
||||
AgentChatSession,
|
||||
AgentChatSessionStatus,
|
||||
SessionType,
|
||||
)
|
||||
from v1.agent.schemas import RunAgentInput
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, sessions: list[AgentChatSession]) -> None:
|
||||
self._sessions = {session.id: session for session in sessions}
|
||||
self.commit_called = False
|
||||
|
||||
async def execute(self, stmt: object):
|
||||
class _Result:
|
||||
def __init__(self, session_obj: AgentChatSession | None) -> None:
|
||||
self._session_obj = session_obj
|
||||
|
||||
def scalar_one_or_none(self) -> AgentChatSession | None:
|
||||
return self._session_obj
|
||||
|
||||
for session in self._sessions.values():
|
||||
return _Result(session)
|
||||
return _Result(None)
|
||||
|
||||
async def scalar(self, stmt: object) -> AgentChatSession | None:
|
||||
for session in self._sessions.values():
|
||||
return session
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.commit_called = True
|
||||
|
||||
|
||||
def _build_input(run_id: str) -> RunAgentInput:
|
||||
return RunAgentInput.model_validate(
|
||||
{
|
||||
"threadId": "t1",
|
||||
"runId": run_id,
|
||||
"state": {},
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {},
|
||||
"resume": {"interruptId": "int-1", "payload": {"decision": "approved"}},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_resume_rejects_non_owner_session() -> None:
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": datetime.now(timezone.utc).isoformat(),
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
service = AgentChatService(
|
||||
session=FakeAsyncSession([session]), # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_resume_commits_expired_state_before_410() -> None:
|
||||
owner_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
state_snapshot={
|
||||
"version": 2,
|
||||
"pending_tool_call": {
|
||||
"interrupt_id": "int-1",
|
||||
"tool_name": "srv.transfer_funds",
|
||||
"tool_args": {"to": "u2", "amount": 100},
|
||||
"status": "PENDING_APPROVAL",
|
||||
"expires_at": "2000-01-01T00:00:00+00:00",
|
||||
"decision": None,
|
||||
"result": None,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"run_context": {"thread_id": "t1", "run_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
fake_db = FakeAsyncSession([session])
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.prepare_resume(str(session.id), _build_input(str(session.id)))
|
||||
|
||||
assert exc_info.value.status_code == 410
|
||||
assert fake_db.commit_called is True
|
||||
@@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from v1.agent.tool_dispatcher import (
|
||||
BackendExecutionResult,
|
||||
InterruptResult,
|
||||
ToolDispatcher,
|
||||
dispatch_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestToolDispatcher:
|
||||
def test_frontend_tool_returns_interrupt(self):
|
||||
tool = {
|
||||
"name": "ui.navigate_to",
|
||||
"execution_target": "frontend",
|
||||
"args": {"path": "/home"},
|
||||
}
|
||||
result = dispatch_tool_call(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
assert result.interrupt_type == "tool_execution"
|
||||
assert result.tool_name == "ui.navigate_to"
|
||||
|
||||
def test_backend_tool_executes_directly(self):
|
||||
tool = {
|
||||
"name": "srv.get_user_info",
|
||||
"execution_target": "backend",
|
||||
"args": {"user_id": "u1"},
|
||||
"requires_approval": False,
|
||||
}
|
||||
result = dispatch_tool_call(tool)
|
||||
assert isinstance(result, BackendExecutionResult)
|
||||
assert result.tool_name == "srv.get_user_info"
|
||||
|
||||
def test_backend_tool_with_approval_returns_interrupt(self):
|
||||
tool = {
|
||||
"name": "srv.transfer_funds",
|
||||
"execution_target": "backend",
|
||||
"args": {"to": "u2", "amount": 100},
|
||||
"requires_approval": True,
|
||||
}
|
||||
result = dispatch_tool_call(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
assert result.interrupt_type == "approval_required"
|
||||
assert result.tool_name == "srv.transfer_funds"
|
||||
|
||||
def test_dispatcher_class_can_dispatch(self):
|
||||
dispatcher = ToolDispatcher()
|
||||
tool = {
|
||||
"name": "ui.navigate_to",
|
||||
"execution_target": "frontend",
|
||||
"args": {"message": "Hello"},
|
||||
}
|
||||
result = dispatcher.dispatch(tool)
|
||||
assert isinstance(result, InterruptResult)
|
||||
|
||||
def test_unknown_frontend_tool_is_rejected(self):
|
||||
tool = {
|
||||
"name": "ui.unknown_action",
|
||||
"execution_target": "frontend",
|
||||
"args": {},
|
||||
}
|
||||
with pytest.raises(ValueError, match="not in allowlist"):
|
||||
dispatch_tool_call(tool)
|
||||
@@ -1,27 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from v1.agent.tool_registry import validate_tool_spec
|
||||
|
||||
|
||||
class TestValidateToolSpec:
|
||||
def test_ui_namespace_must_be_frontend(self):
|
||||
with pytest.raises(ValueError, match="ui.* must use frontend target"):
|
||||
validate_tool_spec(
|
||||
{"name": "ui.navigate_to", "execution_target": "backend"}
|
||||
)
|
||||
|
||||
def test_srv_namespace_must_be_backend(self):
|
||||
with pytest.raises(ValueError, match="srv.* must use backend target"):
|
||||
validate_tool_spec(
|
||||
{"name": "srv.search_docs", "execution_target": "frontend"}
|
||||
)
|
||||
|
||||
def test_ui_namespace_with_frontend_is_valid(self):
|
||||
validate_tool_spec({"name": "ui.navigate_to", "execution_target": "frontend"})
|
||||
|
||||
def test_srv_namespace_with_backend_is_valid(self):
|
||||
validate_tool_spec({"name": "srv.search_docs", "execution_target": "backend"})
|
||||
|
||||
def test_other_namespace_is_rejected(self):
|
||||
with pytest.raises(ValueError, match="must be in ui.* or srv.* namespace"):
|
||||
validate_tool_spec({"name": "other.tool", "execution_target": "frontend"})
|
||||
@@ -1,196 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agent.orchestrator import OrchestratorResult
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from v1.agent.schemas import AgentChatRunRequest
|
||||
from v1.agent.service import AgentChatService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_creates_session_and_persists_messages(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
result = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert result.session_id is not None
|
||||
assert result.output == "hello"
|
||||
assert [event.type for event in result.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, result.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 2
|
||||
assert session_obj.status.value == "completed"
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == result.session_id)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = rows.scalars().all()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role.value == "user"
|
||||
assert messages[1].role.value == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
first = await service.run(AgentChatRunRequest(message="first"))
|
||||
second = await service.run(
|
||||
AgentChatRunRequest(message="second", session_id=first.session_id)
|
||||
)
|
||||
|
||||
assert second.session_id == first.session_id
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, first.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _FailingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage={
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": Decimal("0"),
|
||||
"currency": "USD",
|
||||
},
|
||||
events=[],
|
||||
context={},
|
||||
failed=True,
|
||||
error="stage failed",
|
||||
)
|
||||
|
||||
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
|
||||
)
|
||||
stored_session = rows.scalars().one()
|
||||
assert stored_session.status.value == "failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message=" "))
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
|
||||
db_session: AsyncSession,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
async def _fail_commit() -> None:
|
||||
raise SQLAlchemyError("db down")
|
||||
|
||||
monkeypatch.setattr(db_session, "commit", _fail_commit)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_502_for_unexpected_exception(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _CrashingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
raise RuntimeError("unexpected")
|
||||
|
||||
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
Reference in New Issue
Block a user