feat(agent-chat): complete core workflow and strengthen auth rate limiting

This commit is contained in:
qzl
2026-02-25 16:51:12 +08:00
parent 53c72e48e6
commit cd40b2b4f4
62 changed files with 3441 additions and 3 deletions
+7 -1
View File
@@ -17,7 +17,13 @@ if str(src_path) not in sys.path:
from core.config.settings import config # noqa: E402
from core.db.base import Base # noqa: E402
from models import Profile # noqa: F401,E402
from models import ( # noqa: F401,E402
AgentChatMessage,
AgentChatSession,
Llm,
LlmFactory,
Profile,
)
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
@@ -0,0 +1,195 @@
"""create_agent_chat_core_tables
Revision ID: 20260226_agent_chat_core
Revises: 20260224_drop_profile
Create Date: 2026-02-26 10:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "20260226_agent_chat_core"
down_revision: Union[str, Sequence[str], None] = "20260224_drop_profile"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"llm_factory",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(length=50), nullable=False),
sa.Column("request_url", sa.String(length=255), nullable=False),
sa.Column("avatar", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id", name="pk_llm_factory"),
sa.UniqueConstraint("name", name="uq_llm_factory_name"),
)
op.create_index("ix_llm_factory_name", "llm_factory", ["name"])
op.create_index("ix_llm_factory_deleted_at", "llm_factory", ["deleted_at"])
op.create_table(
"llms",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("factory_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("model_code", sa.String(length=50), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["factory_id"], ["llm_factory.id"], ondelete="RESTRICT"
),
sa.PrimaryKeyConstraint("id", name="pk_llms"),
sa.UniqueConstraint("model_code", name="uq_llms_model_code"),
)
op.create_index("ix_llms_factory_id", "llms", ["factory_id"])
op.create_index("ix_llms_model_code", "llms", ["model_code"])
op.create_index("ix_llms_deleted_at", "llms", ["deleted_at"])
op.create_table(
"sessions",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("title", sa.String(length=255), nullable=True),
sa.Column(
"status",
sa.Enum(
"pending",
"running",
"completed",
"failed",
name="agent_chat_session_status",
native_enum=False,
),
nullable=False,
server_default="pending",
),
sa.Column(
"last_activity_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("message_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_cost", sa.Numeric(12, 6), nullable=False, server_default="0"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["auth.users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id", name="pk_sessions"),
)
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
op.create_index(
"ix_sessions_user_last_activity",
"sessions",
["user_id", "last_activity_at"],
)
op.create_index("ix_sessions_deleted_at", "sessions", ["deleted_at"])
op.create_table(
"messages",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("seq", sa.Integer(), nullable=False),
sa.Column(
"role",
sa.Enum(
"user",
"assistant",
"system",
"tool",
name="agent_chat_message_role",
native_enum=False,
),
nullable=False,
),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("model_code", sa.String(length=50), nullable=True),
sa.Column("tool_name", sa.String(length=100), nullable=True),
sa.Column("input_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("output_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("cost", sa.Numeric(12, 6), nullable=False, server_default="0"),
sa.Column(
"currency", sa.String(length=3), nullable=False, server_default="USD"
),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["session_id"], ["sessions.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id", name="pk_messages"),
sa.UniqueConstraint("session_id", "seq", name="uq_messages_session_seq"),
)
op.create_index("ix_messages_session_id", "messages", ["session_id"])
op.create_index("ix_messages_session_role", "messages", ["session_id", "role"])
op.create_index("ix_messages_deleted_at", "messages", ["deleted_at"])
def downgrade() -> None:
op.drop_index("ix_messages_deleted_at", table_name="messages")
op.drop_index("ix_messages_session_role", table_name="messages")
op.drop_index("ix_messages_session_id", table_name="messages")
op.drop_table("messages")
op.drop_index("ix_sessions_deleted_at", table_name="sessions")
op.drop_index("ix_sessions_user_last_activity", table_name="sessions")
op.drop_index("ix_sessions_user_id", table_name="sessions")
op.drop_table("sessions")
op.drop_index("ix_llms_deleted_at", table_name="llms")
op.drop_index("ix_llms_model_code", table_name="llms")
op.drop_index("ix_llms_factory_id", table_name="llms")
op.drop_table("llms")
op.drop_index("ix_llm_factory_deleted_at", table_name="llm_factory")
op.drop_index("ix_llm_factory_name", table_name="llm_factory")
op.drop_table("llm_factory")
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,20 @@
from __future__ import annotations
from typing import Any
from core.agent_chat.event_bridge import map_internal_event
class AguiAdapter:
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
message = payload.get("message")
if not isinstance(message, str) or not message.strip():
raise ValueError("message is required")
return {
"message": message,
"session_id": payload.get("session_id"),
}
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
return map_internal_event(event)
@@ -0,0 +1,69 @@
from __future__ import annotations
from decimal import Decimal
from typing import Any, Mapping
def _to_non_negative_int(value: Any, *, field: str) -> int:
if isinstance(value, bool):
raise ValueError(f"{field} must be an integer")
if isinstance(value, int):
converted = value
elif isinstance(value, str) and value.isdigit():
converted = int(value)
else:
raise ValueError(f"{field} must be an integer")
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
converted = Decimal(str(value))
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
class CostTracker:
def __init__(self, *, currency: str = "USD") -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
self._currency = currency
def add_usage(self, usage: Mapping[str, Any]) -> None:
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
total_tokens = usage.get("total_tokens")
cost = usage.get("cost", "0")
currency = usage.get("currency")
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
normalized_total = (
_to_non_negative_int(total_tokens, field="total_tokens")
if total_tokens is not None
else normalized_input + normalized_output
)
normalized_cost = _to_non_negative_decimal(cost, field="cost")
self._input_tokens += normalized_input
self._output_tokens += normalized_output
self._total_tokens += normalized_total
self._cost += normalized_cost
if currency is not None:
normalized_currency = str(currency)
if normalized_currency != self._currency:
raise ValueError("currency mismatch")
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": self._cost,
"currency": self._currency,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,82 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class CrewAITemplate:
agents: dict[str, Any]
tasks: dict[str, Any]
workflow: dict[str, Any]
prompts: dict[str, str]
tools_whitelist: set[str]
def _default_static_root() -> Path:
return Path(__file__).resolve().parents[3] / "config" / "static" / "agent_chat"
def _read_yaml(file_path: Path) -> dict[str, Any]:
if not file_path.is_file():
raise FileNotFoundError(f"Required config file not found: {file_path}")
with file_path.open("r", encoding="utf-8") as file:
loaded = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"YAML file must be a mapping: {file_path}")
return loaded
def _read_prompt(file_path: Path) -> str:
if not file_path.is_file():
raise FileNotFoundError(f"Required prompt file not found: {file_path}")
return file_path.read_text(encoding="utf-8").strip()
def validate_workflow_stages(stages: list[str]) -> None:
expected = ["intent", "execution", "organization"]
if stages != expected:
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
root = static_root or _default_static_root()
tools = _read_yaml(root / "tools.yaml")
raw_tools = tools.get("tools", [])
if not isinstance(raw_tools, list):
raise ValueError("tools.yaml field 'tools' must be a list")
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
raise ValueError("tools.yaml list items must be non-empty strings")
whitelist = {item.strip() for item in raw_tools}
return whitelist
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
root = static_root or _default_static_root()
crewai_root = root / "crewai"
agents = _read_yaml(crewai_root / "agents.yaml")
tasks = _read_yaml(crewai_root / "tasks.yaml")
workflow = _read_yaml(crewai_root / "workflow.yaml")
stages = workflow.get("stages")
if not isinstance(stages, list):
raise ValueError("workflow.yaml field 'stages' must be a list")
validate_workflow_stages([str(stage) for stage in stages])
prompts = {
"intent": _read_prompt(crewai_root / "prompts" / "intent.md"),
"execution": _read_prompt(crewai_root / "prompts" / "execution.md"),
"organization": _read_prompt(crewai_root / "prompts" / "organization.md"),
}
return CrewAITemplate(
agents=agents,
tasks=tasks,
workflow=workflow,
prompts=prompts,
tools_whitelist=load_tools_whitelist(root),
)
@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Any
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
missing = [field for field in required if field not in event]
if missing:
raise ValueError(f"Missing fields for {kind}: {missing}")
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
kind = event.get("kind")
if kind == "run_started":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.started",
"run_id": event["session_id"],
}
if kind == "message_delta":
_require_fields(event, kind=kind, required=["message_id", "delta"])
return {
"type": "message.delta",
"message_id": event["message_id"],
"delta": event["delta"],
}
if kind == "tool_started":
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
return {
"type": "tool.started",
"message_id": event["message_id"],
"tool_name": event["tool_name"],
}
if kind == "tool_completed":
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
return {
"type": "tool.completed",
"message_id": event["message_id"],
"tool_name": event["tool_name"],
"result": event.get("result"),
}
if kind == "run_completed":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.completed",
"run_id": event["session_id"],
"output": event.get("output", ""),
}
if kind == "run_failed":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.failed",
"run_id": event["session_id"],
"error": event.get("error", ""),
}
raise ValueError(f"Unsupported event kind: {kind}")
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Any
def run_started(*, run_id: str) -> dict[str, Any]:
return {"type": "run.started", "run_id": run_id}
def stage_completed(
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
) -> dict[str, Any]:
event: dict[str, Any] = {
"type": "stage.completed",
"run_id": run_id,
"stage": stage,
}
if usage is not None:
event["usage"] = usage
return event
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
return {
"type": "run.completed",
"run_id": run_id,
"output": output,
"usage": usage,
}
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
return {
"type": "run.failed",
"run_id": run_id,
"error": error,
}
+112
View File
@@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass
import hashlib
from pathlib import Path
from typing import Protocol
from core.agent_chat.storage_adapter import StorageAdapter
_ALLOWED_MIME_TYPES = {
"audio/mpeg",
"audio/wav",
"audio/x-wav",
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
"text/plain",
}
class _AsrTool(Protocol):
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
@dataclass(frozen=True)
class AttachmentInput:
filename: str
mime_type: str
content: bytes
origin: str = "user_upload"
@dataclass(frozen=True)
class ProcessedAttachmentContext:
attachments: list[dict[str, object]]
preview_texts: list[str]
class MultimodalProcessor:
def __init__(
self,
*,
storage: StorageAdapter,
asr_tool: _AsrTool,
max_file_size_mb: int = 20,
) -> None:
self._storage = storage
self._asr_tool = asr_tool
self._max_size_bytes = max_file_size_mb * 1024 * 1024
def process(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
attachments: list[AttachmentInput],
) -> ProcessedAttachmentContext:
metadata_list: list[dict[str, object]] = []
preview_texts: list[str] = []
for attachment in attachments:
self._validate_attachment(attachment)
checksum = hashlib.sha256(attachment.content).hexdigest()
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
object_path = self._storage.build_object_path(
user_id=user_id,
session_id=session_id,
message_seq=message_seq,
checksum_sha256=checksum,
extension=extension,
)
preview_text = self._build_preview_text(attachment)
if preview_text:
preview_texts.append(preview_text)
metadata = self._storage.build_attachment_metadata(
object_path=object_path,
mime_type=attachment.mime_type,
size=len(attachment.content),
checksum_sha256=checksum,
origin=attachment.origin,
preview_text=preview_text,
)
metadata_list.append(metadata)
return ProcessedAttachmentContext(
attachments=metadata_list,
preview_texts=preview_texts,
)
def _validate_attachment(self, attachment: AttachmentInput) -> None:
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
raise ValueError("Unsupported MIME type")
if len(attachment.content) > self._max_size_bytes:
raise ValueError("Attachment exceeds max file size")
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
if attachment.mime_type.startswith("audio/"):
result = self._asr_tool.transcribe(
audio_bytes=attachment.content,
filename=attachment.filename,
)
text = result.get("text")
if isinstance(text, str):
return text
return None
if attachment.mime_type == "text/plain":
return attachment.content.decode("utf-8", errors="ignore")[:200]
return None
@@ -0,0 +1,88 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from core.agent_chat.cost_tracker import CostTracker
from core.agent_chat import events
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
@dataclass(frozen=True)
class OrchestratorResult:
output: str
usage: dict[str, Any]
events: list[dict[str, Any]]
context: dict[str, Any]
failed: bool
error: str | None
class AgentChatOrchestrator:
def __init__(
self,
*,
intent_stage: StageCallable,
execution_stage: StageCallable,
organization_stage: StageCallable,
) -> None:
self._intent_stage = intent_stage
self._execution_stage = execution_stage
self._organization_stage = organization_stage
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
tracker = CostTracker()
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
context: dict[str, Any] = {}
stage_pipeline: list[tuple[str, StageCallable]] = [
("intent", self._intent_stage),
("execution", self._execution_stage),
("organization", self._organization_stage),
]
stage_output = user_message
try:
for stage_name, stage_callable in stage_pipeline:
stage_result = await stage_callable(
message=stage_output, context=context
)
stage_output = str(stage_result.get("content", stage_output))
usage = stage_result.get("usage", {})
if isinstance(usage, dict):
tracker.add_usage(usage)
emitted_events.append(
events.stage_completed(
run_id=run_id,
stage=stage_name,
usage=tracker.snapshot(),
)
)
except Exception as exc: # noqa: BLE001
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
return OrchestratorResult(
output="",
usage=tracker.snapshot(),
events=emitted_events,
context=context,
failed=True,
error=str(exc),
)
summary = tracker.snapshot()
emitted_events.append(
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
)
return OrchestratorResult(
output=stage_output,
usage=summary,
events=emitted_events,
context=context,
failed=False,
error=None,
)
@@ -0,0 +1,46 @@
from __future__ import annotations
class StorageAdapter:
_bucket: str
def __init__(self, bucket: str) -> None:
self._bucket = bucket
@property
def bucket(self) -> str:
return self._bucket
def build_object_path(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
checksum_sha256: str,
extension: str,
) -> str:
normalized_ext = extension.strip(".").lower()
return (
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
f"{checksum_sha256}.{normalized_ext}"
)
def build_attachment_metadata(
self,
*,
object_path: str,
mime_type: str,
size: int,
checksum_sha256: str,
origin: str,
preview_text: str | None = None,
) -> dict[str, object]:
return {
"object_path": object_path,
"mime_type": mime_type,
"size": size,
"checksum_sha256": checksum_sha256,
"origin": origin,
"preview_text": preview_text,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,40 @@
from __future__ import annotations
import importlib
from collections.abc import Callable
from typing import Any
TranscribeCallable = Callable[..., dict[str, Any]]
class FunASRTool:
_transcribe_callable: TranscribeCallable
_model: str
def __init__(
self,
transcribe_callable: TranscribeCallable | None = None,
model: str = "fun-asr-realtime-2025-11-07",
) -> None:
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
self._model = model
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
return {
"model": self._model,
**payload,
}
def _dashscope_transcribe(
self, *, audio_bytes: bytes, filename: str
) -> dict[str, Any]:
try:
importlib.import_module("dashscope")
except ImportError as exc:
raise RuntimeError("DashScope SDK is not installed") from exc
raise RuntimeError(
"DashScope transcribe runtime integration is not configured yet"
)
+9
View File
@@ -132,6 +132,14 @@ class SupabaseSettings(BaseModel):
return self.public_url
class StorageSettings(BaseModel):
provider: Literal["supabase"] = "supabase"
bucket: str = Field(default="agent-chat-attachments", min_length=3, max_length=63)
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
max_file_size_mb: int = Field(default=20, ge=1, le=200)
retention_days: int = Field(default=30, ge=1, le=3650)
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
@@ -163,6 +171,7 @@ class Settings(BaseSettings):
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
supabase: SupabaseSettings = SupabaseSettings()
storage: StorageSettings = StorageSettings()
celery: CelerySettings = CelerySettings()
database: DatabaseSettings = DatabaseSettings()
@@ -0,0 +1,9 @@
intent:
role: Intent Agent
goal: Classify user intent and decide execution strategy
execution:
role: Execution Agent
goal: Execute tasks with available tools
organization:
role: Organization Agent
goal: Organize output for user-friendly response
@@ -0,0 +1,2 @@
你是任务执行代理。
基于输入意图与上下文调用可用工具,并生成可验证执行结果。
@@ -0,0 +1,2 @@
你是意图识别代理。
你的任务是识别用户输入的意图类型,并返回结构化意图标签。
@@ -0,0 +1,2 @@
你是结果整理代理。
将执行结果组织为面向用户的清晰回复,保留关键信息与必要引用。
@@ -0,0 +1,6 @@
intent:
description: Identify user intent and required capabilities
execution:
description: Execute intent with tools and model calls
organization:
description: Format final response and references
@@ -0,0 +1,9 @@
stages:
- intent
- execution
- organization
timeouts:
intent_seconds: 8
execution_seconds: 30
organization_seconds: 10
@@ -0,0 +1,25 @@
factories:
- name: qwen
request_url: https://dashscope.aliyuncs.com/compatible-mode/v1
avatar: https://cdn.simpleicons.org/alibabacloud/FF6A00
- name: minimax
request_url: https://api.minimax.chat/v1
avatar: https://cdn.simpleicons.org/minimax/1A1A1A
- name: kimi
request_url: https://api.moonshot.cn/v1
avatar: https://cdn.simpleicons.org/moonrepo/3B82F6
- name: deepseek
request_url: https://api.deepseek.com/v1
avatar: https://cdn.simpleicons.org/deepseek/4D6BFE
- name: doubao
request_url: https://ark.cn-beijing.volces.com/api/v3
avatar: https://cdn.simpleicons.org/volkswagen/001E50
- name: zai
request_url: https://api.z.ai/v1
avatar: https://cdn.simpleicons.org/zotero/CC2936
llms:
- model_code: qwen3.5-flash
factory_id: qwen
- model_code: deepseek-v3.2
factory_id: deepseek
@@ -0,0 +1,3 @@
tools:
- asr_fun_asr
- attachment_extract
+126 -1
View File
@@ -1,9 +1,134 @@
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel, ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.llm import Llm
from models.llm_factory import LlmFactory
logger = get_logger("core.initialization.init_data")
class LlmFactorySeed(BaseModel):
name: str
request_url: str
avatar: str | None = None
class LlmSeed(BaseModel):
model_code: str
factory_id: str
class LlmCatalogSeed(BaseModel):
factories: list[LlmFactorySeed]
llms: list[LlmSeed]
def _default_catalog_path() -> Path:
return (
Path(__file__).resolve().parents[1]
/ "config"
/ "static"
/ "agent_chat"
/ "llm_catalog.yaml"
)
def load_llm_catalog(catalog_path: Path | None = None) -> dict[str, Any]:
path = catalog_path or _default_catalog_path()
with path.open("r", encoding="utf-8") as file:
loaded = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"Invalid LLM catalog format: {path}")
raw_factories = loaded.get("factories", [])
raw_llms = loaded.get("llms", [])
if not isinstance(raw_factories, list) or not isinstance(raw_llms, list):
raise ValueError(f"Invalid LLM catalog sections: {path}")
try:
parsed = LlmCatalogSeed.model_validate(
{
"factories": list(raw_factories),
"llms": list(raw_llms),
}
)
except ValidationError as exc:
raise ValueError(f"Invalid LLM catalog data: {path}") from exc
return parsed.model_dump()
async def _upsert_factory(
session: AsyncSession,
*,
name: str,
request_url: str,
avatar: str | None,
) -> uuid.UUID:
result = await session.execute(select(LlmFactory).where(LlmFactory.name == name))
factory = result.scalar_one_or_none()
if factory is None:
factory = LlmFactory(name=name, request_url=request_url, avatar=avatar)
session.add(factory)
await session.flush()
else:
factory.request_url = request_url
factory.avatar = avatar
return factory.id
async def _upsert_llm(
session: AsyncSession,
*,
model_code: str,
factory_id: uuid.UUID,
) -> None:
result = await session.execute(select(Llm).where(Llm.model_code == model_code))
llm = result.scalar_one_or_none()
if llm is None:
session.add(Llm(model_code=model_code, factory_id=factory_id))
return
llm.factory_id = factory_id
async def initialize_data() -> bool:
"""Initialize bootstrap data."""
logger.info("Initializing data (no-op)")
catalog = load_llm_catalog()
async with AsyncSessionLocal() as session:
async with session.begin():
factory_id_by_name: dict[str, uuid.UUID] = {}
for factory in catalog["factories"]:
factory_id = await _upsert_factory(
session,
name=factory["name"],
request_url=factory["request_url"],
avatar=factory.get("avatar"),
)
factory_id_by_name[factory["name"]] = factory_id
for llm in catalog["llms"]:
factory_name = llm["factory_id"]
resolved_factory_id = factory_id_by_name.get(factory_name)
if resolved_factory_id is None:
raise RuntimeError(
f"Factory '{factory_name}' not found for model '{llm['model_code']}'"
)
await _upsert_llm(
session,
model_code=llm["model_code"],
factory_id=resolved_factory_id,
)
logger.info("Initialized LLM factory/model seed data")
return True
+11 -1
View File
@@ -1,5 +1,15 @@
from __future__ import annotations
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.profile import Profile
__all__ = ["Profile"]
__all__ = [
"AgentChatMessage",
"AgentChatSession",
"Llm",
"LlmFactory",
"Profile",
]
+62
View File
@@ -0,0 +1,62 @@
from __future__ import annotations
from decimal import Decimal
import uuid
from enum import Enum
from sqlalchemy import (
JSON,
Enum as SqlEnum,
ForeignKey,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class AgentChatMessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL = "tool"
class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "messages"
__table_args__: tuple[UniqueConstraint] = (
UniqueConstraint("session_id", "seq", name="uq_messages_session_seq"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
seq: Mapped[int] = mapped_column(Integer, nullable=False)
role: Mapped[AgentChatMessageRole] = mapped_column(
SqlEnum(
AgentChatMessageRole, name="agent_chat_message_role", native_enum=False
),
nullable=False,
)
content: Mapped[str] = mapped_column(Text, nullable=False)
model_code: Mapped[str | None] = mapped_column(String(50), nullable=True)
tool_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
input_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
currency: Mapped[str] = mapped_column(String(3), nullable=False, default="USD")
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
)
+64
View File
@@ -0,0 +1,64 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
import uuid
from enum import Enum
from sqlalchemy import (
DateTime,
Enum as SqlEnum,
ForeignKey,
Integer,
Numeric,
String,
func,
text,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class AgentChatSessionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "sessions"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("auth.users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
status: Mapped[AgentChatSessionStatus] = mapped_column(
SqlEnum(
AgentChatSessionStatus, name="agent_chat_session_status", native_enum=False
),
nullable=False,
default=AgentChatSessionStatus.PENDING,
)
last_activity_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
message_count: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
total_tokens: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
total_cost: Mapped[Decimal] = mapped_column(
Numeric(12, 6), nullable=False, server_default=text("0")
)
+26
View File
@@ -0,0 +1,26 @@
from __future__ import annotations
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class Llm(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "llms"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
factory_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("llm_factory.id", ondelete="RESTRICT"),
nullable=False,
index=True,
)
model_code: Mapped[str] = mapped_column(
String(50), nullable=False, unique=True, index=True
)
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
import uuid
from sqlalchemy import String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class LlmFactory(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "llm_factory"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(
String(50), nullable=False, unique=True, index=True
)
request_url: Mapped[str] = mapped_column(String(255), nullable=False)
avatar: Mapped[str | None] = mapped_column(Text, nullable=True)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.agent_chat.service import AgentChatService
from v1.profile.dependencies import get_current_user
def get_agent_chat_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AgentChatService:
return AgentChatService(session=session, current_user=user)
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import AgentChatRunRequest, AgentChatRunResponse
from v1.agent_chat.service import AgentChatService
router = APIRouter(prefix="/agent-chat", tags=["agent-chat"])
@router.post("/run", response_model=AgentChatRunResponse)
async def run_agent_chat(
payload: AgentChatRunRequest,
service: Annotated[AgentChatService, Depends(get_agent_chat_service)],
) -> AgentChatRunResponse:
return await service.run(payload)
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, Field
class AgentChatRunRequest(BaseModel):
message: str = Field(min_length=1, max_length=8000)
session_id: UUID | None = None
class AgentChatEvent(BaseModel):
type: str
run_id: str | None = None
message_id: str | None = None
delta: str | None = None
tool_name: str | None = None
result: str | None = None
output: str | None = None
error: str | None = None
class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
+286
View File
@@ -0,0 +1,286 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import func, select
from sqlalchemy.exc import SQLAlchemyError
from core.agent_chat.agui_adapter import AguiAdapter
from core.agent_chat.orchestrator import AgentChatOrchestrator
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.agent_chat.service")
def build_session_title(first_message: str, *, now: datetime) -> str:
title = first_message.strip().replace("\n", " ")[:24]
if not title:
return now.strftime("新对话 %Y-%m-%d %H:%M")
return title
def aggregate_session_cost(costs: list[Decimal]) -> Decimal:
total = Decimal("0")
for cost in costs:
if cost < 0:
raise ValueError("cost must be non-negative")
total += cost
return total
def select_recent_session(
sessions: list[AgentChatSession],
) -> AgentChatSession | None:
if not sessions:
return None
return max(sessions, key=lambda item: item.last_activity_at)
class AgentChatService(BaseService):
_session: AsyncSession
def __init__(self, session: AsyncSession, current_user: CurrentUser | None) -> None:
super().__init__(current_user=current_user)
self._session = session
self._adapter = AguiAdapter()
self._orchestrator = AgentChatOrchestrator(
intent_stage=self._intent_stage,
execution_stage=self._execution_stage,
organization_stage=self._organization_stage,
)
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
try:
command = self._adapter.to_command(payload.model_dump(mode="python"))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_chat_run",
identifier=str(user_id),
limit=60,
window_seconds=60,
)
now = datetime.now(timezone.utc)
try:
chat_session = await self._resolve_session(
session_id=command["session_id"],
user_id=user_id,
first_message=command["message"],
now=now,
)
base_seq = await self._next_seq_base(chat_session.id)
user_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 1,
role=AgentChatMessageRole.USER,
content=command["message"],
cost=Decimal("0"),
)
orchestrator_result = await self._orchestrator.run(
run_id=str(chat_session.id),
user_message=command["message"],
)
assistant_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 2,
role=AgentChatMessageRole.ASSISTANT,
content=orchestrator_result.output,
input_tokens=int(orchestrator_result.usage["input_tokens"]),
output_tokens=int(orchestrator_result.usage["output_tokens"]),
cost=Decimal(orchestrator_result.usage["cost"]),
)
self._session.add(user_message)
self._session.add(assistant_message)
chat_session.status = (
AgentChatSessionStatus.FAILED
if orchestrator_result.failed
else AgentChatSessionStatus.COMPLETED
)
chat_session.last_activity_at = now
chat_session.message_count = chat_session.message_count + 2
chat_session.total_tokens = chat_session.total_tokens + int(
orchestrator_result.usage["total_tokens"]
)
chat_session.total_cost = aggregate_session_cost(
[
Decimal(chat_session.total_cost),
Decimal(orchestrator_result.usage["cost"]),
]
)
await self._session.commit()
await self._session.refresh(chat_session)
await self._session.refresh(user_message)
mapped_events = self._build_mapped_events(
session_id=str(chat_session.id),
message_id=str(user_message.id),
user_message=command["message"],
assistant_output=assistant_message.content,
failed=orchestrator_result.failed,
error=orchestrator_result.error,
)
events = [AgentChatEvent.model_validate(item) for item in mapped_events]
if orchestrator_result.failed:
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
)
return AgentChatRunResponse(
session_id=chat_session.id,
output=assistant_message.content,
events=events,
)
except HTTPException:
await self._session.rollback()
raise
except SQLAlchemyError:
await self._session.rollback()
logger.exception("Agent chat run failed")
raise HTTPException(status_code=503, detail="Agent chat store unavailable")
except Exception as exc: # noqa: BLE001
await self._session.rollback()
logger.exception(
"Agent chat unexpected failure", error_type=type(exc).__name__
)
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
) from exc
def _build_mapped_events(
self,
*,
session_id: str,
message_id: str,
user_message: str,
assistant_output: str,
failed: bool,
error: str | None,
) -> list[dict[str, object]]:
mapped_events = [
self._adapter.to_protocol_event(
{
"kind": "run_started",
"session_id": session_id,
}
),
self._adapter.to_protocol_event(
{
"kind": "message_delta",
"message_id": message_id,
"delta": user_message,
}
),
]
if failed:
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_failed",
"session_id": session_id,
"error": error or "orchestration failed",
}
)
)
return mapped_events
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": session_id,
"output": assistant_output,
}
)
)
return mapped_events
async def _resolve_session(
self,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
if session_id is not None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.where(AgentChatSession.user_id == user_id)
.where(AgentChatSession.deleted_at.is_(None))
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
existing = result.scalar_one_or_none()
if existing is None:
raise HTTPException(status_code=404, detail="Session not found")
existing.status = AgentChatSessionStatus.RUNNING
return existing
title = build_session_title(first_message, now=now)
created = AgentChatSession(
user_id=user_id,
title=title,
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
)
self._session.add(created)
await self._session.flush()
return created
async def _next_seq_base(self, session_id: object) -> int:
stmt = select(func.max(AgentChatMessage.seq)).where(
AgentChatMessage.session_id == session_id
)
result = await self._session.scalar(stmt)
if result is None:
return 0
return int(result)
async def _intent_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
context["intent"] = "default"
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def _execution_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def _organization_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
+18
View File
@@ -74,6 +74,12 @@ async def login(
payload: LoginRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
await enforce_rate_limit(
scope="login",
identifier=payload.email,
limit=10,
window_seconds=60,
)
return await service.login(payload)
@@ -82,6 +88,12 @@ async def refresh(
payload: RefreshRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
await enforce_rate_limit(
scope="refresh",
identifier=payload.refresh_token,
limit=10,
window_seconds=60,
)
return await service.refresh(payload)
@@ -90,6 +102,12 @@ async def logout(
payload: LogoutRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="logout",
identifier=payload.refresh_token,
limit=10,
window_seconds=60,
)
await service.logout(payload.refresh_token)
return Response(status_code=204)
+2
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent_chat.router import router as agent_chat_router
from v1.auth.router import router as auth_router
from v1.infra.router import router as infra_router
from v1.profile.router import router as profile_router
@@ -12,6 +13,7 @@ router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(infra_router)
router.include_router(profile_router)
router.include_router(agent_chat_router)
@router.get("/health", response_model=HealthResponse)
+98
View File
@@ -0,0 +1,98 @@
from __future__ import annotations
import json
import socket
import threading
import time
from uuid import UUID
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent_chat.service import AgentChatService
class FakeE2EAgentChatService(AgentChatService):
def __init__(self) -> None:
return None
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
session_id = payload.session_id or UUID("00000000-0000-0000-0000-000000000001")
return AgentChatRunResponse(
session_id=session_id,
output=payload.message,
events=[
AgentChatEvent(type="run.started", run_id=str(session_id)),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed", run_id=str(session_id), output=payload.message
),
],
)
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_agent_chat_flow_e2e() -> None:
app.dependency_overrides[get_agent_chat_service] = lambda: FakeE2EAgentChatService()
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
response = request_context.post(
"/api/v1/agent-chat/run",
data=json.dumps({"message": "hello"}),
headers={"Content-Type": "application/json"},
)
assert response.status == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
request_context.dispose()
finally:
app.dependency_overrides = {}
server.should_exit = True
thread.join(timeout=5)
@@ -0,0 +1,38 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.service import select_recent_session
def test_recent_session_home_default_selection() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-0000000000a1"),
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 8, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=100,
total_cost=Decimal("0.010000"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-0000000000a2"),
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=3,
total_tokens=120,
total_cost=Decimal("0.020000"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-0000000000a2")
@@ -0,0 +1,97 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from types import MethodType
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.schemas import AgentChatRunRequest
from v1.agent_chat.service import AgentChatService
class _FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self.committed = False
self.rolled_back = False
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
return None
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def refresh(self, obj: object) -> None:
if isinstance(obj, AgentChatSession) and obj.id is None:
obj.id = uuid4()
if isinstance(obj, AgentChatMessage) and obj.id is None:
obj.id = uuid4()
@pytest.mark.asyncio
async def test_run_persists_messages_and_emits_ordered_events() -> None:
fake_db = _FakeAsyncSession()
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
async def _resolve_session(
self: AgentChatService,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
assert session_id is None
assert first_message == "hello"
return AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000111"),
user_id=user_id,
title="hello",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
message_count=0,
total_tokens=0,
total_cost=Decimal("0"),
created_at=now,
updated_at=now,
deleted_at=None,
)
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
return 2
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
response = await service.run(AgentChatRunRequest(message="hello"))
assert fake_db.committed is True
inserted_messages = [
item for item in fake_db.added if isinstance(item, AgentChatMessage)
]
assert len(inserted_messages) == 2
assert [msg.seq for msg in inserted_messages] == [3, 4]
assert [msg.role for msg in inserted_messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
]
assert [event.type for event in response.events] == [
"run.started",
"message.delta",
"run.completed",
]
@@ -0,0 +1,78 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
from fastapi.testclient import TestClient
from app import app
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent_chat.service import AgentChatService
class FakeAgentChatService:
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
return AgentChatRunResponse(
session_id=UUID("00000000-0000-0000-0000-000000000001"),
output=payload.message,
events=[
AgentChatEvent(
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed",
run_id="00000000-0000-0000-0000-000000000001",
output=payload.message,
),
],
)
def _override_agent_chat_service(
service: FakeAgentChatService,
) -> Callable[[], AgentChatService]:
def _get_service() -> AgentChatService:
return service # type: ignore[return-value]
return _get_service
def test_run_route_returns_response() -> None:
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat/run", json={"message": "hello"})
assert response.status_code == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
app.dependency_overrides = {}
def test_run_route_validates_payload() -> None:
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat/run", json={"message": ""})
assert response.status_code == 422
finally:
app.dependency_overrides = {}
@@ -0,0 +1,20 @@
from __future__ import annotations
from decimal import Decimal
from v1.agent_chat.service import aggregate_session_cost
def test_aggregate_session_cost_sums_non_negative_values() -> None:
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
assert total == Decimal("0.012500")
def test_aggregate_session_cost_rejects_negative_value() -> None:
try:
aggregate_session_cost([Decimal("-0.010000")])
raised = False
except ValueError:
raised = True
assert raised is True
@@ -0,0 +1,42 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.service import select_recent_session
def test_select_recent_session_uses_last_activity_desc() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000001"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=1,
total_tokens=1,
total_cost=Decimal("0"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000002"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=2,
total_cost=Decimal("0"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
def test_select_recent_session_returns_none_for_empty_collection() -> None:
assert select_recent_session([]) is None
@@ -416,6 +416,108 @@ def test_logout_returns_no_content() -> None:
app.dependency_overrides = {}
def test_login_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_refresh_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_logout_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
ok = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert ok.status_code == 204
blocked = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_signup_start_validation_error_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
@@ -0,0 +1,40 @@
from __future__ import annotations
import pytest
from core.agent_chat.agui_adapter import AguiAdapter
def test_to_command_maps_payload_fields() -> None:
adapter = AguiAdapter()
command = adapter.to_command(
{
"message": "hello",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
assert command["message"] == "hello"
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
def test_to_protocol_event_maps_internal_event() -> None:
adapter = AguiAdapter()
mapped = adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": "run-1",
"output": "done",
}
)
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
def test_to_protocol_event_raises_for_invalid_event() -> None:
adapter = AguiAdapter()
with pytest.raises(ValueError):
adapter.to_protocol_event({"kind": "unknown"})
@@ -0,0 +1,30 @@
from __future__ import annotations
import pytest
from core.agent_chat.tools.asr_fun_asr import FunASRTool
def test_transcribe_uses_injected_dashscope_callable() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert filename == "voice.wav"
assert audio_bytes == b"audio"
return {"text": "你好", "request_id": "req-1"}
tool = FunASRTool(transcribe_callable=fake_transcribe)
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
assert result["text"] == "你好"
assert result["request_id"] == "req-1"
assert result["model"] == "fun-asr-realtime-2025-11-07"
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
raise RuntimeError("upstream timeout")
tool = FunASRTool(transcribe_callable=fake_transcribe)
with pytest.raises(RuntimeError):
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
@@ -0,0 +1,82 @@
from __future__ import annotations
from decimal import Decimal
import pytest
from core.agent_chat.cost_tracker import CostTracker
def test_normalize_usage_and_cost_aggregation() -> None:
tracker = CostTracker()
tracker.add_usage(
{
"prompt_tokens": 7,
"completion_tokens": 5,
"cost": "0.002500",
}
)
tracker.add_usage(
{
"input_tokens": 5,
"output_tokens": 3,
"cost": "0.003000",
"currency": "USD",
}
)
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 12
assert snapshot["output_tokens"] == 8
assert snapshot["total_tokens"] == 20
assert snapshot["cost"] == Decimal("0.005500")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_negative_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": -1})
with pytest.raises(ValueError):
tracker.add_usage({"cost": "-0.010000"})
def test_snapshot_is_zero_before_any_usage() -> None:
tracker = CostTracker()
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 0
assert snapshot["output_tokens"] == 0
assert snapshot["total_tokens"] == 0
assert snapshot["cost"] == Decimal("0")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_currency_mismatch() -> None:
tracker = CostTracker(currency="USD")
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
with pytest.raises(ValueError):
tracker.add_usage(
{
"input_tokens": 1,
"output_tokens": 1,
"cost": "0.001000",
"currency": "CNY",
}
)
def test_add_usage_rejects_non_integral_token_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": 1.5})
with pytest.raises(ValueError):
tracker.add_usage({"output_tokens": True})
@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from core.agent_chat.event_bridge import map_internal_event
def test_map_run_started_event() -> None:
event = {"kind": "run_started", "session_id": "s1"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.started", "run_id": "s1"}
def test_map_message_delta_event() -> None:
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
mapped = map_internal_event(event)
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
def test_map_tool_events() -> None:
started = {
"kind": "tool_started",
"message_id": "m2",
"tool_name": "asr_fun_asr",
}
completed = {
"kind": "tool_completed",
"message_id": "m2",
"tool_name": "asr_fun_asr",
"result": "ok",
}
mapped_started = map_internal_event(started)
mapped_completed = map_internal_event(completed)
assert mapped_started["type"] == "tool.started"
assert mapped_started["tool_name"] == "asr_fun_asr"
assert mapped_completed["type"] == "tool.completed"
assert mapped_completed["result"] == "ok"
def test_map_run_completed_event() -> None:
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
def test_map_unknown_event_raises() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "unknown"})
def test_map_event_missing_required_field_raises_value_error() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "message_delta", "message_id": "m1"})
@@ -0,0 +1,89 @@
from __future__ import annotations
import pytest
from core.agent_chat.multimodal import AttachmentInput, MultimodalProcessor
from core.agent_chat.storage_adapter import StorageAdapter
from core.agent_chat.tools.asr_fun_asr import FunASRTool
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert audio_bytes == b"audio"
assert filename == "voice.wav"
return {"text": "hello world", "request_id": "req-1"}
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
max_file_size_mb=1,
)
result = processor.process(
user_id="u1",
session_id="s1",
message_seq=4,
attachments=[
AttachmentInput(
filename="voice.wav",
mime_type="audio/wav",
content=b"audio",
)
],
)
assert len(result.attachments) == 1
metadata = result.attachments[0]
assert (
metadata["object_path"]
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
)
assert metadata["mime_type"] == "audio/wav"
assert result.preview_texts == ["hello world"]
def test_multimodal_rejects_unsupported_mime_type() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage, asr_tool=FunASRTool(lambda **_: {})
)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="malware.exe",
mime_type="application/octet-stream",
content=b"bad",
)
],
)
def test_multimodal_rejects_attachment_over_max_size() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(lambda **_: {}),
max_file_size_mb=1,
)
oversized = b"x" * (1024 * 1024 + 1)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="big.wav",
mime_type="audio/wav",
content=oversized,
)
],
)
@@ -0,0 +1,104 @@
from __future__ import annotations
from core.agent_chat.orchestrator import AgentChatOrchestrator
async def _intent_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("intent")
return {
"content": f"intent:{message}",
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
}
async def _execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
return {
"content": f"execution:{message}",
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
}
async def _organization_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("organization")
return {
"content": "final answer",
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
}
def test_orchestrator_runs_three_stages_in_order() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
assert result.context["sequence"] == ["intent", "execution", "organization"]
assert result.output == "final answer"
assert result.usage["total_tokens"] == 13
assert result.events[0]["type"] == "run.started"
assert result.events[-1]["type"] == "run.completed"
async def _failing_execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
raise RuntimeError("boom")
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_failing_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
assert result.context["sequence"] == ["intent", "execution"]
assert result.events[-1]["type"] == "run.failed"
assert result.events[-1]["run_id"] == "run-2"
assert "boom" in (result.events[-1].get("error") or "")
assert result.failed is True
assert "boom" in (result.error or "")
def test_orchestrator_emits_stage_event_payload_shape() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
for event in result.events:
assert "type" in event
assert event.get("run_id") == "run-3"
stage_events = [
event for event in result.events if event["type"] == "stage.completed"
]
assert [event["stage"] for event in stage_events] == [
"intent",
"execution",
"organization",
]
@@ -0,0 +1,23 @@
from __future__ import annotations
from datetime import datetime
from v1.agent_chat.service import build_session_title
def test_build_session_title_truncates_first_message() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title(
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
)
assert len(title) == 24
def test_build_session_title_falls_back_when_message_empty() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title("\n ", now=now)
assert title == "新对话 2026-02-25 10:30"
@@ -0,0 +1,37 @@
from __future__ import annotations
from core.agent_chat.storage_adapter import StorageAdapter
def test_build_object_path_uses_expected_pattern() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
path = adapter.build_object_path(
user_id="u1",
session_id="s1",
message_seq=3,
checksum_sha256="abc123",
extension="wav",
)
assert path == "agent-chat/u1/s1/3/abc123.wav"
def test_build_attachment_metadata_contains_required_fields() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
metadata = adapter.build_attachment_metadata(
object_path="agent-chat/u1/s1/3/abc123.wav",
mime_type="audio/wav",
size=1024,
checksum_sha256="abc123",
origin="user_upload",
preview_text="hello",
)
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
assert metadata["mime_type"] == "audio/wav"
assert metadata["size"] == 1024
assert metadata["checksum_sha256"] == "abc123"
assert metadata["origin"] == "user_upload"
assert metadata["preview_text"] == "hello"
@@ -0,0 +1,138 @@
from __future__ import annotations
from pathlib import Path
import pytest
from core.agent_chat.crewai.template_loader import (
load_crewai_template,
load_tools_whitelist,
validate_workflow_stages,
)
def _write(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
def _prepare_static_root(root: Path) -> Path:
_write(
root / "crewai" / "agents.yaml",
"""
intent:
role: Intent Agent
execution:
role: Execution Agent
organization:
role: Organization Agent
""".strip(),
)
_write(
root / "crewai" / "tasks.yaml",
"""
intent:
description: classify
execution:
description: run task
organization:
description: summarize
""".strip(),
)
_write(
root / "crewai" / "workflow.yaml",
"""
stages:
- intent
- execution
- organization
""".strip(),
)
_write(root / "crewai" / "prompts" / "intent.md", "intent prompt")
_write(root / "crewai" / "prompts" / "execution.md", "execution prompt")
_write(root / "crewai" / "prompts" / "organization.md", "organization prompt")
_write(
root / "tools.yaml",
"""
tools:
- asr_fun_asr
- doc_extract
""".strip(),
)
return root
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
template = load_crewai_template(static_root)
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
assert template.workflow["stages"] == ["intent", "execution", "organization"]
assert template.prompts["intent"] == "intent prompt"
assert template.prompts["execution"] == "execution prompt"
assert template.prompts["organization"] == "organization prompt"
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
tmp_path: Path,
) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
(static_root / "crewai" / "tasks.yaml").unlink()
with pytest.raises(FileNotFoundError):
load_crewai_template(static_root)
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
tmp_path: Path,
) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
_write(
static_root / "crewai" / "workflow.yaml",
"""
stages:
- execution
- intent
- organization
""".strip(),
)
with pytest.raises(ValueError):
load_crewai_template(static_root)
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
whitelist = load_tools_whitelist(static_root)
assert whitelist == {"asr_fun_asr", "doc_extract"}
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
validate_workflow_stages(["intent", "execution", "organization"])
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
with pytest.raises(ValueError):
validate_workflow_stages(["intent", "execution"])
with pytest.raises(ValueError):
validate_workflow_stages(["intent", "execution", "organization", "extra"])
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
_write(
static_root / "tools.yaml",
"""
tools:
- asr_fun_asr
- 123
""".strip(),
)
with pytest.raises(ValueError):
load_tools_whitelist(static_root)
@@ -0,0 +1,143 @@
from __future__ import annotations
from pathlib import Path
import pytest
from sqlalchemy import Column, String, Table, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from core.initialization import init_data
from models.llm import Llm
from models.llm_factory import LlmFactory
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
catalog_path = (
Path(__file__).resolve().parents[3]
/ "src"
/ "core"
/ "config"
/ "static"
/ "agent_chat"
/ "llm_catalog.yaml"
)
catalog = init_data.load_llm_catalog(catalog_path)
assert len(catalog["factories"]) == 6
assert len(catalog["llms"]) == 2
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_id"}
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
catalog_path = tmp_path / "llm_catalog.yaml"
catalog_path.write_text(
"""
factories:
- name: qwen
llms:
- model_code: qwen3.5-flash
""".strip(),
encoding="utf-8",
)
with pytest.raises(ValueError):
init_data.load_llm_catalog(catalog_path)
@pytest.mark.asyncio
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
session_maker = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
first = await init_data.initialize_data()
second = await init_data.initialize_data()
assert first is True
assert second is True
async with session_maker() as session:
factory_count = await session.scalar(
select(func.count()).select_from(LlmFactory)
)
llm_count = await session.scalar(select(func.count()).select_from(Llm))
assert factory_count == 6
assert llm_count == 2
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.mark.asyncio
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
monkeypatch: pytest.MonkeyPatch,
) -> None:
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
session_maker = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
monkeypatch.setattr(
init_data,
"load_llm_catalog",
lambda *_: {
"factories": [
{
"name": "qwen",
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"avatar": "https://cdn.example.com/qwen.png",
}
],
"llms": [
{
"model_code": "qwen3.5-flash",
"factory_id": "missing_factory",
}
],
},
)
with pytest.raises(RuntimeError):
await init_data.initialize_data()
async with session_maker() as session:
factory_count = await session.scalar(
select(func.count()).select_from(LlmFactory)
)
llm_count = await session.scalar(select(func.count()).select_from(Llm))
assert factory_count == 0
assert llm_count == 0
Base.metadata.remove(users_table)
await engine.dispose()
@@ -0,0 +1,17 @@
from __future__ import annotations
from pathlib import Path
def test_agent_chat_migration_exists_and_creates_expected_tables() -> None:
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
migration = versions_dir / "20260226_create_agent_chat_core_tables.py"
assert migration.exists()
content = migration.read_text(encoding="utf-8")
assert 'create_table(\n "llm_factory"' in content
assert 'create_table(\n "llms"' in content
assert 'create_table(\n "sessions"' in content
assert 'create_table(\n "messages"' in content
assert "tool_calls" not in content
@@ -0,0 +1,119 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
factory = LlmFactory(
name="qwen",
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
avatar="https://cdn.example.com/qwen.png",
)
db_session.add(factory)
await db_session.flush()
llm = Llm(
factory_id=factory.id,
model_code="qwen3.5-flash",
)
db_session.add(llm)
await db_session.commit()
found_llm = await db_session.get(Llm, llm.id)
assert found_llm is not None
assert found_llm.factory_id == factory.id
@pytest.mark.asyncio
async def test_session_status_supports_required_values(
db_session: AsyncSession,
) -> None:
user_id = uuid4()
session = AgentChatSession(
user_id=user_id,
title="test",
status="pending",
)
db_session.add(session)
await db_session.commit()
statuses = [
AgentChatSessionStatus.PENDING,
AgentChatSessionStatus.RUNNING,
AgentChatSessionStatus.COMPLETED,
AgentChatSessionStatus.FAILED,
]
for status in statuses:
session.status = status
await db_session.commit()
await db_session.refresh(session)
assert session.status == status
@pytest.mark.asyncio
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
user_id = uuid4()
session = AgentChatSession(
user_id=user_id,
title="tool test",
status="pending",
)
db_session.add(session)
await db_session.flush()
message = AgentChatMessage(
session_id=session.id,
seq=1,
role="tool",
content="tool output",
cost=0,
)
db_session.add(message)
await db_session.commit()
result = await db_session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
)
found = result.scalar_one()
assert found.role == "tool"
@@ -0,0 +1,34 @@
from __future__ import annotations
from pydantic import ValidationError
import pytest
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_social_prefixed_storage_env_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "supabase")
monkeypatch.setenv("SOCIAL_STORAGE__BUCKET", "agent-chat-attachments")
monkeypatch.setenv("SOCIAL_STORAGE__SIGNED_URL_TTL_SECONDS", "900")
monkeypatch.setenv("SOCIAL_STORAGE__MAX_FILE_SIZE_MB", "25")
monkeypatch.setenv("SOCIAL_STORAGE__RETENTION_DAYS", "45")
settings = Settings()
assert settings.storage.provider == "supabase"
assert settings.storage.bucket == "agent-chat-attachments"
assert settings.storage.signed_url_ttl_seconds == 900
assert settings.storage.max_file_size_mb == 25
assert settings.storage.retention_days == 45
def test_storage_settings_validation_rejects_invalid_provider(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "s3")
with pytest.raises(ValidationError):
Settings()
@@ -0,0 +1,196 @@
from __future__ import annotations
from decimal import Decimal
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.agent_chat.orchestrator import OrchestratorResult
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from v1.agent_chat.schemas import AgentChatRunRequest
from v1.agent_chat.service import AgentChatService
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_run_creates_session_and_persists_messages(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
result = await service.run(AgentChatRunRequest(message="hello"))
assert result.session_id is not None
assert result.output == "hello"
assert [event.type for event in result.events] == [
"run.started",
"message.delta",
"run.completed",
]
session_obj = await db_session.get(AgentChatSession, result.session_id)
assert session_obj is not None
assert session_obj.message_count == 2
assert session_obj.status.value == "completed"
rows = await db_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == result.session_id)
.order_by(AgentChatMessage.seq.asc())
)
messages = rows.scalars().all()
assert len(messages) == 2
assert messages[0].role.value == "user"
assert messages[1].role.value == "assistant"
@pytest.mark.asyncio
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
first = await service.run(AgentChatRunRequest(message="first"))
second = await service.run(
AgentChatRunRequest(message="second", session_id=first.session_id)
)
assert second.session_id == first.session_id
session_obj = await db_session.get(AgentChatSession, first.session_id)
assert session_obj is not None
assert session_obj.message_count == 4
@pytest.mark.asyncio
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _FailingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return OrchestratorResult(
output="",
usage={
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"cost": Decimal("0"),
"currency": "USD",
},
events=[],
context={},
failed=True,
error="stage failed",
)
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502
rows = await db_session.execute(
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
)
stored_session = rows.scalars().one()
assert stored_session.status.value == "failed"
@pytest.mark.asyncio
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message=" "))
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
db_session: AsyncSession,
monkeypatch: pytest.MonkeyPatch,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
async def _fail_commit() -> None:
raise SQLAlchemyError("db down")
monkeypatch.setattr(db_session, "commit", _fail_commit)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 503
@pytest.mark.asyncio
async def test_run_returns_502_for_unexpected_exception(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _CrashingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
raise RuntimeError("unexpected")
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502