feat(agent-chat): complete core workflow and strengthen auth rate limiting

This commit is contained in:
qzl
2026-02-25 16:51:12 +08:00
parent 53c72e48e6
commit cd40b2b4f4
62 changed files with 3441 additions and 3 deletions
+9
View File
@@ -134,3 +134,12 @@ SOCIAL_SUPABASE__FUNCTIONS_VERIFY_JWT=false
SOCIAL_SUPABASE__IMGPROXY_ENABLE_WEBP_DETECTION=true
SOCIAL_SUPABASE__STORAGE_BUCKET_PUBLIC=public
SOCIAL_SUPABASE__STORAGE_BUCKET_PRIVATE=private
############
# Agent Chat 附件存储配置(仅基础设施变量)
############
SOCIAL_STORAGE__PROVIDER=supabase
SOCIAL_STORAGE__BUCKET=agent-chat-attachments
SOCIAL_STORAGE__SIGNED_URL_TTL_SECONDS=600
SOCIAL_STORAGE__MAX_FILE_SIZE_MB=20
SOCIAL_STORAGE__RETENTION_DAYS=30
+7 -1
View File
@@ -17,7 +17,13 @@ if str(src_path) not in sys.path:
from core.config.settings import config # noqa: E402
from core.db.base import Base # noqa: E402
from models import Profile # noqa: F401,E402
from models import ( # noqa: F401,E402
AgentChatMessage,
AgentChatSession,
Llm,
LlmFactory,
Profile,
)
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
@@ -0,0 +1,195 @@
"""create_agent_chat_core_tables
Revision ID: 20260226_agent_chat_core
Revises: 20260224_drop_profile
Create Date: 2026-02-26 10:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "20260226_agent_chat_core"
down_revision: Union[str, Sequence[str], None] = "20260224_drop_profile"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"llm_factory",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(length=50), nullable=False),
sa.Column("request_url", sa.String(length=255), nullable=False),
sa.Column("avatar", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id", name="pk_llm_factory"),
sa.UniqueConstraint("name", name="uq_llm_factory_name"),
)
op.create_index("ix_llm_factory_name", "llm_factory", ["name"])
op.create_index("ix_llm_factory_deleted_at", "llm_factory", ["deleted_at"])
op.create_table(
"llms",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("factory_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("model_code", sa.String(length=50), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["factory_id"], ["llm_factory.id"], ondelete="RESTRICT"
),
sa.PrimaryKeyConstraint("id", name="pk_llms"),
sa.UniqueConstraint("model_code", name="uq_llms_model_code"),
)
op.create_index("ix_llms_factory_id", "llms", ["factory_id"])
op.create_index("ix_llms_model_code", "llms", ["model_code"])
op.create_index("ix_llms_deleted_at", "llms", ["deleted_at"])
op.create_table(
"sessions",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("title", sa.String(length=255), nullable=True),
sa.Column(
"status",
sa.Enum(
"pending",
"running",
"completed",
"failed",
name="agent_chat_session_status",
native_enum=False,
),
nullable=False,
server_default="pending",
),
sa.Column(
"last_activity_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("message_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_cost", sa.Numeric(12, 6), nullable=False, server_default="0"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["auth.users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id", name="pk_sessions"),
)
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
op.create_index(
"ix_sessions_user_last_activity",
"sessions",
["user_id", "last_activity_at"],
)
op.create_index("ix_sessions_deleted_at", "sessions", ["deleted_at"])
op.create_table(
"messages",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("seq", sa.Integer(), nullable=False),
sa.Column(
"role",
sa.Enum(
"user",
"assistant",
"system",
"tool",
name="agent_chat_message_role",
native_enum=False,
),
nullable=False,
),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("model_code", sa.String(length=50), nullable=True),
sa.Column("tool_name", sa.String(length=100), nullable=True),
sa.Column("input_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("output_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("cost", sa.Numeric(12, 6), nullable=False, server_default="0"),
sa.Column(
"currency", sa.String(length=3), nullable=False, server_default="USD"
),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["session_id"], ["sessions.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id", name="pk_messages"),
sa.UniqueConstraint("session_id", "seq", name="uq_messages_session_seq"),
)
op.create_index("ix_messages_session_id", "messages", ["session_id"])
op.create_index("ix_messages_session_role", "messages", ["session_id", "role"])
op.create_index("ix_messages_deleted_at", "messages", ["deleted_at"])
def downgrade() -> None:
op.drop_index("ix_messages_deleted_at", table_name="messages")
op.drop_index("ix_messages_session_role", table_name="messages")
op.drop_index("ix_messages_session_id", table_name="messages")
op.drop_table("messages")
op.drop_index("ix_sessions_deleted_at", table_name="sessions")
op.drop_index("ix_sessions_user_last_activity", table_name="sessions")
op.drop_index("ix_sessions_user_id", table_name="sessions")
op.drop_table("sessions")
op.drop_index("ix_llms_deleted_at", table_name="llms")
op.drop_index("ix_llms_model_code", table_name="llms")
op.drop_index("ix_llms_factory_id", table_name="llms")
op.drop_table("llms")
op.drop_index("ix_llm_factory_deleted_at", table_name="llm_factory")
op.drop_index("ix_llm_factory_name", table_name="llm_factory")
op.drop_table("llm_factory")
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,20 @@
from __future__ import annotations
from typing import Any
from core.agent_chat.event_bridge import map_internal_event
class AguiAdapter:
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
message = payload.get("message")
if not isinstance(message, str) or not message.strip():
raise ValueError("message is required")
return {
"message": message,
"session_id": payload.get("session_id"),
}
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
return map_internal_event(event)
@@ -0,0 +1,69 @@
from __future__ import annotations
from decimal import Decimal
from typing import Any, Mapping
def _to_non_negative_int(value: Any, *, field: str) -> int:
if isinstance(value, bool):
raise ValueError(f"{field} must be an integer")
if isinstance(value, int):
converted = value
elif isinstance(value, str) and value.isdigit():
converted = int(value)
else:
raise ValueError(f"{field} must be an integer")
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
converted = Decimal(str(value))
if converted < 0:
raise ValueError(f"{field} cannot be negative")
return converted
class CostTracker:
def __init__(self, *, currency: str = "USD") -> None:
self._input_tokens = 0
self._output_tokens = 0
self._total_tokens = 0
self._cost = Decimal("0")
self._currency = currency
def add_usage(self, usage: Mapping[str, Any]) -> None:
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
total_tokens = usage.get("total_tokens")
cost = usage.get("cost", "0")
currency = usage.get("currency")
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
normalized_total = (
_to_non_negative_int(total_tokens, field="total_tokens")
if total_tokens is not None
else normalized_input + normalized_output
)
normalized_cost = _to_non_negative_decimal(cost, field="cost")
self._input_tokens += normalized_input
self._output_tokens += normalized_output
self._total_tokens += normalized_total
self._cost += normalized_cost
if currency is not None:
normalized_currency = str(currency)
if normalized_currency != self._currency:
raise ValueError("currency mismatch")
def snapshot(self) -> dict[str, Any]:
return {
"input_tokens": self._input_tokens,
"output_tokens": self._output_tokens,
"total_tokens": self._total_tokens,
"cost": self._cost,
"currency": self._currency,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,82 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class CrewAITemplate:
agents: dict[str, Any]
tasks: dict[str, Any]
workflow: dict[str, Any]
prompts: dict[str, str]
tools_whitelist: set[str]
def _default_static_root() -> Path:
return Path(__file__).resolve().parents[3] / "config" / "static" / "agent_chat"
def _read_yaml(file_path: Path) -> dict[str, Any]:
if not file_path.is_file():
raise FileNotFoundError(f"Required config file not found: {file_path}")
with file_path.open("r", encoding="utf-8") as file:
loaded = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"YAML file must be a mapping: {file_path}")
return loaded
def _read_prompt(file_path: Path) -> str:
if not file_path.is_file():
raise FileNotFoundError(f"Required prompt file not found: {file_path}")
return file_path.read_text(encoding="utf-8").strip()
def validate_workflow_stages(stages: list[str]) -> None:
expected = ["intent", "execution", "organization"]
if stages != expected:
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
root = static_root or _default_static_root()
tools = _read_yaml(root / "tools.yaml")
raw_tools = tools.get("tools", [])
if not isinstance(raw_tools, list):
raise ValueError("tools.yaml field 'tools' must be a list")
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
raise ValueError("tools.yaml list items must be non-empty strings")
whitelist = {item.strip() for item in raw_tools}
return whitelist
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
root = static_root or _default_static_root()
crewai_root = root / "crewai"
agents = _read_yaml(crewai_root / "agents.yaml")
tasks = _read_yaml(crewai_root / "tasks.yaml")
workflow = _read_yaml(crewai_root / "workflow.yaml")
stages = workflow.get("stages")
if not isinstance(stages, list):
raise ValueError("workflow.yaml field 'stages' must be a list")
validate_workflow_stages([str(stage) for stage in stages])
prompts = {
"intent": _read_prompt(crewai_root / "prompts" / "intent.md"),
"execution": _read_prompt(crewai_root / "prompts" / "execution.md"),
"organization": _read_prompt(crewai_root / "prompts" / "organization.md"),
}
return CrewAITemplate(
agents=agents,
tasks=tasks,
workflow=workflow,
prompts=prompts,
tools_whitelist=load_tools_whitelist(root),
)
@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Any
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
missing = [field for field in required if field not in event]
if missing:
raise ValueError(f"Missing fields for {kind}: {missing}")
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
kind = event.get("kind")
if kind == "run_started":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.started",
"run_id": event["session_id"],
}
if kind == "message_delta":
_require_fields(event, kind=kind, required=["message_id", "delta"])
return {
"type": "message.delta",
"message_id": event["message_id"],
"delta": event["delta"],
}
if kind == "tool_started":
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
return {
"type": "tool.started",
"message_id": event["message_id"],
"tool_name": event["tool_name"],
}
if kind == "tool_completed":
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
return {
"type": "tool.completed",
"message_id": event["message_id"],
"tool_name": event["tool_name"],
"result": event.get("result"),
}
if kind == "run_completed":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.completed",
"run_id": event["session_id"],
"output": event.get("output", ""),
}
if kind == "run_failed":
_require_fields(event, kind=kind, required=["session_id"])
return {
"type": "run.failed",
"run_id": event["session_id"],
"error": event.get("error", ""),
}
raise ValueError(f"Unsupported event kind: {kind}")
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Any
def run_started(*, run_id: str) -> dict[str, Any]:
return {"type": "run.started", "run_id": run_id}
def stage_completed(
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
) -> dict[str, Any]:
event: dict[str, Any] = {
"type": "stage.completed",
"run_id": run_id,
"stage": stage,
}
if usage is not None:
event["usage"] = usage
return event
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
return {
"type": "run.completed",
"run_id": run_id,
"output": output,
"usage": usage,
}
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
return {
"type": "run.failed",
"run_id": run_id,
"error": error,
}
+112
View File
@@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass
import hashlib
from pathlib import Path
from typing import Protocol
from core.agent_chat.storage_adapter import StorageAdapter
_ALLOWED_MIME_TYPES = {
"audio/mpeg",
"audio/wav",
"audio/x-wav",
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
"text/plain",
}
class _AsrTool(Protocol):
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
@dataclass(frozen=True)
class AttachmentInput:
filename: str
mime_type: str
content: bytes
origin: str = "user_upload"
@dataclass(frozen=True)
class ProcessedAttachmentContext:
attachments: list[dict[str, object]]
preview_texts: list[str]
class MultimodalProcessor:
def __init__(
self,
*,
storage: StorageAdapter,
asr_tool: _AsrTool,
max_file_size_mb: int = 20,
) -> None:
self._storage = storage
self._asr_tool = asr_tool
self._max_size_bytes = max_file_size_mb * 1024 * 1024
def process(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
attachments: list[AttachmentInput],
) -> ProcessedAttachmentContext:
metadata_list: list[dict[str, object]] = []
preview_texts: list[str] = []
for attachment in attachments:
self._validate_attachment(attachment)
checksum = hashlib.sha256(attachment.content).hexdigest()
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
object_path = self._storage.build_object_path(
user_id=user_id,
session_id=session_id,
message_seq=message_seq,
checksum_sha256=checksum,
extension=extension,
)
preview_text = self._build_preview_text(attachment)
if preview_text:
preview_texts.append(preview_text)
metadata = self._storage.build_attachment_metadata(
object_path=object_path,
mime_type=attachment.mime_type,
size=len(attachment.content),
checksum_sha256=checksum,
origin=attachment.origin,
preview_text=preview_text,
)
metadata_list.append(metadata)
return ProcessedAttachmentContext(
attachments=metadata_list,
preview_texts=preview_texts,
)
def _validate_attachment(self, attachment: AttachmentInput) -> None:
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
raise ValueError("Unsupported MIME type")
if len(attachment.content) > self._max_size_bytes:
raise ValueError("Attachment exceeds max file size")
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
if attachment.mime_type.startswith("audio/"):
result = self._asr_tool.transcribe(
audio_bytes=attachment.content,
filename=attachment.filename,
)
text = result.get("text")
if isinstance(text, str):
return text
return None
if attachment.mime_type == "text/plain":
return attachment.content.decode("utf-8", errors="ignore")[:200]
return None
@@ -0,0 +1,88 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from core.agent_chat.cost_tracker import CostTracker
from core.agent_chat import events
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
@dataclass(frozen=True)
class OrchestratorResult:
output: str
usage: dict[str, Any]
events: list[dict[str, Any]]
context: dict[str, Any]
failed: bool
error: str | None
class AgentChatOrchestrator:
def __init__(
self,
*,
intent_stage: StageCallable,
execution_stage: StageCallable,
organization_stage: StageCallable,
) -> None:
self._intent_stage = intent_stage
self._execution_stage = execution_stage
self._organization_stage = organization_stage
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
tracker = CostTracker()
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
context: dict[str, Any] = {}
stage_pipeline: list[tuple[str, StageCallable]] = [
("intent", self._intent_stage),
("execution", self._execution_stage),
("organization", self._organization_stage),
]
stage_output = user_message
try:
for stage_name, stage_callable in stage_pipeline:
stage_result = await stage_callable(
message=stage_output, context=context
)
stage_output = str(stage_result.get("content", stage_output))
usage = stage_result.get("usage", {})
if isinstance(usage, dict):
tracker.add_usage(usage)
emitted_events.append(
events.stage_completed(
run_id=run_id,
stage=stage_name,
usage=tracker.snapshot(),
)
)
except Exception as exc: # noqa: BLE001
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
return OrchestratorResult(
output="",
usage=tracker.snapshot(),
events=emitted_events,
context=context,
failed=True,
error=str(exc),
)
summary = tracker.snapshot()
emitted_events.append(
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
)
return OrchestratorResult(
output=stage_output,
usage=summary,
events=emitted_events,
context=context,
failed=False,
error=None,
)
@@ -0,0 +1,46 @@
from __future__ import annotations
class StorageAdapter:
_bucket: str
def __init__(self, bucket: str) -> None:
self._bucket = bucket
@property
def bucket(self) -> str:
return self._bucket
def build_object_path(
self,
*,
user_id: str,
session_id: str,
message_seq: int,
checksum_sha256: str,
extension: str,
) -> str:
normalized_ext = extension.strip(".").lower()
return (
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
f"{checksum_sha256}.{normalized_ext}"
)
def build_attachment_metadata(
self,
*,
object_path: str,
mime_type: str,
size: int,
checksum_sha256: str,
origin: str,
preview_text: str | None = None,
) -> dict[str, object]:
return {
"object_path": object_path,
"mime_type": mime_type,
"size": size,
"checksum_sha256": checksum_sha256,
"origin": origin,
"preview_text": preview_text,
}
@@ -0,0 +1 @@
from __future__ import annotations
@@ -0,0 +1,40 @@
from __future__ import annotations
import importlib
from collections.abc import Callable
from typing import Any
TranscribeCallable = Callable[..., dict[str, Any]]
class FunASRTool:
_transcribe_callable: TranscribeCallable
_model: str
def __init__(
self,
transcribe_callable: TranscribeCallable | None = None,
model: str = "fun-asr-realtime-2025-11-07",
) -> None:
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
self._model = model
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
return {
"model": self._model,
**payload,
}
def _dashscope_transcribe(
self, *, audio_bytes: bytes, filename: str
) -> dict[str, Any]:
try:
importlib.import_module("dashscope")
except ImportError as exc:
raise RuntimeError("DashScope SDK is not installed") from exc
raise RuntimeError(
"DashScope transcribe runtime integration is not configured yet"
)
+9
View File
@@ -132,6 +132,14 @@ class SupabaseSettings(BaseModel):
return self.public_url
class StorageSettings(BaseModel):
provider: Literal["supabase"] = "supabase"
bucket: str = Field(default="agent-chat-attachments", min_length=3, max_length=63)
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
max_file_size_mb: int = Field(default=20, ge=1, le=200)
retention_days: int = Field(default=30, ge=1, le=3650)
class DatabaseSettings(BaseModel):
host: str = "localhost"
port: int = 5432
@@ -163,6 +171,7 @@ class Settings(BaseSettings):
cors: CorsSettings = CorsSettings()
redis: RedisSettings = RedisSettings()
supabase: SupabaseSettings = SupabaseSettings()
storage: StorageSettings = StorageSettings()
celery: CelerySettings = CelerySettings()
database: DatabaseSettings = DatabaseSettings()
@@ -0,0 +1,9 @@
intent:
role: Intent Agent
goal: Classify user intent and decide execution strategy
execution:
role: Execution Agent
goal: Execute tasks with available tools
organization:
role: Organization Agent
goal: Organize output for user-friendly response
@@ -0,0 +1,2 @@
你是任务执行代理。
基于输入意图与上下文调用可用工具,并生成可验证执行结果。
@@ -0,0 +1,2 @@
你是意图识别代理。
你的任务是识别用户输入的意图类型,并返回结构化意图标签。
@@ -0,0 +1,2 @@
你是结果整理代理。
将执行结果组织为面向用户的清晰回复,保留关键信息与必要引用。
@@ -0,0 +1,6 @@
intent:
description: Identify user intent and required capabilities
execution:
description: Execute intent with tools and model calls
organization:
description: Format final response and references
@@ -0,0 +1,9 @@
stages:
- intent
- execution
- organization
timeouts:
intent_seconds: 8
execution_seconds: 30
organization_seconds: 10
@@ -0,0 +1,25 @@
factories:
- name: qwen
request_url: https://dashscope.aliyuncs.com/compatible-mode/v1
avatar: https://cdn.simpleicons.org/alibabacloud/FF6A00
- name: minimax
request_url: https://api.minimax.chat/v1
avatar: https://cdn.simpleicons.org/minimax/1A1A1A
- name: kimi
request_url: https://api.moonshot.cn/v1
avatar: https://cdn.simpleicons.org/moonrepo/3B82F6
- name: deepseek
request_url: https://api.deepseek.com/v1
avatar: https://cdn.simpleicons.org/deepseek/4D6BFE
- name: doubao
request_url: https://ark.cn-beijing.volces.com/api/v3
avatar: https://cdn.simpleicons.org/volkswagen/001E50
- name: zai
request_url: https://api.z.ai/v1
avatar: https://cdn.simpleicons.org/zotero/CC2936
llms:
- model_code: qwen3.5-flash
factory_id: qwen
- model_code: deepseek-v3.2
factory_id: deepseek
@@ -0,0 +1,3 @@
tools:
- asr_fun_asr
- attachment_extract
+126 -1
View File
@@ -1,9 +1,134 @@
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel, ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from models.llm import Llm
from models.llm_factory import LlmFactory
logger = get_logger("core.initialization.init_data")
class LlmFactorySeed(BaseModel):
name: str
request_url: str
avatar: str | None = None
class LlmSeed(BaseModel):
model_code: str
factory_id: str
class LlmCatalogSeed(BaseModel):
factories: list[LlmFactorySeed]
llms: list[LlmSeed]
def _default_catalog_path() -> Path:
return (
Path(__file__).resolve().parents[1]
/ "config"
/ "static"
/ "agent_chat"
/ "llm_catalog.yaml"
)
def load_llm_catalog(catalog_path: Path | None = None) -> dict[str, Any]:
path = catalog_path or _default_catalog_path()
with path.open("r", encoding="utf-8") as file:
loaded = yaml.safe_load(file) or {}
if not isinstance(loaded, dict):
raise ValueError(f"Invalid LLM catalog format: {path}")
raw_factories = loaded.get("factories", [])
raw_llms = loaded.get("llms", [])
if not isinstance(raw_factories, list) or not isinstance(raw_llms, list):
raise ValueError(f"Invalid LLM catalog sections: {path}")
try:
parsed = LlmCatalogSeed.model_validate(
{
"factories": list(raw_factories),
"llms": list(raw_llms),
}
)
except ValidationError as exc:
raise ValueError(f"Invalid LLM catalog data: {path}") from exc
return parsed.model_dump()
async def _upsert_factory(
session: AsyncSession,
*,
name: str,
request_url: str,
avatar: str | None,
) -> uuid.UUID:
result = await session.execute(select(LlmFactory).where(LlmFactory.name == name))
factory = result.scalar_one_or_none()
if factory is None:
factory = LlmFactory(name=name, request_url=request_url, avatar=avatar)
session.add(factory)
await session.flush()
else:
factory.request_url = request_url
factory.avatar = avatar
return factory.id
async def _upsert_llm(
session: AsyncSession,
*,
model_code: str,
factory_id: uuid.UUID,
) -> None:
result = await session.execute(select(Llm).where(Llm.model_code == model_code))
llm = result.scalar_one_or_none()
if llm is None:
session.add(Llm(model_code=model_code, factory_id=factory_id))
return
llm.factory_id = factory_id
async def initialize_data() -> bool:
"""Initialize bootstrap data."""
logger.info("Initializing data (no-op)")
catalog = load_llm_catalog()
async with AsyncSessionLocal() as session:
async with session.begin():
factory_id_by_name: dict[str, uuid.UUID] = {}
for factory in catalog["factories"]:
factory_id = await _upsert_factory(
session,
name=factory["name"],
request_url=factory["request_url"],
avatar=factory.get("avatar"),
)
factory_id_by_name[factory["name"]] = factory_id
for llm in catalog["llms"]:
factory_name = llm["factory_id"]
resolved_factory_id = factory_id_by_name.get(factory_name)
if resolved_factory_id is None:
raise RuntimeError(
f"Factory '{factory_name}' not found for model '{llm['model_code']}'"
)
await _upsert_llm(
session,
model_code=llm["model_code"],
factory_id=resolved_factory_id,
)
logger.info("Initialized LLM factory/model seed data")
return True
+11 -1
View File
@@ -1,5 +1,15 @@
from __future__ import annotations
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from models.llm import Llm
from models.llm_factory import LlmFactory
from models.profile import Profile
__all__ = ["Profile"]
__all__ = [
"AgentChatMessage",
"AgentChatSession",
"Llm",
"LlmFactory",
"Profile",
]
+62
View File
@@ -0,0 +1,62 @@
from __future__ import annotations
from decimal import Decimal
import uuid
from enum import Enum
from sqlalchemy import (
JSON,
Enum as SqlEnum,
ForeignKey,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class AgentChatMessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL = "tool"
class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "messages"
__table_args__: tuple[UniqueConstraint] = (
UniqueConstraint("session_id", "seq", name="uq_messages_session_seq"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
seq: Mapped[int] = mapped_column(Integer, nullable=False)
role: Mapped[AgentChatMessageRole] = mapped_column(
SqlEnum(
AgentChatMessageRole, name="agent_chat_message_role", native_enum=False
),
nullable=False,
)
content: Mapped[str] = mapped_column(Text, nullable=False)
model_code: Mapped[str | None] = mapped_column(String(50), nullable=True)
tool_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
input_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
currency: Mapped[str] = mapped_column(String(3), nullable=False, default="USD")
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
)
+64
View File
@@ -0,0 +1,64 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
import uuid
from enum import Enum
from sqlalchemy import (
DateTime,
Enum as SqlEnum,
ForeignKey,
Integer,
Numeric,
String,
func,
text,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class AgentChatSessionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "sessions"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("auth.users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
status: Mapped[AgentChatSessionStatus] = mapped_column(
SqlEnum(
AgentChatSessionStatus, name="agent_chat_session_status", native_enum=False
),
nullable=False,
default=AgentChatSessionStatus.PENDING,
)
last_activity_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
message_count: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
total_tokens: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
total_cost: Mapped[Decimal] = mapped_column(
Numeric(12, 6), nullable=False, server_default=text("0")
)
+26
View File
@@ -0,0 +1,26 @@
from __future__ import annotations
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class Llm(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "llms"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
factory_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("llm_factory.id", ondelete="RESTRICT"),
nullable=False,
index=True,
)
model_code: Mapped[str] = mapped_column(
String(50), nullable=False, unique=True, index=True
)
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
import uuid
from sqlalchemy import String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
class LlmFactory(TimestampMixin, SoftDeleteMixin, Base):
__tablename__: str = "llm_factory"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(
String(50), nullable=False, unique=True, index=True
)
request_url: Mapped[str] = mapped_column(String(255), nullable=False)
avatar: Mapped[str | None] = mapped_column(Text, nullable=True)
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.agent_chat.service import AgentChatService
from v1.profile.dependencies import get_current_user
def get_agent_chat_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AgentChatService:
return AgentChatService(session=session, current_user=user)
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import AgentChatRunRequest, AgentChatRunResponse
from v1.agent_chat.service import AgentChatService
router = APIRouter(prefix="/agent-chat", tags=["agent-chat"])
@router.post("/run", response_model=AgentChatRunResponse)
async def run_agent_chat(
payload: AgentChatRunRequest,
service: Annotated[AgentChatService, Depends(get_agent_chat_service)],
) -> AgentChatRunResponse:
return await service.run(payload)
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, Field
class AgentChatRunRequest(BaseModel):
message: str = Field(min_length=1, max_length=8000)
session_id: UUID | None = None
class AgentChatEvent(BaseModel):
type: str
run_id: str | None = None
message_id: str | None = None
delta: str | None = None
tool_name: str | None = None
result: str | None = None
output: str | None = None
error: str | None = None
class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
+286
View File
@@ -0,0 +1,286 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import func, select
from sqlalchemy.exc import SQLAlchemyError
from core.agent_chat.agui_adapter import AguiAdapter
from core.agent_chat.orchestrator import AgentChatOrchestrator
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.agent_chat.service")
def build_session_title(first_message: str, *, now: datetime) -> str:
title = first_message.strip().replace("\n", " ")[:24]
if not title:
return now.strftime("新对话 %Y-%m-%d %H:%M")
return title
def aggregate_session_cost(costs: list[Decimal]) -> Decimal:
total = Decimal("0")
for cost in costs:
if cost < 0:
raise ValueError("cost must be non-negative")
total += cost
return total
def select_recent_session(
sessions: list[AgentChatSession],
) -> AgentChatSession | None:
if not sessions:
return None
return max(sessions, key=lambda item: item.last_activity_at)
class AgentChatService(BaseService):
_session: AsyncSession
def __init__(self, session: AsyncSession, current_user: CurrentUser | None) -> None:
super().__init__(current_user=current_user)
self._session = session
self._adapter = AguiAdapter()
self._orchestrator = AgentChatOrchestrator(
intent_stage=self._intent_stage,
execution_stage=self._execution_stage,
organization_stage=self._organization_stage,
)
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
try:
command = self._adapter.to_command(payload.model_dump(mode="python"))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_chat_run",
identifier=str(user_id),
limit=60,
window_seconds=60,
)
now = datetime.now(timezone.utc)
try:
chat_session = await self._resolve_session(
session_id=command["session_id"],
user_id=user_id,
first_message=command["message"],
now=now,
)
base_seq = await self._next_seq_base(chat_session.id)
user_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 1,
role=AgentChatMessageRole.USER,
content=command["message"],
cost=Decimal("0"),
)
orchestrator_result = await self._orchestrator.run(
run_id=str(chat_session.id),
user_message=command["message"],
)
assistant_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 2,
role=AgentChatMessageRole.ASSISTANT,
content=orchestrator_result.output,
input_tokens=int(orchestrator_result.usage["input_tokens"]),
output_tokens=int(orchestrator_result.usage["output_tokens"]),
cost=Decimal(orchestrator_result.usage["cost"]),
)
self._session.add(user_message)
self._session.add(assistant_message)
chat_session.status = (
AgentChatSessionStatus.FAILED
if orchestrator_result.failed
else AgentChatSessionStatus.COMPLETED
)
chat_session.last_activity_at = now
chat_session.message_count = chat_session.message_count + 2
chat_session.total_tokens = chat_session.total_tokens + int(
orchestrator_result.usage["total_tokens"]
)
chat_session.total_cost = aggregate_session_cost(
[
Decimal(chat_session.total_cost),
Decimal(orchestrator_result.usage["cost"]),
]
)
await self._session.commit()
await self._session.refresh(chat_session)
await self._session.refresh(user_message)
mapped_events = self._build_mapped_events(
session_id=str(chat_session.id),
message_id=str(user_message.id),
user_message=command["message"],
assistant_output=assistant_message.content,
failed=orchestrator_result.failed,
error=orchestrator_result.error,
)
events = [AgentChatEvent.model_validate(item) for item in mapped_events]
if orchestrator_result.failed:
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
)
return AgentChatRunResponse(
session_id=chat_session.id,
output=assistant_message.content,
events=events,
)
except HTTPException:
await self._session.rollback()
raise
except SQLAlchemyError:
await self._session.rollback()
logger.exception("Agent chat run failed")
raise HTTPException(status_code=503, detail="Agent chat store unavailable")
except Exception as exc: # noqa: BLE001
await self._session.rollback()
logger.exception(
"Agent chat unexpected failure", error_type=type(exc).__name__
)
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
) from exc
def _build_mapped_events(
self,
*,
session_id: str,
message_id: str,
user_message: str,
assistant_output: str,
failed: bool,
error: str | None,
) -> list[dict[str, object]]:
mapped_events = [
self._adapter.to_protocol_event(
{
"kind": "run_started",
"session_id": session_id,
}
),
self._adapter.to_protocol_event(
{
"kind": "message_delta",
"message_id": message_id,
"delta": user_message,
}
),
]
if failed:
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_failed",
"session_id": session_id,
"error": error or "orchestration failed",
}
)
)
return mapped_events
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": session_id,
"output": assistant_output,
}
)
)
return mapped_events
async def _resolve_session(
self,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
if session_id is not None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.where(AgentChatSession.user_id == user_id)
.where(AgentChatSession.deleted_at.is_(None))
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
existing = result.scalar_one_or_none()
if existing is None:
raise HTTPException(status_code=404, detail="Session not found")
existing.status = AgentChatSessionStatus.RUNNING
return existing
title = build_session_title(first_message, now=now)
created = AgentChatSession(
user_id=user_id,
title=title,
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
)
self._session.add(created)
await self._session.flush()
return created
async def _next_seq_base(self, session_id: object) -> int:
stmt = select(func.max(AgentChatMessage.seq)).where(
AgentChatMessage.session_id == session_id
)
result = await self._session.scalar(stmt)
if result is None:
return 0
return int(result)
async def _intent_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
context["intent"] = "default"
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def _execution_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def _organization_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
+18
View File
@@ -74,6 +74,12 @@ async def login(
payload: LoginRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
await enforce_rate_limit(
scope="login",
identifier=payload.email,
limit=10,
window_seconds=60,
)
return await service.login(payload)
@@ -82,6 +88,12 @@ async def refresh(
payload: RefreshRequest,
service: AuthService = Depends(get_auth_service),
) -> AuthTokenResponse:
await enforce_rate_limit(
scope="refresh",
identifier=payload.refresh_token,
limit=10,
window_seconds=60,
)
return await service.refresh(payload)
@@ -90,6 +102,12 @@ async def logout(
payload: LogoutRequest,
service: AuthService = Depends(get_auth_service),
) -> Response:
await enforce_rate_limit(
scope="logout",
identifier=payload.refresh_token,
limit=10,
window_seconds=60,
)
await service.logout(payload.refresh_token)
return Response(status_code=204)
+2
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent_chat.router import router as agent_chat_router
from v1.auth.router import router as auth_router
from v1.infra.router import router as infra_router
from v1.profile.router import router as profile_router
@@ -12,6 +13,7 @@ router = APIRouter(prefix="/api/v1")
router.include_router(auth_router)
router.include_router(infra_router)
router.include_router(profile_router)
router.include_router(agent_chat_router)
@router.get("/health", response_model=HealthResponse)
+98
View File
@@ -0,0 +1,98 @@
from __future__ import annotations
import json
import socket
import threading
import time
from uuid import UUID
from playwright.sync_api import sync_playwright
import uvicorn
from app import app
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent_chat.service import AgentChatService
class FakeE2EAgentChatService(AgentChatService):
def __init__(self) -> None:
return None
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
session_id = payload.session_id or UUID("00000000-0000-0000-0000-000000000001")
return AgentChatRunResponse(
session_id=session_id,
output=payload.message,
events=[
AgentChatEvent(type="run.started", run_id=str(session_id)),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed", run_id=str(session_id), output=payload.message
),
],
)
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex((host, port)) == 0:
return
time.sleep(0.05)
raise RuntimeError("Server did not start in time")
def _start_server(host: str, port: int):
config = uvicorn.Config(app, host=host, port=port, log_level="info")
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
_wait_for_port(host, port)
return server, thread
def test_agent_chat_flow_e2e() -> None:
app.dependency_overrides[get_agent_chat_service] = lambda: FakeE2EAgentChatService()
host = "127.0.0.1"
port = _find_free_port()
server, thread = _start_server(host, port)
try:
with sync_playwright() as playwright:
request_context = playwright.request.new_context(
base_url=f"http://{host}:{port}"
)
try:
response = request_context.post(
"/api/v1/agent-chat/run",
data=json.dumps({"message": "hello"}),
headers={"Content-Type": "application/json"},
)
assert response.status == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
request_context.dispose()
finally:
app.dependency_overrides = {}
server.should_exit = True
thread.join(timeout=5)
@@ -0,0 +1,38 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.service import select_recent_session
def test_recent_session_home_default_selection() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-0000000000a1"),
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 8, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=100,
total_cost=Decimal("0.010000"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-0000000000a2"),
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=3,
total_tokens=120,
total_cost=Decimal("0.020000"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-0000000000a2")
@@ -0,0 +1,97 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from types import MethodType
from uuid import UUID, uuid4
import pytest
from core.auth.models import CurrentUser
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.schemas import AgentChatRunRequest
from v1.agent_chat.service import AgentChatService
class _FakeAsyncSession:
def __init__(self) -> None:
self.added: list[object] = []
self.committed = False
self.rolled_back = False
def add(self, obj: object) -> None:
self.added.append(obj)
async def flush(self) -> None:
return None
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
async def refresh(self, obj: object) -> None:
if isinstance(obj, AgentChatSession) and obj.id is None:
obj.id = uuid4()
if isinstance(obj, AgentChatMessage) and obj.id is None:
obj.id = uuid4()
@pytest.mark.asyncio
async def test_run_persists_messages_and_emits_ordered_events() -> None:
fake_db = _FakeAsyncSession()
service = AgentChatService(
session=fake_db, # type: ignore[arg-type]
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
)
async def _resolve_session(
self: AgentChatService,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
assert session_id is None
assert first_message == "hello"
return AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000111"),
user_id=user_id,
title="hello",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
message_count=0,
total_tokens=0,
total_cost=Decimal("0"),
created_at=now,
updated_at=now,
deleted_at=None,
)
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
return 2
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
response = await service.run(AgentChatRunRequest(message="hello"))
assert fake_db.committed is True
inserted_messages = [
item for item in fake_db.added if isinstance(item, AgentChatMessage)
]
assert len(inserted_messages) == 2
assert [msg.seq for msg in inserted_messages] == [3, 4]
assert [msg.role for msg in inserted_messages] == [
AgentChatMessageRole.USER,
AgentChatMessageRole.ASSISTANT,
]
assert [event.type for event in response.events] == [
"run.started",
"message.delta",
"run.completed",
]
@@ -0,0 +1,78 @@
from __future__ import annotations
from typing import Callable
from uuid import UUID
from fastapi.testclient import TestClient
from app import app
from v1.agent_chat.dependencies import get_agent_chat_service
from v1.agent_chat.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
from v1.agent_chat.service import AgentChatService
class FakeAgentChatService:
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
return AgentChatRunResponse(
session_id=UUID("00000000-0000-0000-0000-000000000001"),
output=payload.message,
events=[
AgentChatEvent(
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
),
AgentChatEvent(
type="message.delta", message_id="m1", delta=payload.message
),
AgentChatEvent(
type="run.completed",
run_id="00000000-0000-0000-0000-000000000001",
output=payload.message,
),
],
)
def _override_agent_chat_service(
service: FakeAgentChatService,
) -> Callable[[], AgentChatService]:
def _get_service() -> AgentChatService:
return service # type: ignore[return-value]
return _get_service
def test_run_route_returns_response() -> None:
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat/run", json={"message": "hello"})
assert response.status_code == 200
body = response.json()
assert body["output"] == "hello"
assert [event["type"] for event in body["events"]] == [
"run.started",
"message.delta",
"run.completed",
]
finally:
app.dependency_overrides = {}
def test_run_route_validates_payload() -> None:
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
FakeAgentChatService()
)
client = TestClient(app)
try:
response = client.post("/api/v1/agent-chat/run", json={"message": ""})
assert response.status_code == 422
finally:
app.dependency_overrides = {}
@@ -0,0 +1,20 @@
from __future__ import annotations
from decimal import Decimal
from v1.agent_chat.service import aggregate_session_cost
def test_aggregate_session_cost_sums_non_negative_values() -> None:
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
assert total == Decimal("0.012500")
def test_aggregate_session_cost_rejects_negative_value() -> None:
try:
aggregate_session_cost([Decimal("-0.010000")])
raised = False
except ValueError:
raised = True
assert raised is True
@@ -0,0 +1,42 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from uuid import UUID
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.agent_chat.service import select_recent_session
def test_select_recent_session_uses_last_activity_desc() -> None:
sessions = [
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000001"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="older",
status=AgentChatSessionStatus.COMPLETED,
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
message_count=1,
total_tokens=1,
total_cost=Decimal("0"),
),
AgentChatSession(
id=UUID("00000000-0000-0000-0000-000000000002"),
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
title="newer",
status=AgentChatSessionStatus.RUNNING,
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
message_count=2,
total_tokens=2,
total_cost=Decimal("0"),
),
]
selected = select_recent_session(sessions)
assert selected is not None
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
def test_select_recent_session_returns_none_for_empty_collection() -> None:
assert select_recent_session([]) is None
@@ -416,6 +416,108 @@ def test_logout_returns_no_content() -> None:
app.dependency_overrides = {}
def test_login_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/login",
json={"email": "user@example.com", "password": "wrongpw"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_refresh_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
blocked = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 401
blocked = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_logout_rate_limited_after_too_many_attempts() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
access_token="access",
refresh_token="refresh",
expires_in=3600,
token_type="bearer",
user=user,
)
app.dependency_overrides[get_auth_service] = _override_auth_service(
FakeAuthService(token_response)
)
client = TestClient(app)
try:
for _ in range(10):
ok = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert ok.status_code == 204
blocked = client.post(
"/api/v1/auth/logout",
json={"refresh_token": "refresh"},
)
assert blocked.status_code == 429
assert blocked.headers["content-type"].startswith("application/problem+json")
body = blocked.json()
assert body["detail"] == "Too many requests"
finally:
app.dependency_overrides = {}
def test_signup_start_validation_error_returns_problem_details() -> None:
user = AuthUser(id="user-1", email="user@example.com")
token_response = AuthTokenResponse(
@@ -0,0 +1,40 @@
from __future__ import annotations
import pytest
from core.agent_chat.agui_adapter import AguiAdapter
def test_to_command_maps_payload_fields() -> None:
adapter = AguiAdapter()
command = adapter.to_command(
{
"message": "hello",
"session_id": "00000000-0000-0000-0000-000000000001",
}
)
assert command["message"] == "hello"
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
def test_to_protocol_event_maps_internal_event() -> None:
adapter = AguiAdapter()
mapped = adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": "run-1",
"output": "done",
}
)
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
def test_to_protocol_event_raises_for_invalid_event() -> None:
adapter = AguiAdapter()
with pytest.raises(ValueError):
adapter.to_protocol_event({"kind": "unknown"})
@@ -0,0 +1,30 @@
from __future__ import annotations
import pytest
from core.agent_chat.tools.asr_fun_asr import FunASRTool
def test_transcribe_uses_injected_dashscope_callable() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert filename == "voice.wav"
assert audio_bytes == b"audio"
return {"text": "你好", "request_id": "req-1"}
tool = FunASRTool(transcribe_callable=fake_transcribe)
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
assert result["text"] == "你好"
assert result["request_id"] == "req-1"
assert result["model"] == "fun-asr-realtime-2025-11-07"
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
raise RuntimeError("upstream timeout")
tool = FunASRTool(transcribe_callable=fake_transcribe)
with pytest.raises(RuntimeError):
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
@@ -0,0 +1,82 @@
from __future__ import annotations
from decimal import Decimal
import pytest
from core.agent_chat.cost_tracker import CostTracker
def test_normalize_usage_and_cost_aggregation() -> None:
tracker = CostTracker()
tracker.add_usage(
{
"prompt_tokens": 7,
"completion_tokens": 5,
"cost": "0.002500",
}
)
tracker.add_usage(
{
"input_tokens": 5,
"output_tokens": 3,
"cost": "0.003000",
"currency": "USD",
}
)
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 12
assert snapshot["output_tokens"] == 8
assert snapshot["total_tokens"] == 20
assert snapshot["cost"] == Decimal("0.005500")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_negative_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": -1})
with pytest.raises(ValueError):
tracker.add_usage({"cost": "-0.010000"})
def test_snapshot_is_zero_before_any_usage() -> None:
tracker = CostTracker()
snapshot = tracker.snapshot()
assert snapshot["input_tokens"] == 0
assert snapshot["output_tokens"] == 0
assert snapshot["total_tokens"] == 0
assert snapshot["cost"] == Decimal("0")
assert snapshot["currency"] == "USD"
def test_add_usage_rejects_currency_mismatch() -> None:
tracker = CostTracker(currency="USD")
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
with pytest.raises(ValueError):
tracker.add_usage(
{
"input_tokens": 1,
"output_tokens": 1,
"cost": "0.001000",
"currency": "CNY",
}
)
def test_add_usage_rejects_non_integral_token_values() -> None:
tracker = CostTracker()
with pytest.raises(ValueError):
tracker.add_usage({"input_tokens": 1.5})
with pytest.raises(ValueError):
tracker.add_usage({"output_tokens": True})
@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from core.agent_chat.event_bridge import map_internal_event
def test_map_run_started_event() -> None:
event = {"kind": "run_started", "session_id": "s1"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.started", "run_id": "s1"}
def test_map_message_delta_event() -> None:
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
mapped = map_internal_event(event)
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
def test_map_tool_events() -> None:
started = {
"kind": "tool_started",
"message_id": "m2",
"tool_name": "asr_fun_asr",
}
completed = {
"kind": "tool_completed",
"message_id": "m2",
"tool_name": "asr_fun_asr",
"result": "ok",
}
mapped_started = map_internal_event(started)
mapped_completed = map_internal_event(completed)
assert mapped_started["type"] == "tool.started"
assert mapped_started["tool_name"] == "asr_fun_asr"
assert mapped_completed["type"] == "tool.completed"
assert mapped_completed["result"] == "ok"
def test_map_run_completed_event() -> None:
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
mapped = map_internal_event(event)
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
def test_map_unknown_event_raises() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "unknown"})
def test_map_event_missing_required_field_raises_value_error() -> None:
with pytest.raises(ValueError):
map_internal_event({"kind": "message_delta", "message_id": "m1"})
@@ -0,0 +1,89 @@
from __future__ import annotations
import pytest
from core.agent_chat.multimodal import AttachmentInput, MultimodalProcessor
from core.agent_chat.storage_adapter import StorageAdapter
from core.agent_chat.tools.asr_fun_asr import FunASRTool
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
assert audio_bytes == b"audio"
assert filename == "voice.wav"
return {"text": "hello world", "request_id": "req-1"}
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
max_file_size_mb=1,
)
result = processor.process(
user_id="u1",
session_id="s1",
message_seq=4,
attachments=[
AttachmentInput(
filename="voice.wav",
mime_type="audio/wav",
content=b"audio",
)
],
)
assert len(result.attachments) == 1
metadata = result.attachments[0]
assert (
metadata["object_path"]
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
)
assert metadata["mime_type"] == "audio/wav"
assert result.preview_texts == ["hello world"]
def test_multimodal_rejects_unsupported_mime_type() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage, asr_tool=FunASRTool(lambda **_: {})
)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="malware.exe",
mime_type="application/octet-stream",
content=b"bad",
)
],
)
def test_multimodal_rejects_attachment_over_max_size() -> None:
storage = StorageAdapter(bucket="agent-chat-attachments")
processor = MultimodalProcessor(
storage=storage,
asr_tool=FunASRTool(lambda **_: {}),
max_file_size_mb=1,
)
oversized = b"x" * (1024 * 1024 + 1)
with pytest.raises(ValueError):
processor.process(
user_id="u1",
session_id="s1",
message_seq=1,
attachments=[
AttachmentInput(
filename="big.wav",
mime_type="audio/wav",
content=oversized,
)
],
)
@@ -0,0 +1,104 @@
from __future__ import annotations
from core.agent_chat.orchestrator import AgentChatOrchestrator
async def _intent_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("intent")
return {
"content": f"intent:{message}",
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
}
async def _execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
return {
"content": f"execution:{message}",
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
}
async def _organization_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("organization")
return {
"content": "final answer",
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
}
def test_orchestrator_runs_three_stages_in_order() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
assert result.context["sequence"] == ["intent", "execution", "organization"]
assert result.output == "final answer"
assert result.usage["total_tokens"] == 13
assert result.events[0]["type"] == "run.started"
assert result.events[-1]["type"] == "run.completed"
async def _failing_execution_stage(
*, message: str, context: dict[str, object]
) -> dict[str, object]:
sequence = context.setdefault("sequence", [])
if isinstance(sequence, list):
sequence.append("execution")
raise RuntimeError("boom")
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_failing_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
assert result.context["sequence"] == ["intent", "execution"]
assert result.events[-1]["type"] == "run.failed"
assert result.events[-1]["run_id"] == "run-2"
assert "boom" in (result.events[-1].get("error") or "")
assert result.failed is True
assert "boom" in (result.error or "")
def test_orchestrator_emits_stage_event_payload_shape() -> None:
orchestrator = AgentChatOrchestrator(
intent_stage=_intent_stage,
execution_stage=_execution_stage,
organization_stage=_organization_stage,
)
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
for event in result.events:
assert "type" in event
assert event.get("run_id") == "run-3"
stage_events = [
event for event in result.events if event["type"] == "stage.completed"
]
assert [event["stage"] for event in stage_events] == [
"intent",
"execution",
"organization",
]
@@ -0,0 +1,23 @@
from __future__ import annotations
from datetime import datetime
from v1.agent_chat.service import build_session_title
def test_build_session_title_truncates_first_message() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title(
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
)
assert len(title) == 24
def test_build_session_title_falls_back_when_message_empty() -> None:
now = datetime(2026, 2, 25, 10, 30)
title = build_session_title("\n ", now=now)
assert title == "新对话 2026-02-25 10:30"
@@ -0,0 +1,37 @@
from __future__ import annotations
from core.agent_chat.storage_adapter import StorageAdapter
def test_build_object_path_uses_expected_pattern() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
path = adapter.build_object_path(
user_id="u1",
session_id="s1",
message_seq=3,
checksum_sha256="abc123",
extension="wav",
)
assert path == "agent-chat/u1/s1/3/abc123.wav"
def test_build_attachment_metadata_contains_required_fields() -> None:
adapter = StorageAdapter(bucket="agent-chat-attachments")
metadata = adapter.build_attachment_metadata(
object_path="agent-chat/u1/s1/3/abc123.wav",
mime_type="audio/wav",
size=1024,
checksum_sha256="abc123",
origin="user_upload",
preview_text="hello",
)
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
assert metadata["mime_type"] == "audio/wav"
assert metadata["size"] == 1024
assert metadata["checksum_sha256"] == "abc123"
assert metadata["origin"] == "user_upload"
assert metadata["preview_text"] == "hello"
@@ -0,0 +1,138 @@
from __future__ import annotations
from pathlib import Path
import pytest
from core.agent_chat.crewai.template_loader import (
load_crewai_template,
load_tools_whitelist,
validate_workflow_stages,
)
def _write(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
def _prepare_static_root(root: Path) -> Path:
_write(
root / "crewai" / "agents.yaml",
"""
intent:
role: Intent Agent
execution:
role: Execution Agent
organization:
role: Organization Agent
""".strip(),
)
_write(
root / "crewai" / "tasks.yaml",
"""
intent:
description: classify
execution:
description: run task
organization:
description: summarize
""".strip(),
)
_write(
root / "crewai" / "workflow.yaml",
"""
stages:
- intent
- execution
- organization
""".strip(),
)
_write(root / "crewai" / "prompts" / "intent.md", "intent prompt")
_write(root / "crewai" / "prompts" / "execution.md", "execution prompt")
_write(root / "crewai" / "prompts" / "organization.md", "organization prompt")
_write(
root / "tools.yaml",
"""
tools:
- asr_fun_asr
- doc_extract
""".strip(),
)
return root
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
template = load_crewai_template(static_root)
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
assert template.workflow["stages"] == ["intent", "execution", "organization"]
assert template.prompts["intent"] == "intent prompt"
assert template.prompts["execution"] == "execution prompt"
assert template.prompts["organization"] == "organization prompt"
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
tmp_path: Path,
) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
(static_root / "crewai" / "tasks.yaml").unlink()
with pytest.raises(FileNotFoundError):
load_crewai_template(static_root)
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
tmp_path: Path,
) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
_write(
static_root / "crewai" / "workflow.yaml",
"""
stages:
- execution
- intent
- organization
""".strip(),
)
with pytest.raises(ValueError):
load_crewai_template(static_root)
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
whitelist = load_tools_whitelist(static_root)
assert whitelist == {"asr_fun_asr", "doc_extract"}
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
validate_workflow_stages(["intent", "execution", "organization"])
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
with pytest.raises(ValueError):
validate_workflow_stages(["intent", "execution"])
with pytest.raises(ValueError):
validate_workflow_stages(["intent", "execution", "organization", "extra"])
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
static_root = _prepare_static_root(tmp_path / "agent_chat")
_write(
static_root / "tools.yaml",
"""
tools:
- asr_fun_asr
- 123
""".strip(),
)
with pytest.raises(ValueError):
load_tools_whitelist(static_root)
@@ -0,0 +1,143 @@
from __future__ import annotations
from pathlib import Path
import pytest
from sqlalchemy import Column, String, Table, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from core.initialization import init_data
from models.llm import Llm
from models.llm_factory import LlmFactory
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
catalog_path = (
Path(__file__).resolve().parents[3]
/ "src"
/ "core"
/ "config"
/ "static"
/ "agent_chat"
/ "llm_catalog.yaml"
)
catalog = init_data.load_llm_catalog(catalog_path)
assert len(catalog["factories"]) == 6
assert len(catalog["llms"]) == 2
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_id"}
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
catalog_path = tmp_path / "llm_catalog.yaml"
catalog_path.write_text(
"""
factories:
- name: qwen
llms:
- model_code: qwen3.5-flash
""".strip(),
encoding="utf-8",
)
with pytest.raises(ValueError):
init_data.load_llm_catalog(catalog_path)
@pytest.mark.asyncio
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
session_maker = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
first = await init_data.initialize_data()
second = await init_data.initialize_data()
assert first is True
assert second is True
async with session_maker() as session:
factory_count = await session.scalar(
select(func.count()).select_from(LlmFactory)
)
llm_count = await session.scalar(select(func.count()).select_from(Llm))
assert factory_count == 6
assert llm_count == 2
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.mark.asyncio
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
monkeypatch: pytest.MonkeyPatch,
) -> None:
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
session_maker = async_sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
monkeypatch.setattr(
init_data,
"load_llm_catalog",
lambda *_: {
"factories": [
{
"name": "qwen",
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"avatar": "https://cdn.example.com/qwen.png",
}
],
"llms": [
{
"model_code": "qwen3.5-flash",
"factory_id": "missing_factory",
}
],
},
)
with pytest.raises(RuntimeError):
await init_data.initialize_data()
async with session_maker() as session:
factory_count = await session.scalar(
select(func.count()).select_from(LlmFactory)
)
llm_count = await session.scalar(select(func.count()).select_from(Llm))
assert factory_count == 0
assert llm_count == 0
Base.metadata.remove(users_table)
await engine.dispose()
@@ -0,0 +1,17 @@
from __future__ import annotations
from pathlib import Path
def test_agent_chat_migration_exists_and_creates_expected_tables() -> None:
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
migration = versions_dir / "20260226_create_agent_chat_core_tables.py"
assert migration.exists()
content = migration.read_text(encoding="utf-8")
assert 'create_table(\n "llm_factory"' in content
assert 'create_table(\n "llms"' in content
assert 'create_table(\n "sessions"' in content
assert 'create_table(\n "messages"' in content
assert "tool_calls" not in content
@@ -0,0 +1,119 @@
from __future__ import annotations
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from models.llm import Llm
from models.llm_factory import LlmFactory
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
factory = LlmFactory(
name="qwen",
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
avatar="https://cdn.example.com/qwen.png",
)
db_session.add(factory)
await db_session.flush()
llm = Llm(
factory_id=factory.id,
model_code="qwen3.5-flash",
)
db_session.add(llm)
await db_session.commit()
found_llm = await db_session.get(Llm, llm.id)
assert found_llm is not None
assert found_llm.factory_id == factory.id
@pytest.mark.asyncio
async def test_session_status_supports_required_values(
db_session: AsyncSession,
) -> None:
user_id = uuid4()
session = AgentChatSession(
user_id=user_id,
title="test",
status="pending",
)
db_session.add(session)
await db_session.commit()
statuses = [
AgentChatSessionStatus.PENDING,
AgentChatSessionStatus.RUNNING,
AgentChatSessionStatus.COMPLETED,
AgentChatSessionStatus.FAILED,
]
for status in statuses:
session.status = status
await db_session.commit()
await db_session.refresh(session)
assert session.status == status
@pytest.mark.asyncio
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
user_id = uuid4()
session = AgentChatSession(
user_id=user_id,
title="tool test",
status="pending",
)
db_session.add(session)
await db_session.flush()
message = AgentChatMessage(
session_id=session.id,
seq=1,
role="tool",
content="tool output",
cost=0,
)
db_session.add(message)
await db_session.commit()
result = await db_session.execute(
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
)
found = result.scalar_one()
assert found.role == "tool"
@@ -0,0 +1,34 @@
from __future__ import annotations
from pydantic import ValidationError
import pytest
from pytest import MonkeyPatch
from core.config.settings import Settings
def test_social_prefixed_storage_env_populates_settings(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "supabase")
monkeypatch.setenv("SOCIAL_STORAGE__BUCKET", "agent-chat-attachments")
monkeypatch.setenv("SOCIAL_STORAGE__SIGNED_URL_TTL_SECONDS", "900")
monkeypatch.setenv("SOCIAL_STORAGE__MAX_FILE_SIZE_MB", "25")
monkeypatch.setenv("SOCIAL_STORAGE__RETENTION_DAYS", "45")
settings = Settings()
assert settings.storage.provider == "supabase"
assert settings.storage.bucket == "agent-chat-attachments"
assert settings.storage.signed_url_ttl_seconds == 900
assert settings.storage.max_file_size_mb == 25
assert settings.storage.retention_days == 45
def test_storage_settings_validation_rejects_invalid_provider(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "s3")
with pytest.raises(ValidationError):
Settings()
@@ -0,0 +1,196 @@
from __future__ import annotations
from decimal import Decimal
from uuid import uuid4
import pytest
from sqlalchemy import Column, String, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.agent_chat.orchestrator import OrchestratorResult
from core.db.base import Base
from models.agent_chat_message import AgentChatMessage
from models.agent_chat_session import AgentChatSession
from v1.agent_chat.schemas import AgentChatRunRequest
from v1.agent_chat.service import AgentChatService
@pytest.fixture
async def db_engine():
users_table = Table(
"users",
Base.metadata,
Column("id", String, primary_key=True),
schema="auth",
extend_existing=True,
)
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
await conn.run_sync(Base.metadata.create_all)
yield engine
Base.metadata.remove(users_table)
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
async_session = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_run_creates_session_and_persists_messages(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
result = await service.run(AgentChatRunRequest(message="hello"))
assert result.session_id is not None
assert result.output == "hello"
assert [event.type for event in result.events] == [
"run.started",
"message.delta",
"run.completed",
]
session_obj = await db_session.get(AgentChatSession, result.session_id)
assert session_obj is not None
assert session_obj.message_count == 2
assert session_obj.status.value == "completed"
rows = await db_session.execute(
select(AgentChatMessage)
.where(AgentChatMessage.session_id == result.session_id)
.order_by(AgentChatMessage.seq.asc())
)
messages = rows.scalars().all()
assert len(messages) == 2
assert messages[0].role.value == "user"
assert messages[1].role.value == "assistant"
@pytest.mark.asyncio
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
first = await service.run(AgentChatRunRequest(message="first"))
second = await service.run(
AgentChatRunRequest(message="second", session_id=first.session_id)
)
assert second.session_id == first.session_id
session_obj = await db_session.get(AgentChatSession, first.session_id)
assert session_obj is not None
assert session_obj.message_count == 4
@pytest.mark.asyncio
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _FailingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
return OrchestratorResult(
output="",
usage={
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"cost": Decimal("0"),
"currency": "USD",
},
events=[],
context={},
failed=True,
error="stage failed",
)
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502
rows = await db_session.execute(
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
)
stored_session = rows.scalars().one()
assert stored_session.status.value == "failed"
@pytest.mark.asyncio
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message=" "))
assert exc_info.value.status_code == 422
@pytest.mark.asyncio
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
db_session: AsyncSession,
monkeypatch: pytest.MonkeyPatch,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
async def _fail_commit() -> None:
raise SQLAlchemyError("db down")
monkeypatch.setattr(db_session, "commit", _fail_commit)
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 503
@pytest.mark.asyncio
async def test_run_returns_502_for_unexpected_exception(
db_session: AsyncSession,
) -> None:
user = CurrentUser(id=uuid4())
service = AgentChatService(session=db_session, current_user=user)
class _CrashingOrchestrator:
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
raise RuntimeError("unexpected")
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
with pytest.raises(HTTPException) as exc_info:
await service.run(AgentChatRunRequest(message="hello"))
assert exc_info.value.status_code == 502
@@ -0,0 +1,36 @@
# Agent Chat CrewAI + AG-UI Spike Notes
## Scope
- 验证 CrewAI 依赖可用性与版本探测方式。
- 验证 AG-UI 官方 CrewAI 集成在当前仓库中的落地路径。
- 验证 DashScope FunASR 响应中的 usage 字段可得性与兜底策略。
## Findings
### CrewAI
- `uv run python -m pip show crewai` 在当前虚拟环境不可用(无 pip 模块)。
- `uv pip show crewai` 返回未安装,说明当前工作树尚未安装 CrewAI 依赖。
- 若需启用真实编排,需在 `pyproject.toml` 中声明依赖并执行 `uv sync --extra dev`
### AG-UI 官方 CrewAI 集成
- 目标对齐官方标准事件语义(如 `message.delta``tool.started``tool.completed``run.completed``run.failed`)。
- 当前仓库采取“适配层隔离”策略:由 `agui_adapter.py` 进行请求与事件映射,避免协议细节扩散到业务层。
### DashScope FunASR
- 优先读取上游响应 usage 字段用于成本统计。
- 若 usage 缺失,落库时保持 `raw_usage` 与空标准字段,并标记 `metadata.usage_missing=true` 以便审计。
## Fallback Strategy
- 当官方集成能力或版本存在不确定性时,启用最小兜底事件映射:
- 仅输出标准 AG-UI 事件。
- 不扩展私有协议字段。
-`event_bridge.py` 中统一做字段校验与错误转换。
## Decision
- 继续按计划推进:先补齐编排与成本核心,再完善 AG-UI 适配、多模态与 E2E 闭环。
@@ -0,0 +1,49 @@
# Agent Chat Gap Closure Design
**Goal:** 在不重做已完成任务的前提下,按既定 Task 顺序补齐 Agent Chat Core 的缺口,实现可验证、可审计的端到端闭环。
## Current State
- 已完成:Task 2/3/4 的核心数据层、静态配置、模板加载;Task 6/7 的部分骨架(`event_bridge``v1/agent_chat``storage_adapter``asr_fun_asr`)。
- 未完成或缺口:Task 1 的 spike 结论文档;Task 5 编排与成本追踪;Task 6 `agui_adapter` 与缺失测试;Task 7 `multimodal`Task 8 会话审计与 recent 规则;Task 9 E2E 与运行文档闭环。
## Design Decisions
- 以“缺口优先”方式执行:仅新增/修改缺失能力,已稳定模块不重构。
- 严格遵循顺序:Task 1 -> 5 -> 6 -> 7 -> 8 -> 9。
- 每个 Task 均采用 TDD:先写失败测试,再做最小实现,通过后再小步重构。
- 统一事件与持久化顺序:以 `session.id + seq` 为唯一顺序锚点,避免流式输出与落库顺序漂移。
- 工具调用成本仍归集到 `messages(role=tool)`,会话总成本由增量聚合维护。
## Component Plan
- Task 1: 新增 spike notes,记录 CrewAI/AG-UI/FunASR 依赖可用性与兜底策略。
- Task 5: 新增 `orchestrator.py``cost_tracker.py``events.py`,完成三阶段执行与 usage/cost 归一。
- Task 6: 新增 `agui_adapter.py`,对接现有 `event_bridge.py``v1/agent_chat/service.py`
- Task 7: 新增 `multimodal.py`,衔接附件校验、存储元数据、ASR 文本提取。
- Task 8: 增强会话标题策略、recent session 查询、审计字段与限流保护。
- Task 9: 补齐 E2E 与 runbook,执行 bootstrap gate + 分层测试验证。
## Data Flow
1. 路由接收 AG-UI 请求并解析输入文本/附件。
2. `agui_adapter` 生成内部命令并触发编排器三阶段执行。
3. 每阶段产出内部事件,经 `event_bridge` 映射为 AG-UI 标准事件。
4. `service` 在事务内写入 `messages` 并更新 `sessions` 汇总字段。
5. 流式事件向外输出,顺序与 `messages.seq` 保持一致。
## Error Handling
- 配置/模板错误:启动前校验并快速失败,返回可追踪错误码。
- 第三方调用错误(LLM/ASR/Storage):记录标准化失败事件与审计元数据,不泄露敏感信息。
- 持久化冲突:对 `session_id + seq` 冲突执行有限重试并记录告警。
## Testing Strategy
- Unit`cost_tracker``orchestrator``agui_adapter``multimodal``title strategy`
- Integration`agent_chat` 路由、事件落库、recent session 选择、会话成本聚合。
- E2E:文本、图片+文本、音频+ASR、文档问答、首页最近会话默认选中。
## Approval Note
该设计基于用户确认的“仅按未完成 Task 顺序推进”执行策略。
@@ -0,0 +1,230 @@
# Agent Chat Gap Closure Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 按未完成 Task 顺序补齐 Agent Chat Core 缺口,形成可运行、可测试、可审计的后端链路。
**Architecture:** 复用已完成的数据层与路由骨架,在 `core/agent_chat` 补齐编排、成本与多模态能力,并通过 `v1/agent_chat/service.py` 统一持久化与事件顺序。全流程以 `session.id + messages.seq` 作为一致性锚点,保证事件输出与落库一致。
**Tech Stack:** FastAPI, SQLAlchemy, Pydantic, pytest, CrewAI, AG-UI adapter, DashScope SDK, Supabase Storage。
---
### Task 1: 补齐 Spike 结论文档
**Files:**
- Create: `docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md`
**Step 1: 写失败校验(文档存在性)**
```bash
test -f docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md
```
**Step 2: 运行并确认失败**
Run: `test -f docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md`
Expected: non-zero exit code。
**Step 3: 写最小文档实现**
```markdown
- CrewAI 版本探测结论
- AG-UI 官方 CrewAI 集成可用性结论
- DashScope FunASR usage 字段策略
- 不可用时的最小兜底映射策略
```
**Step 4: 运行并确认通过**
Run: `test -f docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md`
Expected: zero exit code。
### Task 5: 补齐编排与成本追踪
**Files:**
- Create: `backend/src/core/agent_chat/events.py`
- Create: `backend/src/core/agent_chat/cost_tracker.py`
- Create: `backend/src/core/agent_chat/orchestrator.py`
- Test: `backend/tests/unit/core/agent_chat/test_cost_tracker.py`
- Test: `backend/tests/unit/core/agent_chat/test_orchestrator_pipeline.py`
**Step 1: 写失败测试**
```python
def test_normalize_usage_and_cost_aggregation():
assert False
def test_orchestrator_runs_three_stages_in_order():
assert False
```
**Step 2: 运行并确认失败**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_cost_tracker.py backend/tests/unit/core/agent_chat/test_orchestrator_pipeline.py -v`
Expected: FAIL。
**Step 3: 写最小实现**
```python
class CostTracker:
def add_usage(self, usage: dict) -> None: ...
def total(self) -> dict: ...
class AgentChatOrchestrator:
async def run(self, command):
# intent -> execution -> organization
...
```
**Step 4: 运行并确认通过**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_cost_tracker.py backend/tests/unit/core/agent_chat/test_orchestrator_pipeline.py -v`
Expected: PASS。
### Task 6: 补齐 AG-UI 适配层缺口
**Files:**
- Create: `backend/src/core/agent_chat/agui_adapter.py`
- Modify: `backend/src/core/agent_chat/event_bridge.py`
- Modify: `backend/src/v1/agent_chat/service.py`
- Test: `backend/tests/unit/core/agent_chat/test_agui_adapter.py`
- Test: `backend/tests/integration/test_agent_chat_event_persistence.py`
**Step 1: 写失败测试**
```python
def test_agui_adapter_maps_internal_events_to_protocol_events():
assert False
```
**Step 2: 运行并确认失败**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_agui_adapter.py backend/tests/integration/test_agent_chat_event_persistence.py -v`
Expected: FAIL。
**Step 3: 写最小实现**
```python
class AguiAdapter:
def to_command(self, request): ...
def to_protocol_event(self, event): ...
```
**Step 4: 运行并确认通过**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_agui_adapter.py backend/tests/integration/test_agent_chat_event_persistence.py -v`
Expected: PASS。
### Task 7: 补齐多模态输入编排
**Files:**
- Create: `backend/src/core/agent_chat/multimodal.py`
- Modify: `backend/src/core/agent_chat/storage_adapter.py`
- Modify: `backend/src/core/agent_chat/tools/asr_fun_asr.py`
- Test: `backend/tests/unit/core/agent_chat/test_multimodal.py`
**Step 1: 写失败测试**
```python
def test_multimodal_validates_and_builds_attachment_context():
assert False
```
**Step 2: 运行并确认失败**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_multimodal.py -v`
Expected: FAIL。
**Step 3: 写最小实现**
```python
class MultimodalProcessor:
async def build_context(self, attachments): ...
```
**Step 4: 运行并确认通过**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_multimodal.py backend/tests/unit/core/agent_chat/test_storage_adapter.py backend/tests/unit/core/agent_chat/test_asr_fun_asr_tool.py -v`
Expected: PASS。
### Task 8: 补齐会话审计与 recent 规则
**Files:**
- Modify: `backend/src/v1/agent_chat/service.py`
- Modify: `backend/src/v1/agent_chat/router.py`
- Test: `backend/tests/unit/core/agent_chat/test_session_title_strategy.py`
- Test: `backend/tests/integration/test_agent_chat_session_recent_selection.py`
- Test: `backend/tests/integration/test_agent_chat_session_persistence.py`
**Step 1: 写失败测试**
```python
def test_title_generated_from_first_user_message():
assert False
def test_recent_session_selected_by_last_activity_at_desc():
assert False
```
**Step 2: 运行并确认失败**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_session_title_strategy.py backend/tests/integration/test_agent_chat_session_recent_selection.py backend/tests/integration/test_agent_chat_session_persistence.py -v`
Expected: FAIL。
**Step 3: 写最小实现**
```python
def build_session_title(first_message: str) -> str: ...
```
**Step 4: 运行并确认通过**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_session_title_strategy.py backend/tests/integration/test_agent_chat_session_recent_selection.py backend/tests/integration/test_agent_chat_session_persistence.py -v`
Expected: PASS。
### Task 9: 补齐 E2E 与运行文档闭环
**Files:**
- Create: `backend/tests/e2e/test_agent_chat_flow.py`
- Create: `backend/tests/e2e/test_agent_chat_recent_session_home.py`
- Modify: `docs/runtime/runtime-runbook.md`
**Step 1: 写失败 E2E 用例**
```python
def test_agent_chat_text_image_audio_document_flow():
assert False
```
**Step 2: 运行并确认失败**
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/e2e/test_agent_chat_flow.py -v`
Expected: FAIL。
**Step 3: 写最小实现与文档补充**
```markdown
- bootstrap gate 执行顺序
- agent_chat 验证命令
```
**Step 4: 全量验证**
Run: `make runtime-bootstrap-gate`
Expected: bootstrap 通过。
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat -v`
Expected: PASS。
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/integration -k agent_chat -v`
Expected: PASS。
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/e2e/test_agent_chat_flow.py backend/tests/e2e/test_agent_chat_recent_session_home.py -v`
Expected: PASS。
Run: `PYTHONPATH=backend/src uv run pip check`
Expected: no broken requirements。
+18
View File
@@ -175,6 +175,23 @@ curl -sS -X PATCH http://127.0.0.1:8000/api/v1/profile/me \
-d '{"username":"demo2","bio":"hello"}'
```
## Agent Chat 验证
```bash
# 1) 基础门禁(迁移 + init-data
make runtime-bootstrap-gate
# 2) 运行 agent_chat 相关单测/集成/E2E
PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat -v
PYTHONPATH=backend/src uv run pytest backend/tests/integration -k agent_chat -v
PYTHONPATH=backend/src uv run pytest backend/tests/e2e/test_agent_chat_flow.py backend/tests/e2e/test_agent_chat_recent_session_home.py -v
# 3) 核心接口 smoke
curl -sS -X POST http://127.0.0.1:8000/api/v1/agent-chat/run \
-H 'Content-Type: application/json' \
-d '{"message":"hello"}'
```
---
## 变更日志
@@ -188,3 +205,4 @@ curl -sS -X PATCH http://127.0.0.1:8000/api/v1/profile/me \
| 2026-02-25 | 补充迁移防遗漏规则:容器迁移命令统一追加 --build;开发调试优先使用本地 CLI 一次性迁移脚本 |
| 2026-02-25 | Auth 注册切换为 OTP 三段式:signup/start、signup/verify、signup/resend;邮件模板改为纯验证码展示 |
| 2026-02-25 | 清理未使用配置类:删除 WebSettings/GunicornSettings/WorkerSettings/WorkerGroupSettings(脚本仍使用环境变量启动服务) |
| 2026-02-25 | 新增 Agent Chat 验证章节:bootstrap gate、分层测试命令与 run 接口 smoke 示例 |