Merge branch 'feature/agent-chat-core' into dev
This commit is contained in:
@@ -134,3 +134,12 @@ SOCIAL_SUPABASE__FUNCTIONS_VERIFY_JWT=false
|
||||
SOCIAL_SUPABASE__IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
SOCIAL_SUPABASE__STORAGE_BUCKET_PUBLIC=public
|
||||
SOCIAL_SUPABASE__STORAGE_BUCKET_PRIVATE=private
|
||||
|
||||
############
|
||||
# Agent Chat 附件存储配置(仅基础设施变量)
|
||||
############
|
||||
SOCIAL_STORAGE__PROVIDER=supabase
|
||||
SOCIAL_STORAGE__BUCKET=agent-chat-attachments
|
||||
SOCIAL_STORAGE__SIGNED_URL_TTL_SECONDS=600
|
||||
SOCIAL_STORAGE__MAX_FILE_SIZE_MB=20
|
||||
SOCIAL_STORAGE__RETENTION_DAYS=30
|
||||
|
||||
@@ -17,7 +17,13 @@ if str(src_path) not in sys.path:
|
||||
|
||||
from core.config.settings import config # noqa: E402
|
||||
from core.db.base import Base # noqa: E402
|
||||
from models import Profile # noqa: F401,E402
|
||||
from models import ( # noqa: F401,E402
|
||||
AgentChatMessage,
|
||||
AgentChatSession,
|
||||
Llm,
|
||||
LlmFactory,
|
||||
Profile,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
"""create_agent_chat_core_tables
|
||||
|
||||
Revision ID: 20260226_agent_chat_core
|
||||
Revises: 20260224_drop_profile
|
||||
Create Date: 2026-02-26 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
revision: str = "20260226_agent_chat_core"
|
||||
down_revision: Union[str, Sequence[str], None] = "20260224_drop_profile"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_factory",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(length=50), nullable=False),
|
||||
sa.Column("request_url", sa.String(length=255), nullable=False),
|
||||
sa.Column("avatar", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id", name="pk_llm_factory"),
|
||||
sa.UniqueConstraint("name", name="uq_llm_factory_name"),
|
||||
)
|
||||
op.create_index("ix_llm_factory_name", "llm_factory", ["name"])
|
||||
op.create_index("ix_llm_factory_deleted_at", "llm_factory", ["deleted_at"])
|
||||
|
||||
op.create_table(
|
||||
"llms",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("factory_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("model_code", sa.String(length=50), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["factory_id"], ["llm_factory.id"], ondelete="RESTRICT"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="pk_llms"),
|
||||
sa.UniqueConstraint("model_code", name="uq_llms_model_code"),
|
||||
)
|
||||
op.create_index("ix_llms_factory_id", "llms", ["factory_id"])
|
||||
op.create_index("ix_llms_model_code", "llms", ["model_code"])
|
||||
op.create_index("ix_llms_deleted_at", "llms", ["deleted_at"])
|
||||
|
||||
op.create_table(
|
||||
"sessions",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("title", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"pending",
|
||||
"running",
|
||||
"completed",
|
||||
"failed",
|
||||
name="agent_chat_session_status",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("message_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("total_cost", sa.Numeric(12, 6), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["auth.users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id", name="pk_sessions"),
|
||||
)
|
||||
op.create_index("ix_sessions_user_id", "sessions", ["user_id"])
|
||||
op.create_index(
|
||||
"ix_sessions_user_last_activity",
|
||||
"sessions",
|
||||
["user_id", "last_activity_at"],
|
||||
)
|
||||
op.create_index("ix_sessions_deleted_at", "sessions", ["deleted_at"])
|
||||
|
||||
op.create_table(
|
||||
"messages",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("seq", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum(
|
||||
"user",
|
||||
"assistant",
|
||||
"system",
|
||||
"tool",
|
||||
name="agent_chat_message_role",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("model_code", sa.String(length=50), nullable=True),
|
||||
sa.Column("tool_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("input_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("output_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("cost", sa.Numeric(12, 6), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"currency", sa.String(length=3), nullable=False, server_default="USD"
|
||||
),
|
||||
sa.Column("latency_ms", sa.Integer(), nullable=True),
|
||||
sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["session_id"], ["sessions.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id", name="pk_messages"),
|
||||
sa.UniqueConstraint("session_id", "seq", name="uq_messages_session_seq"),
|
||||
)
|
||||
op.create_index("ix_messages_session_id", "messages", ["session_id"])
|
||||
op.create_index("ix_messages_session_role", "messages", ["session_id", "role"])
|
||||
op.create_index("ix_messages_deleted_at", "messages", ["deleted_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_messages_deleted_at", table_name="messages")
|
||||
op.drop_index("ix_messages_session_role", table_name="messages")
|
||||
op.drop_index("ix_messages_session_id", table_name="messages")
|
||||
op.drop_table("messages")
|
||||
|
||||
op.drop_index("ix_sessions_deleted_at", table_name="sessions")
|
||||
op.drop_index("ix_sessions_user_last_activity", table_name="sessions")
|
||||
op.drop_index("ix_sessions_user_id", table_name="sessions")
|
||||
op.drop_table("sessions")
|
||||
|
||||
op.drop_index("ix_llms_deleted_at", table_name="llms")
|
||||
op.drop_index("ix_llms_model_code", table_name="llms")
|
||||
op.drop_index("ix_llms_factory_id", table_name="llms")
|
||||
op.drop_table("llms")
|
||||
|
||||
op.drop_index("ix_llm_factory_deleted_at", table_name="llm_factory")
|
||||
op.drop_index("ix_llm_factory_name", table_name="llm_factory")
|
||||
op.drop_table("llm_factory")
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.agent_chat.event_bridge import map_internal_event
|
||||
|
||||
|
||||
class AguiAdapter:
|
||||
def to_command(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
message = payload.get("message")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
raise ValueError("message is required")
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"session_id": payload.get("session_id"),
|
||||
}
|
||||
|
||||
def to_protocol_event(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
return map_internal_event(event)
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def _to_non_negative_int(value: Any, *, field: str) -> int:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if isinstance(value, int):
|
||||
converted = value
|
||||
elif isinstance(value, str) and value.isdigit():
|
||||
converted = int(value)
|
||||
else:
|
||||
raise ValueError(f"{field} must be an integer")
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
def _to_non_negative_decimal(value: Any, *, field: str) -> Decimal:
|
||||
converted = Decimal(str(value))
|
||||
if converted < 0:
|
||||
raise ValueError(f"{field} cannot be negative")
|
||||
return converted
|
||||
|
||||
|
||||
class CostTracker:
|
||||
def __init__(self, *, currency: str = "USD") -> None:
|
||||
self._input_tokens = 0
|
||||
self._output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._cost = Decimal("0")
|
||||
self._currency = currency
|
||||
|
||||
def add_usage(self, usage: Mapping[str, Any]) -> None:
|
||||
input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0))
|
||||
output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0))
|
||||
total_tokens = usage.get("total_tokens")
|
||||
cost = usage.get("cost", "0")
|
||||
currency = usage.get("currency")
|
||||
|
||||
normalized_input = _to_non_negative_int(input_tokens, field="input_tokens")
|
||||
normalized_output = _to_non_negative_int(output_tokens, field="output_tokens")
|
||||
normalized_total = (
|
||||
_to_non_negative_int(total_tokens, field="total_tokens")
|
||||
if total_tokens is not None
|
||||
else normalized_input + normalized_output
|
||||
)
|
||||
normalized_cost = _to_non_negative_decimal(cost, field="cost")
|
||||
|
||||
self._input_tokens += normalized_input
|
||||
self._output_tokens += normalized_output
|
||||
self._total_tokens += normalized_total
|
||||
self._cost += normalized_cost
|
||||
|
||||
if currency is not None:
|
||||
normalized_currency = str(currency)
|
||||
if normalized_currency != self._currency:
|
||||
raise ValueError("currency mismatch")
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self._input_tokens,
|
||||
"output_tokens": self._output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"cost": self._cost,
|
||||
"currency": self._currency,
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrewAITemplate:
|
||||
agents: dict[str, Any]
|
||||
tasks: dict[str, Any]
|
||||
workflow: dict[str, Any]
|
||||
prompts: dict[str, str]
|
||||
tools_whitelist: set[str]
|
||||
|
||||
|
||||
def _default_static_root() -> Path:
|
||||
return Path(__file__).resolve().parents[3] / "config" / "static" / "agent_chat"
|
||||
|
||||
|
||||
def _read_yaml(file_path: Path) -> dict[str, Any]:
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"Required config file not found: {file_path}")
|
||||
with file_path.open("r", encoding="utf-8") as file:
|
||||
loaded = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"YAML file must be a mapping: {file_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def _read_prompt(file_path: Path) -> str:
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"Required prompt file not found: {file_path}")
|
||||
return file_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
|
||||
def validate_workflow_stages(stages: list[str]) -> None:
|
||||
expected = ["intent", "execution", "organization"]
|
||||
if stages != expected:
|
||||
raise ValueError(f"Invalid workflow stages: {stages}, expected: {expected}")
|
||||
|
||||
|
||||
def load_tools_whitelist(static_root: Path | None = None) -> set[str]:
|
||||
root = static_root or _default_static_root()
|
||||
tools = _read_yaml(root / "tools.yaml")
|
||||
raw_tools = tools.get("tools", [])
|
||||
if not isinstance(raw_tools, list):
|
||||
raise ValueError("tools.yaml field 'tools' must be a list")
|
||||
if not all(isinstance(item, str) and item.strip() for item in raw_tools):
|
||||
raise ValueError("tools.yaml list items must be non-empty strings")
|
||||
whitelist = {item.strip() for item in raw_tools}
|
||||
return whitelist
|
||||
|
||||
|
||||
def load_crewai_template(static_root: Path | None = None) -> CrewAITemplate:
|
||||
root = static_root or _default_static_root()
|
||||
crewai_root = root / "crewai"
|
||||
|
||||
agents = _read_yaml(crewai_root / "agents.yaml")
|
||||
tasks = _read_yaml(crewai_root / "tasks.yaml")
|
||||
workflow = _read_yaml(crewai_root / "workflow.yaml")
|
||||
|
||||
stages = workflow.get("stages")
|
||||
if not isinstance(stages, list):
|
||||
raise ValueError("workflow.yaml field 'stages' must be a list")
|
||||
validate_workflow_stages([str(stage) for stage in stages])
|
||||
|
||||
prompts = {
|
||||
"intent": _read_prompt(crewai_root / "prompts" / "intent.md"),
|
||||
"execution": _read_prompt(crewai_root / "prompts" / "execution.md"),
|
||||
"organization": _read_prompt(crewai_root / "prompts" / "organization.md"),
|
||||
}
|
||||
|
||||
return CrewAITemplate(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
workflow=workflow,
|
||||
prompts=prompts,
|
||||
tools_whitelist=load_tools_whitelist(root),
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _require_fields(event: dict[str, Any], *, kind: str, required: list[str]) -> None:
|
||||
missing = [field for field in required if field not in event]
|
||||
if missing:
|
||||
raise ValueError(f"Missing fields for {kind}: {missing}")
|
||||
|
||||
|
||||
def map_internal_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
kind = event.get("kind")
|
||||
|
||||
if kind == "run_started":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.started",
|
||||
"run_id": event["session_id"],
|
||||
}
|
||||
|
||||
if kind == "message_delta":
|
||||
_require_fields(event, kind=kind, required=["message_id", "delta"])
|
||||
return {
|
||||
"type": "message.delta",
|
||||
"message_id": event["message_id"],
|
||||
"delta": event["delta"],
|
||||
}
|
||||
|
||||
if kind == "tool_started":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.started",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
}
|
||||
|
||||
if kind == "tool_completed":
|
||||
_require_fields(event, kind=kind, required=["message_id", "tool_name"])
|
||||
return {
|
||||
"type": "tool.completed",
|
||||
"message_id": event["message_id"],
|
||||
"tool_name": event["tool_name"],
|
||||
"result": event.get("result"),
|
||||
}
|
||||
|
||||
if kind == "run_completed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": event["session_id"],
|
||||
"output": event.get("output", ""),
|
||||
}
|
||||
|
||||
if kind == "run_failed":
|
||||
_require_fields(event, kind=kind, required=["session_id"])
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": event["session_id"],
|
||||
"error": event.get("error", ""),
|
||||
}
|
||||
|
||||
raise ValueError(f"Unsupported event kind: {kind}")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_started(*, run_id: str) -> dict[str, Any]:
|
||||
return {"type": "run.started", "run_id": run_id}
|
||||
|
||||
|
||||
def stage_completed(
|
||||
*, run_id: str, stage: str, usage: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
event: dict[str, Any] = {
|
||||
"type": "stage.completed",
|
||||
"run_id": run_id,
|
||||
"stage": stage,
|
||||
}
|
||||
if usage is not None:
|
||||
event["usage"] = usage
|
||||
return event
|
||||
|
||||
|
||||
def run_completed(*, run_id: str, output: str, usage: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.completed",
|
||||
"run_id": run_id,
|
||||
"output": output,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
|
||||
def run_failed(*, run_id: str, error: str) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "run.failed",
|
||||
"run_id": run_id,
|
||||
"error": error,
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from core.agent_chat.storage_adapter import StorageAdapter
|
||||
|
||||
_ALLOWED_MIME_TYPES = {
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
}
|
||||
|
||||
|
||||
class _AsrTool(Protocol):
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, object]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttachmentInput:
|
||||
filename: str
|
||||
mime_type: str
|
||||
content: bytes
|
||||
origin: str = "user_upload"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProcessedAttachmentContext:
|
||||
attachments: list[dict[str, object]]
|
||||
preview_texts: list[str]
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
storage: StorageAdapter,
|
||||
asr_tool: _AsrTool,
|
||||
max_file_size_mb: int = 20,
|
||||
) -> None:
|
||||
self._storage = storage
|
||||
self._asr_tool = asr_tool
|
||||
self._max_size_bytes = max_file_size_mb * 1024 * 1024
|
||||
|
||||
def process(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
attachments: list[AttachmentInput],
|
||||
) -> ProcessedAttachmentContext:
|
||||
metadata_list: list[dict[str, object]] = []
|
||||
preview_texts: list[str] = []
|
||||
|
||||
for attachment in attachments:
|
||||
self._validate_attachment(attachment)
|
||||
checksum = hashlib.sha256(attachment.content).hexdigest()
|
||||
extension = Path(attachment.filename).suffix.strip(".").lower() or "bin"
|
||||
object_path = self._storage.build_object_path(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_seq=message_seq,
|
||||
checksum_sha256=checksum,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
preview_text = self._build_preview_text(attachment)
|
||||
if preview_text:
|
||||
preview_texts.append(preview_text)
|
||||
|
||||
metadata = self._storage.build_attachment_metadata(
|
||||
object_path=object_path,
|
||||
mime_type=attachment.mime_type,
|
||||
size=len(attachment.content),
|
||||
checksum_sha256=checksum,
|
||||
origin=attachment.origin,
|
||||
preview_text=preview_text,
|
||||
)
|
||||
metadata_list.append(metadata)
|
||||
|
||||
return ProcessedAttachmentContext(
|
||||
attachments=metadata_list,
|
||||
preview_texts=preview_texts,
|
||||
)
|
||||
|
||||
def _validate_attachment(self, attachment: AttachmentInput) -> None:
|
||||
if attachment.mime_type not in _ALLOWED_MIME_TYPES:
|
||||
raise ValueError("Unsupported MIME type")
|
||||
if len(attachment.content) > self._max_size_bytes:
|
||||
raise ValueError("Attachment exceeds max file size")
|
||||
|
||||
def _build_preview_text(self, attachment: AttachmentInput) -> str | None:
|
||||
if attachment.mime_type.startswith("audio/"):
|
||||
result = self._asr_tool.transcribe(
|
||||
audio_bytes=attachment.content,
|
||||
filename=attachment.filename,
|
||||
)
|
||||
text = result.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return None
|
||||
if attachment.mime_type == "text/plain":
|
||||
return attachment.content.decode("utf-8", errors="ignore")[:200]
|
||||
return None
|
||||
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from core.agent_chat.cost_tracker import CostTracker
|
||||
from core.agent_chat import events
|
||||
|
||||
StageCallable = Callable[..., Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrchestratorResult:
|
||||
output: str
|
||||
usage: dict[str, Any]
|
||||
events: list[dict[str, Any]]
|
||||
context: dict[str, Any]
|
||||
failed: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intent_stage: StageCallable,
|
||||
execution_stage: StageCallable,
|
||||
organization_stage: StageCallable,
|
||||
) -> None:
|
||||
self._intent_stage = intent_stage
|
||||
self._execution_stage = execution_stage
|
||||
self._organization_stage = organization_stage
|
||||
|
||||
def run_sync(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return asyncio.run(self.run(run_id=run_id, user_message=user_message))
|
||||
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
tracker = CostTracker()
|
||||
emitted_events: list[dict[str, Any]] = [events.run_started(run_id=run_id)]
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
stage_pipeline: list[tuple[str, StageCallable]] = [
|
||||
("intent", self._intent_stage),
|
||||
("execution", self._execution_stage),
|
||||
("organization", self._organization_stage),
|
||||
]
|
||||
|
||||
stage_output = user_message
|
||||
try:
|
||||
for stage_name, stage_callable in stage_pipeline:
|
||||
stage_result = await stage_callable(
|
||||
message=stage_output, context=context
|
||||
)
|
||||
stage_output = str(stage_result.get("content", stage_output))
|
||||
usage = stage_result.get("usage", {})
|
||||
if isinstance(usage, dict):
|
||||
tracker.add_usage(usage)
|
||||
emitted_events.append(
|
||||
events.stage_completed(
|
||||
run_id=run_id,
|
||||
stage=stage_name,
|
||||
usage=tracker.snapshot(),
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
emitted_events.append(events.run_failed(run_id=run_id, error=str(exc)))
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage=tracker.snapshot(),
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=True,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
summary = tracker.snapshot()
|
||||
emitted_events.append(
|
||||
events.run_completed(run_id=run_id, output=stage_output, usage=summary)
|
||||
)
|
||||
return OrchestratorResult(
|
||||
output=stage_output,
|
||||
usage=summary,
|
||||
events=emitted_events,
|
||||
context=context,
|
||||
failed=False,
|
||||
error=None,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class StorageAdapter:
|
||||
_bucket: str
|
||||
|
||||
def __init__(self, bucket: str) -> None:
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._bucket
|
||||
|
||||
def build_object_path(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message_seq: int,
|
||||
checksum_sha256: str,
|
||||
extension: str,
|
||||
) -> str:
|
||||
normalized_ext = extension.strip(".").lower()
|
||||
return (
|
||||
f"agent-chat/{user_id}/{session_id}/{message_seq}/"
|
||||
f"{checksum_sha256}.{normalized_ext}"
|
||||
)
|
||||
|
||||
def build_attachment_metadata(
|
||||
self,
|
||||
*,
|
||||
object_path: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
checksum_sha256: str,
|
||||
origin: str,
|
||||
preview_text: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"object_path": object_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"checksum_sha256": checksum_sha256,
|
||||
"origin": origin,
|
||||
"preview_text": preview_text,
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
TranscribeCallable = Callable[..., dict[str, Any]]
|
||||
|
||||
|
||||
class FunASRTool:
|
||||
_transcribe_callable: TranscribeCallable
|
||||
_model: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transcribe_callable: TranscribeCallable | None = None,
|
||||
model: str = "fun-asr-realtime-2025-11-07",
|
||||
) -> None:
|
||||
self._transcribe_callable = transcribe_callable or self._dashscope_transcribe
|
||||
self._model = model
|
||||
|
||||
def transcribe(self, *, audio_bytes: bytes, filename: str) -> dict[str, Any]:
|
||||
payload = self._transcribe_callable(audio_bytes=audio_bytes, filename=filename)
|
||||
return {
|
||||
"model": self._model,
|
||||
**payload,
|
||||
}
|
||||
|
||||
def _dashscope_transcribe(
|
||||
self, *, audio_bytes: bytes, filename: str
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
importlib.import_module("dashscope")
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("DashScope SDK is not installed") from exc
|
||||
|
||||
raise RuntimeError(
|
||||
"DashScope transcribe runtime integration is not configured yet"
|
||||
)
|
||||
@@ -132,6 +132,14 @@ class SupabaseSettings(BaseModel):
|
||||
return self.public_url
|
||||
|
||||
|
||||
class StorageSettings(BaseModel):
|
||||
provider: Literal["supabase"] = "supabase"
|
||||
bucket: str = Field(default="agent-chat-attachments", min_length=3, max_length=63)
|
||||
signed_url_ttl_seconds: int = Field(default=600, ge=60, le=3600)
|
||||
max_file_size_mb: int = Field(default=20, ge=1, le=200)
|
||||
retention_days: int = Field(default=30, ge=1, le=3650)
|
||||
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
@@ -163,6 +171,7 @@ class Settings(BaseSettings):
|
||||
cors: CorsSettings = CorsSettings()
|
||||
redis: RedisSettings = RedisSettings()
|
||||
supabase: SupabaseSettings = SupabaseSettings()
|
||||
storage: StorageSettings = StorageSettings()
|
||||
celery: CelerySettings = CelerySettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
intent:
|
||||
role: Intent Agent
|
||||
goal: Classify user intent and decide execution strategy
|
||||
execution:
|
||||
role: Execution Agent
|
||||
goal: Execute tasks with available tools
|
||||
organization:
|
||||
role: Organization Agent
|
||||
goal: Organize output for user-friendly response
|
||||
@@ -0,0 +1,2 @@
|
||||
你是任务执行代理。
|
||||
基于输入意图与上下文调用可用工具,并生成可验证执行结果。
|
||||
@@ -0,0 +1,2 @@
|
||||
你是意图识别代理。
|
||||
你的任务是识别用户输入的意图类型,并返回结构化意图标签。
|
||||
@@ -0,0 +1,2 @@
|
||||
你是结果整理代理。
|
||||
将执行结果组织为面向用户的清晰回复,保留关键信息与必要引用。
|
||||
@@ -0,0 +1,6 @@
|
||||
intent:
|
||||
description: Identify user intent and required capabilities
|
||||
execution:
|
||||
description: Execute intent with tools and model calls
|
||||
organization:
|
||||
description: Format final response and references
|
||||
@@ -0,0 +1,9 @@
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
|
||||
timeouts:
|
||||
intent_seconds: 8
|
||||
execution_seconds: 30
|
||||
organization_seconds: 10
|
||||
@@ -0,0 +1,25 @@
|
||||
factories:
|
||||
- name: qwen
|
||||
request_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
avatar: https://cdn.simpleicons.org/alibabacloud/FF6A00
|
||||
- name: minimax
|
||||
request_url: https://api.minimax.chat/v1
|
||||
avatar: https://cdn.simpleicons.org/minimax/1A1A1A
|
||||
- name: kimi
|
||||
request_url: https://api.moonshot.cn/v1
|
||||
avatar: https://cdn.simpleicons.org/moonrepo/3B82F6
|
||||
- name: deepseek
|
||||
request_url: https://api.deepseek.com/v1
|
||||
avatar: https://cdn.simpleicons.org/deepseek/4D6BFE
|
||||
- name: doubao
|
||||
request_url: https://ark.cn-beijing.volces.com/api/v3
|
||||
avatar: https://cdn.simpleicons.org/volkswagen/001E50
|
||||
- name: zai
|
||||
request_url: https://api.z.ai/v1
|
||||
avatar: https://cdn.simpleicons.org/zotero/CC2936
|
||||
|
||||
llms:
|
||||
- model_code: qwen3.5-flash
|
||||
factory_id: qwen
|
||||
- model_code: deepseek-v3.2
|
||||
factory_id: deepseek
|
||||
@@ -0,0 +1,3 @@
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- attachment_extract
|
||||
@@ -1,9 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
logger = get_logger("core.initialization.init_data")
|
||||
|
||||
|
||||
class LlmFactorySeed(BaseModel):
|
||||
name: str
|
||||
request_url: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class LlmSeed(BaseModel):
|
||||
model_code: str
|
||||
factory_id: str
|
||||
|
||||
|
||||
class LlmCatalogSeed(BaseModel):
|
||||
factories: list[LlmFactorySeed]
|
||||
llms: list[LlmSeed]
|
||||
|
||||
|
||||
def _default_catalog_path() -> Path:
|
||||
return (
|
||||
Path(__file__).resolve().parents[1]
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "agent_chat"
|
||||
/ "llm_catalog.yaml"
|
||||
)
|
||||
|
||||
|
||||
def load_llm_catalog(catalog_path: Path | None = None) -> dict[str, Any]:
|
||||
path = catalog_path or _default_catalog_path()
|
||||
with path.open("r", encoding="utf-8") as file:
|
||||
loaded = yaml.safe_load(file) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"Invalid LLM catalog format: {path}")
|
||||
raw_factories = loaded.get("factories", [])
|
||||
raw_llms = loaded.get("llms", [])
|
||||
if not isinstance(raw_factories, list) or not isinstance(raw_llms, list):
|
||||
raise ValueError(f"Invalid LLM catalog sections: {path}")
|
||||
try:
|
||||
parsed = LlmCatalogSeed.model_validate(
|
||||
{
|
||||
"factories": list(raw_factories),
|
||||
"llms": list(raw_llms),
|
||||
}
|
||||
)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid LLM catalog data: {path}") from exc
|
||||
|
||||
return parsed.model_dump()
|
||||
|
||||
|
||||
async def _upsert_factory(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
name: str,
|
||||
request_url: str,
|
||||
avatar: str | None,
|
||||
) -> uuid.UUID:
|
||||
result = await session.execute(select(LlmFactory).where(LlmFactory.name == name))
|
||||
factory = result.scalar_one_or_none()
|
||||
|
||||
if factory is None:
|
||||
factory = LlmFactory(name=name, request_url=request_url, avatar=avatar)
|
||||
session.add(factory)
|
||||
await session.flush()
|
||||
else:
|
||||
factory.request_url = request_url
|
||||
factory.avatar = avatar
|
||||
|
||||
return factory.id
|
||||
|
||||
|
||||
async def _upsert_llm(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
model_code: str,
|
||||
factory_id: uuid.UUID,
|
||||
) -> None:
|
||||
result = await session.execute(select(Llm).where(Llm.model_code == model_code))
|
||||
llm = result.scalar_one_or_none()
|
||||
if llm is None:
|
||||
session.add(Llm(model_code=model_code, factory_id=factory_id))
|
||||
return
|
||||
llm.factory_id = factory_id
|
||||
|
||||
|
||||
async def initialize_data() -> bool:
|
||||
"""Initialize bootstrap data."""
|
||||
logger.info("Initializing data (no-op)")
|
||||
catalog = load_llm_catalog()
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
async with session.begin():
|
||||
factory_id_by_name: dict[str, uuid.UUID] = {}
|
||||
for factory in catalog["factories"]:
|
||||
factory_id = await _upsert_factory(
|
||||
session,
|
||||
name=factory["name"],
|
||||
request_url=factory["request_url"],
|
||||
avatar=factory.get("avatar"),
|
||||
)
|
||||
factory_id_by_name[factory["name"]] = factory_id
|
||||
|
||||
for llm in catalog["llms"]:
|
||||
factory_name = llm["factory_id"]
|
||||
resolved_factory_id = factory_id_by_name.get(factory_name)
|
||||
if resolved_factory_id is None:
|
||||
raise RuntimeError(
|
||||
f"Factory '{factory_name}' not found for model '{llm['model_code']}'"
|
||||
)
|
||||
await _upsert_llm(
|
||||
session,
|
||||
model_code=llm["model_code"],
|
||||
factory_id=resolved_factory_id,
|
||||
)
|
||||
|
||||
logger.info("Initialized LLM factory/model seed data")
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
from models.profile import Profile
|
||||
|
||||
__all__ = ["Profile"]
|
||||
__all__ = [
|
||||
"AgentChatMessage",
|
||||
"AgentChatSession",
|
||||
"Llm",
|
||||
"LlmFactory",
|
||||
"Profile",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Enum as SqlEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class AgentChatMessageRole(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
class AgentChatMessage(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__: str = "messages"
|
||||
__table_args__: tuple[UniqueConstraint] = (
|
||||
UniqueConstraint("session_id", "seq", name="uq_messages_session_seq"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
seq: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
role: Mapped[AgentChatMessageRole] = mapped_column(
|
||||
SqlEnum(
|
||||
AgentChatMessageRole, name="agent_chat_message_role", native_enum=False
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
model_code: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
tool_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
input_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
cost: Mapped[Decimal] = mapped_column(Numeric(12, 6), nullable=False, default=0)
|
||||
currency: Mapped[str] = mapped_column(String(3), nullable=False, default="USD")
|
||||
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
metadata_json: Mapped[dict[str, object] | None] = mapped_column(
|
||||
"metadata", JSON().with_variant(JSONB, "postgresql"), nullable=True
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
DateTime,
|
||||
Enum as SqlEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class AgentChatSessionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class AgentChatSession(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__: str = "sessions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("auth.users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
status: Mapped[AgentChatSessionStatus] = mapped_column(
|
||||
SqlEnum(
|
||||
AgentChatSessionStatus, name="agent_chat_session_status", native_enum=False
|
||||
),
|
||||
nullable=False,
|
||||
default=AgentChatSessionStatus.PENDING,
|
||||
)
|
||||
last_activity_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
message_count: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
total_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
total_cost: Mapped[Decimal] = mapped_column(
|
||||
Numeric(12, 6), nullable=False, server_default=text("0")
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class Llm(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__: str = "llms"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
factory_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("llm_factory.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
model_code: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, unique=True, index=True
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
|
||||
|
||||
class LlmFactory(TimestampMixin, SoftDeleteMixin, Base):
|
||||
__tablename__: str = "llm_factory"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, unique=True, index=True
|
||||
)
|
||||
request_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
avatar: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
@@ -0,0 +1 @@
|
||||
from __future__ import annotations
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
from v1.profile.dependencies import get_current_user
|
||||
|
||||
|
||||
def get_agent_chat_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[CurrentUser, Depends(get_current_user)],
|
||||
) -> AgentChatService:
|
||||
return AgentChatService(session=session, current_user=user)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from v1.agent_chat.dependencies import get_agent_chat_service
|
||||
from v1.agent_chat.schemas import AgentChatRunRequest, AgentChatRunResponse
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
router = APIRouter(prefix="/agent-chat", tags=["agent-chat"])
|
||||
|
||||
|
||||
@router.post("/run", response_model=AgentChatRunResponse)
|
||||
async def run_agent_chat(
|
||||
payload: AgentChatRunRequest,
|
||||
service: Annotated[AgentChatService, Depends(get_agent_chat_service)],
|
||||
) -> AgentChatRunResponse:
|
||||
return await service.run(payload)
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentChatRunRequest(BaseModel):
|
||||
message: str = Field(min_length=1, max_length=8000)
|
||||
session_id: UUID | None = None
|
||||
|
||||
|
||||
class AgentChatEvent(BaseModel):
|
||||
type: str
|
||||
run_id: str | None = None
|
||||
message_id: str | None = None
|
||||
delta: str | None = None
|
||||
tool_name: str | None = None
|
||||
result: str | None = None
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AgentChatRunResponse(BaseModel):
|
||||
session_id: UUID
|
||||
output: str
|
||||
events: list[AgentChatEvent]
|
||||
@@ -0,0 +1,286 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.agent_chat.agui_adapter import AguiAdapter
|
||||
from core.agent_chat.orchestrator import AgentChatOrchestrator
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.auth.rate_limit import enforce_rate_limit
|
||||
from v1.agent_chat.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = get_logger("v1.agent_chat.service")
|
||||
|
||||
|
||||
def build_session_title(first_message: str, *, now: datetime) -> str:
|
||||
title = first_message.strip().replace("\n", " ")[:24]
|
||||
if not title:
|
||||
return now.strftime("新对话 %Y-%m-%d %H:%M")
|
||||
return title
|
||||
|
||||
|
||||
def aggregate_session_cost(costs: list[Decimal]) -> Decimal:
|
||||
total = Decimal("0")
|
||||
for cost in costs:
|
||||
if cost < 0:
|
||||
raise ValueError("cost must be non-negative")
|
||||
total += cost
|
||||
return total
|
||||
|
||||
|
||||
def select_recent_session(
|
||||
sessions: list[AgentChatSession],
|
||||
) -> AgentChatSession | None:
|
||||
if not sessions:
|
||||
return None
|
||||
return max(sessions, key=lambda item: item.last_activity_at)
|
||||
|
||||
|
||||
class AgentChatService(BaseService):
|
||||
_session: AsyncSession
|
||||
|
||||
def __init__(self, session: AsyncSession, current_user: CurrentUser | None) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._session = session
|
||||
self._adapter = AguiAdapter()
|
||||
self._orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=self._intent_stage,
|
||||
execution_stage=self._execution_stage,
|
||||
organization_stage=self._organization_stage,
|
||||
)
|
||||
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
try:
|
||||
command = self._adapter.to_command(payload.model_dump(mode="python"))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
user_id = self.require_user_id()
|
||||
await enforce_rate_limit(
|
||||
scope="agent_chat_run",
|
||||
identifier=str(user_id),
|
||||
limit=60,
|
||||
window_seconds=60,
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
chat_session = await self._resolve_session(
|
||||
session_id=command["session_id"],
|
||||
user_id=user_id,
|
||||
first_message=command["message"],
|
||||
now=now,
|
||||
)
|
||||
|
||||
base_seq = await self._next_seq_base(chat_session.id)
|
||||
user_message = AgentChatMessage(
|
||||
session_id=chat_session.id,
|
||||
seq=base_seq + 1,
|
||||
role=AgentChatMessageRole.USER,
|
||||
content=command["message"],
|
||||
cost=Decimal("0"),
|
||||
)
|
||||
orchestrator_result = await self._orchestrator.run(
|
||||
run_id=str(chat_session.id),
|
||||
user_message=command["message"],
|
||||
)
|
||||
assistant_message = AgentChatMessage(
|
||||
session_id=chat_session.id,
|
||||
seq=base_seq + 2,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=orchestrator_result.output,
|
||||
input_tokens=int(orchestrator_result.usage["input_tokens"]),
|
||||
output_tokens=int(orchestrator_result.usage["output_tokens"]),
|
||||
cost=Decimal(orchestrator_result.usage["cost"]),
|
||||
)
|
||||
self._session.add(user_message)
|
||||
self._session.add(assistant_message)
|
||||
|
||||
chat_session.status = (
|
||||
AgentChatSessionStatus.FAILED
|
||||
if orchestrator_result.failed
|
||||
else AgentChatSessionStatus.COMPLETED
|
||||
)
|
||||
chat_session.last_activity_at = now
|
||||
chat_session.message_count = chat_session.message_count + 2
|
||||
chat_session.total_tokens = chat_session.total_tokens + int(
|
||||
orchestrator_result.usage["total_tokens"]
|
||||
)
|
||||
chat_session.total_cost = aggregate_session_cost(
|
||||
[
|
||||
Decimal(chat_session.total_cost),
|
||||
Decimal(orchestrator_result.usage["cost"]),
|
||||
]
|
||||
)
|
||||
|
||||
await self._session.commit()
|
||||
await self._session.refresh(chat_session)
|
||||
await self._session.refresh(user_message)
|
||||
|
||||
mapped_events = self._build_mapped_events(
|
||||
session_id=str(chat_session.id),
|
||||
message_id=str(user_message.id),
|
||||
user_message=command["message"],
|
||||
assistant_output=assistant_message.content,
|
||||
failed=orchestrator_result.failed,
|
||||
error=orchestrator_result.error,
|
||||
)
|
||||
events = [AgentChatEvent.model_validate(item) for item in mapped_events]
|
||||
if orchestrator_result.failed:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Agent orchestration failed"
|
||||
)
|
||||
return AgentChatRunResponse(
|
||||
session_id=chat_session.id,
|
||||
output=assistant_message.content,
|
||||
events=events,
|
||||
)
|
||||
except HTTPException:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Agent chat run failed")
|
||||
raise HTTPException(status_code=503, detail="Agent chat store unavailable")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await self._session.rollback()
|
||||
logger.exception(
|
||||
"Agent chat unexpected failure", error_type=type(exc).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Agent orchestration failed"
|
||||
) from exc
|
||||
|
||||
def _build_mapped_events(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
user_message: str,
|
||||
assistant_output: str,
|
||||
failed: bool,
|
||||
error: str | None,
|
||||
) -> list[dict[str, object]]:
|
||||
mapped_events = [
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_started",
|
||||
"session_id": session_id,
|
||||
}
|
||||
),
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "message_delta",
|
||||
"message_id": message_id,
|
||||
"delta": user_message,
|
||||
}
|
||||
),
|
||||
]
|
||||
if failed:
|
||||
mapped_events.append(
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_failed",
|
||||
"session_id": session_id,
|
||||
"error": error or "orchestration failed",
|
||||
}
|
||||
)
|
||||
)
|
||||
return mapped_events
|
||||
|
||||
mapped_events.append(
|
||||
self._adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": session_id,
|
||||
"output": assistant_output,
|
||||
}
|
||||
)
|
||||
)
|
||||
return mapped_events
|
||||
|
||||
async def _resolve_session(
|
||||
self,
|
||||
*,
|
||||
session_id: object | None,
|
||||
user_id: UUID,
|
||||
first_message: str,
|
||||
now: datetime,
|
||||
) -> AgentChatSession:
|
||||
if session_id is not None:
|
||||
stmt = (
|
||||
select(AgentChatSession)
|
||||
.where(AgentChatSession.id == session_id)
|
||||
.where(AgentChatSession.user_id == user_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.with_for_update()
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
existing.status = AgentChatSessionStatus.RUNNING
|
||||
return existing
|
||||
|
||||
title = build_session_title(first_message, now=now)
|
||||
|
||||
created = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=now,
|
||||
)
|
||||
self._session.add(created)
|
||||
await self._session.flush()
|
||||
return created
|
||||
|
||||
async def _next_seq_base(self, session_id: object) -> int:
|
||||
stmt = select(func.max(AgentChatMessage.seq)).where(
|
||||
AgentChatMessage.session_id == session_id
|
||||
)
|
||||
result = await self._session.scalar(stmt)
|
||||
if result is None:
|
||||
return 0
|
||||
return int(result)
|
||||
|
||||
async def _intent_stage(
|
||||
self, *, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
context["intent"] = "default"
|
||||
return {
|
||||
"content": message,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
|
||||
}
|
||||
|
||||
async def _execution_stage(
|
||||
self, *, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"content": message,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
|
||||
}
|
||||
|
||||
async def _organization_stage(
|
||||
self, *, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"content": message,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
|
||||
}
|
||||
@@ -74,6 +74,12 @@ async def login(
|
||||
payload: LoginRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="login",
|
||||
identifier=payload.email,
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
return await service.login(payload)
|
||||
|
||||
|
||||
@@ -82,6 +88,12 @@ async def refresh(
|
||||
payload: RefreshRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> AuthTokenResponse:
|
||||
await enforce_rate_limit(
|
||||
scope="refresh",
|
||||
identifier=payload.refresh_token,
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
return await service.refresh(payload)
|
||||
|
||||
|
||||
@@ -90,6 +102,12 @@ async def logout(
|
||||
payload: LogoutRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> Response:
|
||||
await enforce_rate_limit(
|
||||
scope="logout",
|
||||
identifier=payload.refresh_token,
|
||||
limit=10,
|
||||
window_seconds=60,
|
||||
)
|
||||
await service.logout(payload.refresh_token)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.http.models import HealthResponse
|
||||
from v1.agent_chat.router import router as agent_chat_router
|
||||
from v1.auth.router import router as auth_router
|
||||
from v1.infra.router import router as infra_router
|
||||
from v1.profile.router import router as profile_router
|
||||
@@ -12,6 +13,7 @@ router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(auth_router)
|
||||
router.include_router(infra_router)
|
||||
router.include_router(profile_router)
|
||||
router.include_router(agent_chat_router)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
import uvicorn
|
||||
|
||||
from app import app
|
||||
from v1.agent_chat.dependencies import get_agent_chat_service
|
||||
from v1.agent_chat.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
class FakeE2EAgentChatService(AgentChatService):
|
||||
def __init__(self) -> None:
|
||||
return None
|
||||
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
session_id = payload.session_id or UUID("00000000-0000-0000-0000-000000000001")
|
||||
return AgentChatRunResponse(
|
||||
session_id=session_id,
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(type="run.started", run_id=str(session_id)),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed", run_id=str(session_id), output=payload.message
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: float = 5.0) -> None:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex((host, port)) == 0:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Server did not start in time")
|
||||
|
||||
|
||||
def _start_server(host: str, port: int):
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
_wait_for_port(host, port)
|
||||
return server, thread
|
||||
|
||||
|
||||
def test_agent_chat_flow_e2e() -> None:
|
||||
app.dependency_overrides[get_agent_chat_service] = lambda: FakeE2EAgentChatService()
|
||||
host = "127.0.0.1"
|
||||
port = _find_free_port()
|
||||
server, thread = _start_server(host, port)
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
request_context = playwright.request.new_context(
|
||||
base_url=f"http://{host}:{port}"
|
||||
)
|
||||
try:
|
||||
response = request_context.post(
|
||||
"/api/v1/agent-chat/run",
|
||||
data=json.dumps({"message": "hello"}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert response.status == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
request_context.dispose()
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent_chat.service import select_recent_session
|
||||
|
||||
|
||||
def test_recent_session_home_default_selection() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 8, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=100,
|
||||
total_cost=Decimal("0.010000"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-0000000000a2"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000c1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=3,
|
||||
total_tokens=120,
|
||||
total_cost=Decimal("0.020000"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-0000000000a2")
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from types import MethodType
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent_chat.schemas import AgentChatRunRequest
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
class _FakeAsyncSession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
def add(self, obj: object) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
async def refresh(self, obj: object) -> None:
|
||||
if isinstance(obj, AgentChatSession) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
if isinstance(obj, AgentChatMessage) and obj.id is None:
|
||||
obj.id = uuid4()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_persists_messages_and_emits_ordered_events() -> None:
|
||||
fake_db = _FakeAsyncSession()
|
||||
service = AgentChatService(
|
||||
session=fake_db, # type: ignore[arg-type]
|
||||
current_user=CurrentUser(id=UUID("00000000-0000-0000-0000-000000000001")),
|
||||
)
|
||||
|
||||
async def _resolve_session(
|
||||
self: AgentChatService,
|
||||
*,
|
||||
session_id: object | None,
|
||||
user_id: UUID,
|
||||
first_message: str,
|
||||
now: datetime,
|
||||
) -> AgentChatSession:
|
||||
assert session_id is None
|
||||
assert first_message == "hello"
|
||||
return AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000111"),
|
||||
user_id=user_id,
|
||||
title="hello",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=now,
|
||||
message_count=0,
|
||||
total_tokens=0,
|
||||
total_cost=Decimal("0"),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
async def _next_seq_base(self: AgentChatService, session_id: object) -> int:
|
||||
assert session_id == UUID("00000000-0000-0000-0000-000000000111")
|
||||
return 2
|
||||
|
||||
service._resolve_session = MethodType(_resolve_session, service) # type: ignore[method-assign]
|
||||
service._next_seq_base = MethodType(_next_seq_base, service) # type: ignore[method-assign]
|
||||
|
||||
response = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert fake_db.committed is True
|
||||
inserted_messages = [
|
||||
item for item in fake_db.added if isinstance(item, AgentChatMessage)
|
||||
]
|
||||
assert len(inserted_messages) == 2
|
||||
assert [msg.seq for msg in inserted_messages] == [3, 4]
|
||||
assert [msg.role for msg in inserted_messages] == [
|
||||
AgentChatMessageRole.USER,
|
||||
AgentChatMessageRole.ASSISTANT,
|
||||
]
|
||||
assert [event.type for event in response.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app import app
|
||||
from v1.agent_chat.dependencies import get_agent_chat_service
|
||||
from v1.agent_chat.schemas import (
|
||||
AgentChatEvent,
|
||||
AgentChatRunRequest,
|
||||
AgentChatRunResponse,
|
||||
)
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
class FakeAgentChatService:
|
||||
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
|
||||
return AgentChatRunResponse(
|
||||
session_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
output=payload.message,
|
||||
events=[
|
||||
AgentChatEvent(
|
||||
type="run.started", run_id="00000000-0000-0000-0000-000000000001"
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="message.delta", message_id="m1", delta=payload.message
|
||||
),
|
||||
AgentChatEvent(
|
||||
type="run.completed",
|
||||
run_id="00000000-0000-0000-0000-000000000001",
|
||||
output=payload.message,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _override_agent_chat_service(
|
||||
service: FakeAgentChatService,
|
||||
) -> Callable[[], AgentChatService]:
|
||||
def _get_service() -> AgentChatService:
|
||||
return service # type: ignore[return-value]
|
||||
|
||||
return _get_service
|
||||
|
||||
|
||||
def test_run_route_returns_response() -> None:
|
||||
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat/run", json={"message": "hello"})
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["output"] == "hello"
|
||||
assert [event["type"] for event in body["events"]] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_run_route_validates_payload() -> None:
|
||||
app.dependency_overrides[get_agent_chat_service] = _override_agent_chat_service(
|
||||
FakeAgentChatService()
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.post("/api/v1/agent-chat/run", json={"message": ""})
|
||||
assert response.status_code == 422
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from v1.agent_chat.service import aggregate_session_cost
|
||||
|
||||
|
||||
def test_aggregate_session_cost_sums_non_negative_values() -> None:
|
||||
total = aggregate_session_cost([Decimal("0.010000"), Decimal("0.002500")])
|
||||
assert total == Decimal("0.012500")
|
||||
|
||||
|
||||
def test_aggregate_session_cost_rejects_negative_value() -> None:
|
||||
try:
|
||||
aggregate_session_cost([Decimal("-0.010000")])
|
||||
raised = False
|
||||
except ValueError:
|
||||
raised = True
|
||||
|
||||
assert raised is True
|
||||
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from v1.agent_chat.service import select_recent_session
|
||||
|
||||
|
||||
def test_select_recent_session_uses_last_activity_desc() -> None:
|
||||
sessions = [
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="older",
|
||||
status=AgentChatSessionStatus.COMPLETED,
|
||||
last_activity_at=datetime(2026, 2, 25, 9, 0, tzinfo=timezone.utc),
|
||||
message_count=1,
|
||||
total_tokens=1,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
AgentChatSession(
|
||||
id=UUID("00000000-0000-0000-0000-000000000002"),
|
||||
user_id=UUID("00000000-0000-0000-0000-0000000000a1"),
|
||||
title="newer",
|
||||
status=AgentChatSessionStatus.RUNNING,
|
||||
last_activity_at=datetime(2026, 2, 25, 10, 0, tzinfo=timezone.utc),
|
||||
message_count=2,
|
||||
total_tokens=2,
|
||||
total_cost=Decimal("0"),
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_recent_session(sessions)
|
||||
|
||||
assert selected is not None
|
||||
assert selected.id == UUID("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
|
||||
def test_select_recent_session_returns_none_for_empty_collection() -> None:
|
||||
assert select_recent_session([]) is None
|
||||
@@ -416,6 +416,108 @@ def test_logout_returns_no_content() -> None:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_login_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "user@example.com", "password": "wrongpw"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_refresh_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert blocked.status_code == 401
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_logout_rate_limited_after_too_many_attempts() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
access_token="access",
|
||||
refresh_token="refresh",
|
||||
expires_in=3600,
|
||||
token_type="bearer",
|
||||
user=user,
|
||||
)
|
||||
app.dependency_overrides[get_auth_service] = _override_auth_service(
|
||||
FakeAuthService(token_response)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
for _ in range(10):
|
||||
ok = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert ok.status_code == 204
|
||||
|
||||
blocked = client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "refresh"},
|
||||
)
|
||||
assert blocked.status_code == 429
|
||||
assert blocked.headers["content-type"].startswith("application/problem+json")
|
||||
body = blocked.json()
|
||||
assert body["detail"] == "Too many requests"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_signup_start_validation_error_returns_problem_details() -> None:
|
||||
user = AuthUser(id="user-1", email="user@example.com")
|
||||
token_response = AuthTokenResponse(
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.agui_adapter import AguiAdapter
|
||||
|
||||
|
||||
def test_to_command_maps_payload_fields() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
command = adapter.to_command(
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
)
|
||||
|
||||
assert command["message"] == "hello"
|
||||
assert command["session_id"] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_to_protocol_event_maps_internal_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
mapped = adapter.to_protocol_event(
|
||||
{
|
||||
"kind": "run_completed",
|
||||
"session_id": "run-1",
|
||||
"output": "done",
|
||||
}
|
||||
)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "run-1", "output": "done"}
|
||||
|
||||
|
||||
def test_to_protocol_event_raises_for_invalid_event() -> None:
|
||||
adapter = AguiAdapter()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
adapter.to_protocol_event({"kind": "unknown"})
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_transcribe_uses_injected_dashscope_callable() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert filename == "voice.wav"
|
||||
assert audio_bytes == b"audio"
|
||||
return {"text": "你好", "request_id": "req-1"}
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
result = tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
|
||||
assert result["text"] == "你好"
|
||||
assert result["request_id"] == "req-1"
|
||||
assert result["model"] == "fun-asr-realtime-2025-11-07"
|
||||
|
||||
|
||||
def test_transcribe_raises_runtime_error_when_provider_fails() -> None:
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
raise RuntimeError("upstream timeout")
|
||||
|
||||
tool = FunASRTool(transcribe_callable=fake_transcribe)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
tool.transcribe(audio_bytes=b"audio", filename="voice.wav")
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.cost_tracker import CostTracker
|
||||
|
||||
|
||||
def test_normalize_usage_and_cost_aggregation() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 5,
|
||||
"cost": "0.002500",
|
||||
}
|
||||
)
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 3,
|
||||
"cost": "0.003000",
|
||||
"currency": "USD",
|
||||
}
|
||||
)
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 12
|
||||
assert snapshot["output_tokens"] == 8
|
||||
assert snapshot["total_tokens"] == 20
|
||||
assert snapshot["cost"] == Decimal("0.005500")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_negative_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": -1})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"cost": "-0.010000"})
|
||||
|
||||
|
||||
def test_snapshot_is_zero_before_any_usage() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
snapshot = tracker.snapshot()
|
||||
|
||||
assert snapshot["input_tokens"] == 0
|
||||
assert snapshot["output_tokens"] == 0
|
||||
assert snapshot["total_tokens"] == 0
|
||||
assert snapshot["cost"] == Decimal("0")
|
||||
assert snapshot["currency"] == "USD"
|
||||
|
||||
|
||||
def test_add_usage_rejects_currency_mismatch() -> None:
|
||||
tracker = CostTracker(currency="USD")
|
||||
tracker.add_usage({"input_tokens": 1, "output_tokens": 1, "cost": "0.001000"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage(
|
||||
{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"cost": "0.001000",
|
||||
"currency": "CNY",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_add_usage_rejects_non_integral_token_values() -> None:
|
||||
tracker = CostTracker()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"input_tokens": 1.5})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tracker.add_usage({"output_tokens": True})
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.event_bridge import map_internal_event
|
||||
|
||||
|
||||
def test_map_run_started_event() -> None:
|
||||
event = {"kind": "run_started", "session_id": "s1"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.started", "run_id": "s1"}
|
||||
|
||||
|
||||
def test_map_message_delta_event() -> None:
|
||||
event = {"kind": "message_delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "message.delta", "message_id": "m1", "delta": "hello"}
|
||||
|
||||
|
||||
def test_map_tool_events() -> None:
|
||||
started = {
|
||||
"kind": "tool_started",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
}
|
||||
completed = {
|
||||
"kind": "tool_completed",
|
||||
"message_id": "m2",
|
||||
"tool_name": "asr_fun_asr",
|
||||
"result": "ok",
|
||||
}
|
||||
|
||||
mapped_started = map_internal_event(started)
|
||||
mapped_completed = map_internal_event(completed)
|
||||
|
||||
assert mapped_started["type"] == "tool.started"
|
||||
assert mapped_started["tool_name"] == "asr_fun_asr"
|
||||
assert mapped_completed["type"] == "tool.completed"
|
||||
assert mapped_completed["result"] == "ok"
|
||||
|
||||
|
||||
def test_map_run_completed_event() -> None:
|
||||
event = {"kind": "run_completed", "session_id": "s1", "output": "done"}
|
||||
|
||||
mapped = map_internal_event(event)
|
||||
|
||||
assert mapped == {"type": "run.completed", "run_id": "s1", "output": "done"}
|
||||
|
||||
|
||||
def test_map_unknown_event_raises() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "unknown"})
|
||||
|
||||
|
||||
def test_map_event_missing_required_field_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
map_internal_event({"kind": "message_delta", "message_id": "m1"})
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.multimodal import AttachmentInput, MultimodalProcessor
|
||||
from core.agent_chat.storage_adapter import StorageAdapter
|
||||
from core.agent_chat.tools.asr_fun_asr import FunASRTool
|
||||
|
||||
|
||||
def test_multimodal_processes_audio_and_builds_attachment_context() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
def fake_transcribe(*, audio_bytes: bytes, filename: str) -> dict[str, str]:
|
||||
assert audio_bytes == b"audio"
|
||||
assert filename == "voice.wav"
|
||||
return {"text": "hello world", "request_id": "req-1"}
|
||||
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(transcribe_callable=fake_transcribe),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
result = processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=4,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="voice.wav",
|
||||
mime_type="audio/wav",
|
||||
content=b"audio",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert len(result.attachments) == 1
|
||||
metadata = result.attachments[0]
|
||||
assert (
|
||||
metadata["object_path"]
|
||||
== "agent-chat/u1/s1/4/6ed8919ce20490a5e3ad8630a4fab69475297abd07db73918dd5f36fcfaeb11b.wav"
|
||||
)
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert result.preview_texts == ["hello world"]
|
||||
|
||||
|
||||
def test_multimodal_rejects_unsupported_mime_type() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage, asr_tool=FunASRTool(lambda **_: {})
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="malware.exe",
|
||||
mime_type="application/octet-stream",
|
||||
content=b"bad",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_rejects_attachment_over_max_size() -> None:
|
||||
storage = StorageAdapter(bucket="agent-chat-attachments")
|
||||
processor = MultimodalProcessor(
|
||||
storage=storage,
|
||||
asr_tool=FunASRTool(lambda **_: {}),
|
||||
max_file_size_mb=1,
|
||||
)
|
||||
|
||||
oversized = b"x" * (1024 * 1024 + 1)
|
||||
with pytest.raises(ValueError):
|
||||
processor.process(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=1,
|
||||
attachments=[
|
||||
AttachmentInput(
|
||||
filename="big.wav",
|
||||
mime_type="audio/wav",
|
||||
content=oversized,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent_chat.orchestrator import AgentChatOrchestrator
|
||||
|
||||
|
||||
async def _intent_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("intent")
|
||||
return {
|
||||
"content": f"intent:{message}",
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "cost": "0.001000"},
|
||||
}
|
||||
|
||||
|
||||
async def _execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
return {
|
||||
"content": f"execution:{message}",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 2, "cost": "0.002000"},
|
||||
}
|
||||
|
||||
|
||||
async def _organization_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("organization")
|
||||
return {
|
||||
"content": "final answer",
|
||||
"usage": {"input_tokens": 4, "output_tokens": 1, "cost": "0.001500"},
|
||||
}
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-1", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution", "organization"]
|
||||
assert result.output == "final answer"
|
||||
assert result.usage["total_tokens"] == 13
|
||||
assert result.events[0]["type"] == "run.started"
|
||||
assert result.events[-1]["type"] == "run.completed"
|
||||
|
||||
|
||||
async def _failing_execution_stage(
|
||||
*, message: str, context: dict[str, object]
|
||||
) -> dict[str, object]:
|
||||
sequence = context.setdefault("sequence", [])
|
||||
if isinstance(sequence, list):
|
||||
sequence.append("execution")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
def test_orchestrator_stops_and_marks_failed_when_middle_stage_raises() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_failing_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-2", user_message="hello")
|
||||
|
||||
assert result.context["sequence"] == ["intent", "execution"]
|
||||
assert result.events[-1]["type"] == "run.failed"
|
||||
assert result.events[-1]["run_id"] == "run-2"
|
||||
assert "boom" in (result.events[-1].get("error") or "")
|
||||
assert result.failed is True
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
def test_orchestrator_emits_stage_event_payload_shape() -> None:
|
||||
orchestrator = AgentChatOrchestrator(
|
||||
intent_stage=_intent_stage,
|
||||
execution_stage=_execution_stage,
|
||||
organization_stage=_organization_stage,
|
||||
)
|
||||
|
||||
result = orchestrator.run_sync(run_id="run-3", user_message="hello")
|
||||
|
||||
for event in result.events:
|
||||
assert "type" in event
|
||||
assert event.get("run_id") == "run-3"
|
||||
|
||||
stage_events = [
|
||||
event for event in result.events if event["type"] == "stage.completed"
|
||||
]
|
||||
assert [event["stage"] for event in stage_events] == [
|
||||
"intent",
|
||||
"execution",
|
||||
"organization",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from v1.agent_chat.service import build_session_title
|
||||
|
||||
|
||||
def test_build_session_title_truncates_first_message() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title(
|
||||
"这是一个非常长的标题会被截断到二十四个可见字符用于会话摘要", now=now
|
||||
)
|
||||
|
||||
assert len(title) == 24
|
||||
|
||||
|
||||
def test_build_session_title_falls_back_when_message_empty() -> None:
|
||||
now = datetime(2026, 2, 25, 10, 30)
|
||||
|
||||
title = build_session_title("\n ", now=now)
|
||||
|
||||
assert title == "新对话 2026-02-25 10:30"
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agent_chat.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def test_build_object_path_uses_expected_pattern() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
path = adapter.build_object_path(
|
||||
user_id="u1",
|
||||
session_id="s1",
|
||||
message_seq=3,
|
||||
checksum_sha256="abc123",
|
||||
extension="wav",
|
||||
)
|
||||
|
||||
assert path == "agent-chat/u1/s1/3/abc123.wav"
|
||||
|
||||
|
||||
def test_build_attachment_metadata_contains_required_fields() -> None:
|
||||
adapter = StorageAdapter(bucket="agent-chat-attachments")
|
||||
|
||||
metadata = adapter.build_attachment_metadata(
|
||||
object_path="agent-chat/u1/s1/3/abc123.wav",
|
||||
mime_type="audio/wav",
|
||||
size=1024,
|
||||
checksum_sha256="abc123",
|
||||
origin="user_upload",
|
||||
preview_text="hello",
|
||||
)
|
||||
|
||||
assert metadata["object_path"] == "agent-chat/u1/s1/3/abc123.wav"
|
||||
assert metadata["mime_type"] == "audio/wav"
|
||||
assert metadata["size"] == 1024
|
||||
assert metadata["checksum_sha256"] == "abc123"
|
||||
assert metadata["origin"] == "user_upload"
|
||||
assert metadata["preview_text"] == "hello"
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent_chat.crewai.template_loader import (
|
||||
load_crewai_template,
|
||||
load_tools_whitelist,
|
||||
validate_workflow_stages,
|
||||
)
|
||||
|
||||
|
||||
def _write(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _prepare_static_root(root: Path) -> Path:
|
||||
_write(
|
||||
root / "crewai" / "agents.yaml",
|
||||
"""
|
||||
intent:
|
||||
role: Intent Agent
|
||||
execution:
|
||||
role: Execution Agent
|
||||
organization:
|
||||
role: Organization Agent
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "crewai" / "tasks.yaml",
|
||||
"""
|
||||
intent:
|
||||
description: classify
|
||||
execution:
|
||||
description: run task
|
||||
organization:
|
||||
description: summarize
|
||||
""".strip(),
|
||||
)
|
||||
_write(
|
||||
root / "crewai" / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- intent
|
||||
- execution
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
_write(root / "crewai" / "prompts" / "intent.md", "intent prompt")
|
||||
_write(root / "crewai" / "prompts" / "execution.md", "execution prompt")
|
||||
_write(root / "crewai" / "prompts" / "organization.md", "organization prompt")
|
||||
_write(
|
||||
root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- doc_extract
|
||||
""".strip(),
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_load_crewai_template_success_when_all_files_valid(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
|
||||
template = load_crewai_template(static_root)
|
||||
|
||||
assert set(template.agents.keys()) == {"intent", "execution", "organization"}
|
||||
assert set(template.tasks.keys()) == {"intent", "execution", "organization"}
|
||||
assert template.workflow["stages"] == ["intent", "execution", "organization"]
|
||||
assert template.prompts["intent"] == "intent prompt"
|
||||
assert template.prompts["execution"] == "execution prompt"
|
||||
assert template.prompts["organization"] == "organization prompt"
|
||||
assert template.tools_whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_file_not_found_when_required_file_missing(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
(static_root / "crewai" / "tasks.yaml").unlink()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_crewai_template_raises_value_error_when_workflow_stages_invalid(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
_write(
|
||||
static_root / "crewai" / "workflow.yaml",
|
||||
"""
|
||||
stages:
|
||||
- execution
|
||||
- intent
|
||||
- organization
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_crewai_template(static_root)
|
||||
|
||||
|
||||
def test_load_tools_whitelist_from_tools_yaml(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
|
||||
whitelist = load_tools_whitelist(static_root)
|
||||
|
||||
assert whitelist == {"asr_fun_asr", "doc_extract"}
|
||||
|
||||
|
||||
def test_validate_workflow_stages_accepts_exact_intent_execution_organization() -> None:
|
||||
validate_workflow_stages(["intent", "execution", "organization"])
|
||||
|
||||
|
||||
def test_validate_workflow_stages_rejects_extra_or_missing_stage() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution"])
|
||||
with pytest.raises(ValueError):
|
||||
validate_workflow_stages(["intent", "execution", "organization", "extra"])
|
||||
|
||||
|
||||
def test_load_tools_whitelist_rejects_non_string_item(tmp_path: Path) -> None:
|
||||
static_root = _prepare_static_root(tmp_path / "agent_chat")
|
||||
_write(
|
||||
static_root / "tools.yaml",
|
||||
"""
|
||||
tools:
|
||||
- asr_fun_asr
|
||||
- 123
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_tools_whitelist(static_root)
|
||||
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from core.initialization import init_data
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
def test_llm_catalog_file_exists_and_has_required_fields() -> None:
|
||||
catalog_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "src"
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "agent_chat"
|
||||
/ "llm_catalog.yaml"
|
||||
)
|
||||
|
||||
catalog = init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
assert len(catalog["factories"]) == 6
|
||||
assert len(catalog["llms"]) == 2
|
||||
assert set(catalog["factories"][0].keys()) == {"name", "request_url", "avatar"}
|
||||
assert set(catalog["llms"][0].keys()) == {"model_code", "factory_id"}
|
||||
|
||||
|
||||
def test_load_llm_catalog_raises_on_invalid_structure(tmp_path: Path) -> None:
|
||||
catalog_path = tmp_path / "llm_catalog.yaml"
|
||||
catalog_path.write_text(
|
||||
"""
|
||||
factories:
|
||||
- name: qwen
|
||||
llms:
|
||||
- model_code: qwen3.5-flash
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
init_data.load_llm_catalog(catalog_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
|
||||
first = await init_data.initialize_data()
|
||||
second = await init_data.initialize_data()
|
||||
|
||||
assert first is True
|
||||
assert second is True
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 6
|
||||
assert llm_count == 2
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_data_rolls_back_on_invalid_factory_mapping(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
session_maker = async_sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
monkeypatch.setattr(init_data, "AsyncSessionLocal", session_maker)
|
||||
monkeypatch.setattr(
|
||||
init_data,
|
||||
"load_llm_catalog",
|
||||
lambda *_: {
|
||||
"factories": [
|
||||
{
|
||||
"name": "qwen",
|
||||
"request_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"avatar": "https://cdn.example.com/qwen.png",
|
||||
}
|
||||
],
|
||||
"llms": [
|
||||
{
|
||||
"model_code": "qwen3.5-flash",
|
||||
"factory_id": "missing_factory",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await init_data.initialize_data()
|
||||
|
||||
async with session_maker() as session:
|
||||
factory_count = await session.scalar(
|
||||
select(func.count()).select_from(LlmFactory)
|
||||
)
|
||||
llm_count = await session.scalar(select(func.count()).select_from(Llm))
|
||||
|
||||
assert factory_count == 0
|
||||
assert llm_count == 0
|
||||
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_agent_chat_migration_exists_and_creates_expected_tables() -> None:
|
||||
versions_dir = Path(__file__).resolve().parents[3] / "alembic" / "versions"
|
||||
migration = versions_dir / "20260226_create_agent_chat_core_tables.py"
|
||||
|
||||
assert migration.exists()
|
||||
|
||||
content = migration.read_text(encoding="utf-8")
|
||||
assert 'create_table(\n "llm_factory"' in content
|
||||
assert 'create_table(\n "llms"' in content
|
||||
assert 'create_table(\n "sessions"' in content
|
||||
assert 'create_table(\n "messages"' in content
|
||||
assert "tool_calls" not in content
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
||||
from models.llm import Llm
|
||||
from models.llm_factory import LlmFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
|
||||
factory = LlmFactory(
|
||||
name="qwen",
|
||||
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
avatar="https://cdn.example.com/qwen.png",
|
||||
)
|
||||
db_session.add(factory)
|
||||
await db_session.flush()
|
||||
|
||||
llm = Llm(
|
||||
factory_id=factory.id,
|
||||
model_code="qwen3.5-flash",
|
||||
)
|
||||
db_session.add(llm)
|
||||
await db_session.commit()
|
||||
|
||||
found_llm = await db_session.get(Llm, llm.id)
|
||||
assert found_llm is not None
|
||||
assert found_llm.factory_id == factory.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_status_supports_required_values(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
session = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title="test",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(session)
|
||||
await db_session.commit()
|
||||
|
||||
statuses = [
|
||||
AgentChatSessionStatus.PENDING,
|
||||
AgentChatSessionStatus.RUNNING,
|
||||
AgentChatSessionStatus.COMPLETED,
|
||||
AgentChatSessionStatus.FAILED,
|
||||
]
|
||||
for status in statuses:
|
||||
session.status = status
|
||||
await db_session.commit()
|
||||
await db_session.refresh(session)
|
||||
assert session.status == status
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
|
||||
user_id = uuid4()
|
||||
session = AgentChatSession(
|
||||
user_id=user_id,
|
||||
title="tool test",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(session)
|
||||
await db_session.flush()
|
||||
|
||||
message = AgentChatMessage(
|
||||
session_id=session.id,
|
||||
seq=1,
|
||||
role="tool",
|
||||
content="tool output",
|
||||
cost=0,
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
|
||||
)
|
||||
found = result.scalar_one()
|
||||
assert found.role == "tool"
|
||||
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from core.config.settings import Settings
|
||||
|
||||
|
||||
def test_social_prefixed_storage_env_populates_settings(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "supabase")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__BUCKET", "agent-chat-attachments")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__SIGNED_URL_TTL_SECONDS", "900")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__MAX_FILE_SIZE_MB", "25")
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__RETENTION_DAYS", "45")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.storage.provider == "supabase"
|
||||
assert settings.storage.bucket == "agent-chat-attachments"
|
||||
assert settings.storage.signed_url_ttl_seconds == 900
|
||||
assert settings.storage.max_file_size_mb == 25
|
||||
assert settings.storage.retention_days == 45
|
||||
|
||||
|
||||
def test_storage_settings_validation_rejects_invalid_provider(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("SOCIAL_STORAGE__PROVIDER", "s3")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Settings()
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, String, Table, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.agent_chat.orchestrator import OrchestratorResult
|
||||
from core.db.base import Base
|
||||
from models.agent_chat_message import AgentChatMessage
|
||||
from models.agent_chat_session import AgentChatSession
|
||||
from v1.agent_chat.schemas import AgentChatRunRequest
|
||||
from v1.agent_chat.service import AgentChatService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
users_table = Table(
|
||||
"users",
|
||||
Base.metadata,
|
||||
Column("id", String, primary_key=True),
|
||||
schema="auth",
|
||||
extend_existing=True,
|
||||
)
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
Base.metadata.remove(users_table)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
async_session = async_sessionmaker(
|
||||
bind=db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_creates_session_and_persists_messages(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
result = await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert result.session_id is not None
|
||||
assert result.output == "hello"
|
||||
assert [event.type for event in result.events] == [
|
||||
"run.started",
|
||||
"message.delta",
|
||||
"run.completed",
|
||||
]
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, result.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 2
|
||||
assert session_obj.status.value == "completed"
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatMessage)
|
||||
.where(AgentChatMessage.session_id == result.session_id)
|
||||
.order_by(AgentChatMessage.seq.asc())
|
||||
)
|
||||
messages = rows.scalars().all()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role.value == "user"
|
||||
assert messages[1].role.value == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_appends_to_existing_session(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
first = await service.run(AgentChatRunRequest(message="first"))
|
||||
second = await service.run(
|
||||
AgentChatRunRequest(message="second", session_id=first.session_id)
|
||||
)
|
||||
|
||||
assert second.session_id == first.session_id
|
||||
|
||||
session_obj = await db_session.get(AgentChatSession, first.session_id)
|
||||
assert session_obj is not None
|
||||
assert session_obj.message_count == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_raises_502_and_marks_session_failed_when_orchestrator_fails(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _FailingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
return OrchestratorResult(
|
||||
output="",
|
||||
usage={
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": Decimal("0"),
|
||||
"currency": "USD",
|
||||
},
|
||||
events=[],
|
||||
context={},
|
||||
failed=True,
|
||||
error="stage failed",
|
||||
)
|
||||
|
||||
service._orchestrator = _FailingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
|
||||
rows = await db_session.execute(
|
||||
select(AgentChatSession).where(AgentChatSession.user_id == user.id)
|
||||
)
|
||||
stored_session = rows.scalars().one()
|
||||
assert stored_session.status.value == "failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_422_when_message_is_blank(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message=" "))
|
||||
|
||||
assert exc_info.value.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_404_when_session_not_found(db_session: AsyncSession) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello", session_id=uuid4()))
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_503_when_commit_raises_sqlalchemy_error(
|
||||
db_session: AsyncSession,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
async def _fail_commit() -> None:
|
||||
raise SQLAlchemyError("db down")
|
||||
|
||||
monkeypatch.setattr(db_session, "commit", _fail_commit)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_502_for_unexpected_exception(
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
user = CurrentUser(id=uuid4())
|
||||
service = AgentChatService(session=db_session, current_user=user)
|
||||
|
||||
class _CrashingOrchestrator:
|
||||
async def run(self, *, run_id: str, user_message: str) -> OrchestratorResult:
|
||||
raise RuntimeError("unexpected")
|
||||
|
||||
service._orchestrator = _CrashingOrchestrator() # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await service.run(AgentChatRunRequest(message="hello"))
|
||||
|
||||
assert exc_info.value.status_code == 502
|
||||
@@ -0,0 +1,36 @@
|
||||
# Agent Chat CrewAI + AG-UI Spike Notes
|
||||
|
||||
## Scope
|
||||
|
||||
- 验证 CrewAI 依赖可用性与版本探测方式。
|
||||
- 验证 AG-UI 官方 CrewAI 集成在当前仓库中的落地路径。
|
||||
- 验证 DashScope FunASR 响应中的 usage 字段可得性与兜底策略。
|
||||
|
||||
## Findings
|
||||
|
||||
### CrewAI
|
||||
|
||||
- `uv run python -m pip show crewai` 在当前虚拟环境不可用(无 pip 模块)。
|
||||
- `uv pip show crewai` 返回未安装,说明当前工作树尚未安装 CrewAI 依赖。
|
||||
- 若需启用真实编排,需在 `pyproject.toml` 中声明依赖并执行 `uv sync --extra dev`。
|
||||
|
||||
### AG-UI 官方 CrewAI 集成
|
||||
|
||||
- 目标对齐官方标准事件语义(如 `message.delta`、`tool.started`、`tool.completed`、`run.completed`、`run.failed`)。
|
||||
- 当前仓库采取“适配层隔离”策略:由 `agui_adapter.py` 进行请求与事件映射,避免协议细节扩散到业务层。
|
||||
|
||||
### DashScope FunASR
|
||||
|
||||
- 优先读取上游响应 usage 字段用于成本统计。
|
||||
- 若 usage 缺失,落库时保持 `raw_usage` 与空标准字段,并标记 `metadata.usage_missing=true` 以便审计。
|
||||
|
||||
## Fallback Strategy
|
||||
|
||||
- 当官方集成能力或版本存在不确定性时,启用最小兜底事件映射:
|
||||
- 仅输出标准 AG-UI 事件。
|
||||
- 不扩展私有协议字段。
|
||||
- 在 `event_bridge.py` 中统一做字段校验与错误转换。
|
||||
|
||||
## Decision
|
||||
|
||||
- 继续按计划推进:先补齐编排与成本核心,再完善 AG-UI 适配、多模态与 E2E 闭环。
|
||||
@@ -0,0 +1,49 @@
|
||||
# Agent Chat Gap Closure Design
|
||||
|
||||
**Goal:** 在不重做已完成任务的前提下,按既定 Task 顺序补齐 Agent Chat Core 的缺口,实现可验证、可审计的端到端闭环。
|
||||
|
||||
## Current State
|
||||
|
||||
- 已完成:Task 2/3/4 的核心数据层、静态配置、模板加载;Task 6/7 的部分骨架(`event_bridge`、`v1/agent_chat`、`storage_adapter`、`asr_fun_asr`)。
|
||||
- 未完成或缺口:Task 1 的 spike 结论文档;Task 5 编排与成本追踪;Task 6 `agui_adapter` 与缺失测试;Task 7 `multimodal`;Task 8 会话审计与 recent 规则;Task 9 E2E 与运行文档闭环。
|
||||
|
||||
## Design Decisions
|
||||
|
||||
- 以“缺口优先”方式执行:仅新增/修改缺失能力,已稳定模块不重构。
|
||||
- 严格遵循顺序:Task 1 -> 5 -> 6 -> 7 -> 8 -> 9。
|
||||
- 每个 Task 均采用 TDD:先写失败测试,再做最小实现,通过后再小步重构。
|
||||
- 统一事件与持久化顺序:以 `session.id + seq` 为唯一顺序锚点,避免流式输出与落库顺序漂移。
|
||||
- 工具调用成本仍归集到 `messages(role=tool)`,会话总成本由增量聚合维护。
|
||||
|
||||
## Component Plan
|
||||
|
||||
- Task 1: 新增 spike notes,记录 CrewAI/AG-UI/FunASR 依赖可用性与兜底策略。
|
||||
- Task 5: 新增 `orchestrator.py`、`cost_tracker.py`、`events.py`,完成三阶段执行与 usage/cost 归一。
|
||||
- Task 6: 新增 `agui_adapter.py`,对接现有 `event_bridge.py` 与 `v1/agent_chat/service.py`。
|
||||
- Task 7: 新增 `multimodal.py`,衔接附件校验、存储元数据、ASR 文本提取。
|
||||
- Task 8: 增强会话标题策略、recent session 查询、审计字段与限流保护。
|
||||
- Task 9: 补齐 E2E 与 runbook,执行 bootstrap gate + 分层测试验证。
|
||||
|
||||
## Data Flow
|
||||
|
||||
1. 路由接收 AG-UI 请求并解析输入文本/附件。
|
||||
2. `agui_adapter` 生成内部命令并触发编排器三阶段执行。
|
||||
3. 每阶段产出内部事件,经 `event_bridge` 映射为 AG-UI 标准事件。
|
||||
4. `service` 在事务内写入 `messages` 并更新 `sessions` 汇总字段。
|
||||
5. 流式事件向外输出,顺序与 `messages.seq` 保持一致。
|
||||
|
||||
## Error Handling
|
||||
|
||||
- 配置/模板错误:启动前校验并快速失败,返回可追踪错误码。
|
||||
- 第三方调用错误(LLM/ASR/Storage):记录标准化失败事件与审计元数据,不泄露敏感信息。
|
||||
- 持久化冲突:对 `session_id + seq` 冲突执行有限重试并记录告警。
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Unit:`cost_tracker`、`orchestrator`、`agui_adapter`、`multimodal`、`title strategy`。
|
||||
- Integration:`agent_chat` 路由、事件落库、recent session 选择、会话成本聚合。
|
||||
- E2E:文本、图片+文本、音频+ASR、文档问答、首页最近会话默认选中。
|
||||
|
||||
## Approval Note
|
||||
|
||||
该设计基于用户确认的“仅按未完成 Task 顺序推进”执行策略。
|
||||
@@ -0,0 +1,230 @@
|
||||
# Agent Chat Gap Closure Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 按未完成 Task 顺序补齐 Agent Chat Core 缺口,形成可运行、可测试、可审计的后端链路。
|
||||
|
||||
**Architecture:** 复用已完成的数据层与路由骨架,在 `core/agent_chat` 补齐编排、成本与多模态能力,并通过 `v1/agent_chat/service.py` 统一持久化与事件顺序。全流程以 `session.id + messages.seq` 作为一致性锚点,保证事件输出与落库一致。
|
||||
|
||||
**Tech Stack:** FastAPI, SQLAlchemy, Pydantic, pytest, CrewAI, AG-UI adapter, DashScope SDK, Supabase Storage。
|
||||
|
||||
---
|
||||
|
||||
### Task 1: 补齐 Spike 结论文档
|
||||
|
||||
**Files:**
|
||||
- Create: `docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md`
|
||||
|
||||
**Step 1: 写失败校验(文档存在性)**
|
||||
|
||||
```bash
|
||||
test -f docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md
|
||||
```
|
||||
|
||||
**Step 2: 运行并确认失败**
|
||||
|
||||
Run: `test -f docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md`
|
||||
Expected: non-zero exit code。
|
||||
|
||||
**Step 3: 写最小文档实现**
|
||||
|
||||
```markdown
|
||||
- CrewAI 版本探测结论
|
||||
- AG-UI 官方 CrewAI 集成可用性结论
|
||||
- DashScope FunASR usage 字段策略
|
||||
- 不可用时的最小兜底映射策略
|
||||
```
|
||||
|
||||
**Step 4: 运行并确认通过**
|
||||
|
||||
Run: `test -f docs/plans/2026-02-25-agent-chat-crewai-ag-ui-spike-notes.md`
|
||||
Expected: zero exit code。
|
||||
|
||||
### Task 5: 补齐编排与成本追踪
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/src/core/agent_chat/events.py`
|
||||
- Create: `backend/src/core/agent_chat/cost_tracker.py`
|
||||
- Create: `backend/src/core/agent_chat/orchestrator.py`
|
||||
- Test: `backend/tests/unit/core/agent_chat/test_cost_tracker.py`
|
||||
- Test: `backend/tests/unit/core/agent_chat/test_orchestrator_pipeline.py`
|
||||
|
||||
**Step 1: 写失败测试**
|
||||
|
||||
```python
|
||||
def test_normalize_usage_and_cost_aggregation():
|
||||
assert False
|
||||
|
||||
|
||||
def test_orchestrator_runs_three_stages_in_order():
|
||||
assert False
|
||||
```
|
||||
|
||||
**Step 2: 运行并确认失败**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_cost_tracker.py backend/tests/unit/core/agent_chat/test_orchestrator_pipeline.py -v`
|
||||
Expected: FAIL。
|
||||
|
||||
**Step 3: 写最小实现**
|
||||
|
||||
```python
|
||||
class CostTracker:
|
||||
def add_usage(self, usage: dict) -> None: ...
|
||||
def total(self) -> dict: ...
|
||||
|
||||
|
||||
class AgentChatOrchestrator:
|
||||
async def run(self, command):
|
||||
# intent -> execution -> organization
|
||||
...
|
||||
```
|
||||
|
||||
**Step 4: 运行并确认通过**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_cost_tracker.py backend/tests/unit/core/agent_chat/test_orchestrator_pipeline.py -v`
|
||||
Expected: PASS。
|
||||
|
||||
### Task 6: 补齐 AG-UI 适配层缺口
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/src/core/agent_chat/agui_adapter.py`
|
||||
- Modify: `backend/src/core/agent_chat/event_bridge.py`
|
||||
- Modify: `backend/src/v1/agent_chat/service.py`
|
||||
- Test: `backend/tests/unit/core/agent_chat/test_agui_adapter.py`
|
||||
- Test: `backend/tests/integration/test_agent_chat_event_persistence.py`
|
||||
|
||||
**Step 1: 写失败测试**
|
||||
|
||||
```python
|
||||
def test_agui_adapter_maps_internal_events_to_protocol_events():
|
||||
assert False
|
||||
```
|
||||
|
||||
**Step 2: 运行并确认失败**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_agui_adapter.py backend/tests/integration/test_agent_chat_event_persistence.py -v`
|
||||
Expected: FAIL。
|
||||
|
||||
**Step 3: 写最小实现**
|
||||
|
||||
```python
|
||||
class AguiAdapter:
|
||||
def to_command(self, request): ...
|
||||
def to_protocol_event(self, event): ...
|
||||
```
|
||||
|
||||
**Step 4: 运行并确认通过**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_agui_adapter.py backend/tests/integration/test_agent_chat_event_persistence.py -v`
|
||||
Expected: PASS。
|
||||
|
||||
### Task 7: 补齐多模态输入编排
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/src/core/agent_chat/multimodal.py`
|
||||
- Modify: `backend/src/core/agent_chat/storage_adapter.py`
|
||||
- Modify: `backend/src/core/agent_chat/tools/asr_fun_asr.py`
|
||||
- Test: `backend/tests/unit/core/agent_chat/test_multimodal.py`
|
||||
|
||||
**Step 1: 写失败测试**
|
||||
|
||||
```python
|
||||
def test_multimodal_validates_and_builds_attachment_context():
|
||||
assert False
|
||||
```
|
||||
|
||||
**Step 2: 运行并确认失败**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_multimodal.py -v`
|
||||
Expected: FAIL。
|
||||
|
||||
**Step 3: 写最小实现**
|
||||
|
||||
```python
|
||||
class MultimodalProcessor:
|
||||
async def build_context(self, attachments): ...
|
||||
```
|
||||
|
||||
**Step 4: 运行并确认通过**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_multimodal.py backend/tests/unit/core/agent_chat/test_storage_adapter.py backend/tests/unit/core/agent_chat/test_asr_fun_asr_tool.py -v`
|
||||
Expected: PASS。
|
||||
|
||||
### Task 8: 补齐会话审计与 recent 规则
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/src/v1/agent_chat/service.py`
|
||||
- Modify: `backend/src/v1/agent_chat/router.py`
|
||||
- Test: `backend/tests/unit/core/agent_chat/test_session_title_strategy.py`
|
||||
- Test: `backend/tests/integration/test_agent_chat_session_recent_selection.py`
|
||||
- Test: `backend/tests/integration/test_agent_chat_session_persistence.py`
|
||||
|
||||
**Step 1: 写失败测试**
|
||||
|
||||
```python
|
||||
def test_title_generated_from_first_user_message():
|
||||
assert False
|
||||
|
||||
|
||||
def test_recent_session_selected_by_last_activity_at_desc():
|
||||
assert False
|
||||
```
|
||||
|
||||
**Step 2: 运行并确认失败**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_session_title_strategy.py backend/tests/integration/test_agent_chat_session_recent_selection.py backend/tests/integration/test_agent_chat_session_persistence.py -v`
|
||||
Expected: FAIL。
|
||||
|
||||
**Step 3: 写最小实现**
|
||||
|
||||
```python
|
||||
def build_session_title(first_message: str) -> str: ...
|
||||
```
|
||||
|
||||
**Step 4: 运行并确认通过**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat/test_session_title_strategy.py backend/tests/integration/test_agent_chat_session_recent_selection.py backend/tests/integration/test_agent_chat_session_persistence.py -v`
|
||||
Expected: PASS。
|
||||
|
||||
### Task 9: 补齐 E2E 与运行文档闭环
|
||||
|
||||
**Files:**
|
||||
- Create: `backend/tests/e2e/test_agent_chat_flow.py`
|
||||
- Create: `backend/tests/e2e/test_agent_chat_recent_session_home.py`
|
||||
- Modify: `docs/runtime/runtime-runbook.md`
|
||||
|
||||
**Step 1: 写失败 E2E 用例**
|
||||
|
||||
```python
|
||||
def test_agent_chat_text_image_audio_document_flow():
|
||||
assert False
|
||||
```
|
||||
|
||||
**Step 2: 运行并确认失败**
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/e2e/test_agent_chat_flow.py -v`
|
||||
Expected: FAIL。
|
||||
|
||||
**Step 3: 写最小实现与文档补充**
|
||||
|
||||
```markdown
|
||||
- bootstrap gate 执行顺序
|
||||
- agent_chat 验证命令
|
||||
```
|
||||
|
||||
**Step 4: 全量验证**
|
||||
|
||||
Run: `make runtime-bootstrap-gate`
|
||||
Expected: bootstrap 通过。
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat -v`
|
||||
Expected: PASS。
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/integration -k agent_chat -v`
|
||||
Expected: PASS。
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pytest backend/tests/e2e/test_agent_chat_flow.py backend/tests/e2e/test_agent_chat_recent_session_home.py -v`
|
||||
Expected: PASS。
|
||||
|
||||
Run: `PYTHONPATH=backend/src uv run pip check`
|
||||
Expected: no broken requirements。
|
||||
@@ -175,6 +175,23 @@ curl -sS -X PATCH http://127.0.0.1:8000/api/v1/profile/me \
|
||||
-d '{"username":"demo2","bio":"hello"}'
|
||||
```
|
||||
|
||||
## Agent Chat 验证
|
||||
|
||||
```bash
|
||||
# 1) 基础门禁(迁移 + init-data)
|
||||
make runtime-bootstrap-gate
|
||||
|
||||
# 2) 运行 agent_chat 相关单测/集成/E2E
|
||||
PYTHONPATH=backend/src uv run pytest backend/tests/unit/core/agent_chat -v
|
||||
PYTHONPATH=backend/src uv run pytest backend/tests/integration -k agent_chat -v
|
||||
PYTHONPATH=backend/src uv run pytest backend/tests/e2e/test_agent_chat_flow.py backend/tests/e2e/test_agent_chat_recent_session_home.py -v
|
||||
|
||||
# 3) 核心接口 smoke
|
||||
curl -sS -X POST http://127.0.0.1:8000/api/v1/agent-chat/run \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"message":"hello"}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 变更日志
|
||||
@@ -188,3 +205,4 @@ curl -sS -X PATCH http://127.0.0.1:8000/api/v1/profile/me \
|
||||
| 2026-02-25 | 补充迁移防遗漏规则:容器迁移命令统一追加 --build;开发调试优先使用本地 CLI 一次性迁移脚本 |
|
||||
| 2026-02-25 | Auth 注册切换为 OTP 三段式:signup/start、signup/verify、signup/resend;邮件模板改为纯验证码展示 |
|
||||
| 2026-02-25 | 清理未使用配置类:删除 WebSettings/GunicornSettings/WorkerSettings/WorkerGroupSettings(脚本仍使用环境变量启动服务) |
|
||||
| 2026-02-25 | 新增 Agent Chat 验证章节:bootstrap gate、分层测试命令与 run 接口 smoke 示例 |
|
||||
|
||||
Reference in New Issue
Block a user