refactor: 移除 crewai agent 架构相关代码并更新 LLM 配置

This commit is contained in:
qzl
2026-03-04 11:37:09 +08:00
parent 87399f74c8
commit b02a322bf3
71 changed files with 1045 additions and 7499 deletions
-1
View File
@@ -1 +0,0 @@
from __future__ import annotations
-38
View File
@@ -1,38 +0,0 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
class FlowState(BaseModel):
stage_trace: list[str] = Field(default_factory=list)
current_stage: str | None = None
message: str = ""
context: dict[str, Any] = Field(default_factory=dict)
class AgentFlow:
def __init__(self) -> None:
self.state = FlowState()
async def run(self) -> dict[str, Any]:
result = await self.intent_recognition()
result = await self.task_execution(result)
result = await self.result_reporting(result)
return result
async def intent_recognition(self) -> dict[str, Any]:
self.state.current_stage = "intent"
self.state.stage_trace.append("intent")
return {"stage": "intent", "result": "intent recognized"}
async def task_execution(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
self.state.current_stage = "execution"
self.state.stage_trace.append("execution")
return {"stage": "execution", "result": "task executed"}
async def result_reporting(self, _prev_result: dict[str, Any]) -> dict[str, Any]:
self.state.current_stage = "reporting"
self.state.stage_trace.append("reporting")
return {"stage": "reporting", "result": "reported"}
-18
View File
@@ -1,18 +0,0 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.agent.service import AgentChatService
from v1.profile.dependencies import get_current_user
def get_agent_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AgentChatService:
return AgentChatService(session=session, current_user=user)
-41
View File
@@ -1,41 +0,0 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import RunAgentInput
from v1.agent.service import AgentChatService
router = APIRouter(prefix="/agent", tags=["agent"])
@router.post("/runs")
async def create_run(
input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse:
return StreamingResponse(
service.stream_run(input_data),
media_type="text/event-stream",
)
@router.post("/runs/{run_id}/resume")
async def resume_run(
run_id: Annotated[str, Path(min_length=1, max_length=255)],
input_data: RunAgentInput,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> StreamingResponse:
if input_data.runId != run_id:
raise HTTPException(
status_code=409,
detail=f"run_id mismatch: path={run_id}, body={input_data.runId}",
)
await service.prepare_resume(run_id, input_data)
return StreamingResponse(
service.stream_resume(run_id, input_data),
media_type="text/event-stream",
)
-88
View File
@@ -1,88 +0,0 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Literal
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
class RunAgentInput(BaseModel):
model_config = ConfigDict(extra="forbid")
threadId: str = Field(min_length=1, max_length=255)
runId: str = Field(min_length=1, max_length=255)
parentRunId: str | None = Field(default=None, max_length=255)
state: dict[str, Any] = Field(default_factory=dict)
messages: list[dict[str, Any]] = Field(default_factory=list)
tools: list[dict[str, Any]] = Field(default_factory=list)
context: list[dict[str, Any]] = Field(default_factory=list)
forwardedProps: dict[str, Any] = Field(default_factory=dict)
resume: dict[str, Any] | None = None
class AgentChatRunRequest(BaseModel):
message: str = Field(min_length=1, max_length=8000)
session_id: UUID | None = None
class AgentChatEvent(BaseModel):
type: str
run_id: str | None = None
message_id: str | None = None
delta: str | None = None
tool_name: str | None = None
result: str | None = None
output: str | None = None
error: str | None = None
class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
class PendingToolStatus(str, Enum):
PENDING_APPROVAL = "PENDING_APPROVAL"
APPROVED_EXECUTING = "APPROVED_EXECUTING"
EXECUTED = "EXECUTED"
REJECTED = "REJECTED"
EXPIRED = "EXPIRED"
class PendingToolCall(BaseModel):
model_config = ConfigDict(extra="forbid")
interrupt_id: str = Field(min_length=1, max_length=255)
tool_name: str = Field(min_length=1, max_length=255)
tool_args: dict[str, Any] = Field(default_factory=dict)
status: PendingToolStatus
expires_at: datetime
decision: dict[str, Any] | None = None
result: dict[str, Any] | None = None
updated_at: datetime
@field_validator("expires_at", "updated_at")
@classmethod
def _validate_timezone_aware(cls, value: datetime) -> datetime:
if value.tzinfo is None or value.utcoffset() is None:
raise ValueError("datetime must be timezone-aware")
return value
class SnapshotRunContext(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str = Field(min_length=1, max_length=255)
run_id: str = Field(min_length=1, max_length=255)
class AgentSessionSnapshot(BaseModel):
model_config = ConfigDict(extra="forbid")
version: Literal[2]
pending_tool_call: PendingToolCall | None
run_context: SnapshotRunContext
-545
View File
@@ -1,545 +0,0 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING, Any
from uuid import UUID
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import SQLAlchemyError
from core.agent.agui_adapter import AguiAdapter
from core.agent.orchestrator import AgentChatOrchestrator
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent.schemas import (
AgentSessionSnapshot,
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
PendingToolCall,
PendingToolStatus,
RunAgentInput,
SnapshotRunContext,
)
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.agent.service")
DEFAULT_RATE_LIMIT = 60
EMPTY_USAGE = {"input_tokens": 0, "output_tokens": 0, "cost": "0"}
class ResumeDecisionResult(BaseModel):
applied: bool
def build_session_title(first_message: str, *, now: datetime) -> str:
title = first_message.strip().replace("\n", " ")[:24]
if not title:
return now.strftime("新对话 %Y-%m-%d %H:%M")
return title
def aggregate_session_cost(costs: list[Decimal]) -> Decimal:
total = Decimal("0")
for cost in costs:
if cost < 0:
raise ValueError("cost must be non-negative")
total += cost
return total
def select_recent_session(
sessions: list[AgentChatSession],
) -> AgentChatSession | None:
if not sessions:
return None
return max(sessions, key=lambda item: item.last_activity_at)
class AgentChatService(BaseService):
_session: AsyncSession
def __init__(self, session: AsyncSession, current_user: CurrentUser | None) -> None:
super().__init__(current_user=current_user)
self._session = session
self._adapter = AguiAdapter()
self._orchestrator = AgentChatOrchestrator(
intent_stage=self._intent_stage,
execution_stage=self._execution_stage,
organization_stage=self._organization_stage,
)
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
try:
command = self._adapter.to_command(payload.model_dump(mode="python"))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_run",
identifier=str(user_id),
limit=DEFAULT_RATE_LIMIT,
window_seconds=DEFAULT_RATE_LIMIT,
)
now = datetime.now(timezone.utc)
try:
chat_session = await self._resolve_session(
session_id=command["session_id"],
user_id=user_id,
first_message=command["message"],
now=now,
)
base_seq = await self._next_seq_base(chat_session.id)
user_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 1,
role=AgentChatMessageRole.USER,
content=command["message"],
cost=Decimal("0"),
)
orchestrator_result = await self._orchestrator.run(
run_id=str(chat_session.id),
user_message=command["message"],
)
assistant_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 2,
role=AgentChatMessageRole.ASSISTANT,
content=orchestrator_result.output,
input_tokens=int(orchestrator_result.usage["input_tokens"]),
output_tokens=int(orchestrator_result.usage["output_tokens"]),
cost=Decimal(orchestrator_result.usage["cost"]),
)
self._session.add(user_message)
self._session.add(assistant_message)
chat_session.status = (
AgentChatSessionStatus.FAILED
if orchestrator_result.failed
else AgentChatSessionStatus.COMPLETED
)
chat_session.last_activity_at = now
chat_session.message_count = chat_session.message_count + 2
chat_session.total_tokens = chat_session.total_tokens + int(
orchestrator_result.usage["total_tokens"]
)
chat_session.total_cost = aggregate_session_cost(
[
Decimal(chat_session.total_cost),
Decimal(orchestrator_result.usage["cost"]),
]
)
await self._session.commit()
await self._session.refresh(chat_session)
await self._session.refresh(user_message)
mapped_events = self._build_mapped_events(
session_id=str(chat_session.id),
message_id=str(user_message.id),
user_message=command["message"],
assistant_output=assistant_message.content,
failed=orchestrator_result.failed,
error=orchestrator_result.error,
)
events = [AgentChatEvent.model_validate(item) for item in mapped_events]
if orchestrator_result.failed:
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
)
return AgentChatRunResponse(
session_id=chat_session.id,
output=assistant_message.content,
events=events,
)
except HTTPException:
await self._session.rollback()
raise
except SQLAlchemyError:
await self._session.rollback()
logger.exception("Agent chat run failed")
raise HTTPException(status_code=503, detail="Agent chat store unavailable")
except Exception as exc: # noqa: BLE001
await self._session.rollback()
logger.exception(
"Agent chat unexpected failure", error_type=type(exc).__name__
)
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
) from exc
def _build_mapped_events(
self,
*,
session_id: str,
message_id: str,
user_message: str,
assistant_output: str,
failed: bool,
error: str | None,
) -> list[dict[str, object]]:
mapped_events = [
self._adapter.to_protocol_event(
{
"kind": "run_started",
"session_id": session_id,
}
),
self._adapter.to_protocol_event(
{
"kind": "message_delta",
"message_id": message_id,
"delta": user_message,
}
),
]
if failed:
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_failed",
"session_id": session_id,
"error": error or "orchestration failed",
}
)
)
return mapped_events
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": session_id,
"output": assistant_output,
}
)
)
return mapped_events
async def _resolve_session(
self,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
if session_id is not None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.where(AgentChatSession.user_id == user_id)
.where(AgentChatSession.deleted_at.is_(None))
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
existing = result.scalar_one_or_none()
if existing is None:
raise HTTPException(status_code=404, detail="Session not found")
existing.status = AgentChatSessionStatus.RUNNING
return existing
title = build_session_title(first_message, now=now)
created = AgentChatSession(
user_id=user_id,
title=title,
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
)
self._session.add(created)
await self._session.flush()
return created
async def _next_seq_base(self, session_id: object) -> int:
stmt = select(func.max(AgentChatMessage.seq)).where(
AgentChatMessage.session_id == session_id
)
result = await self._session.scalar(stmt)
if result is None:
return 0
return int(result)
async def _intent_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
context["intent"] = "default"
return {"content": message, "usage": EMPTY_USAGE}
async def _execution_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {"content": message, "usage": EMPTY_USAGE}
async def _organization_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {"content": message, "usage": EMPTY_USAGE}
async def get_state_snapshot(self, session_id: UUID) -> dict | None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
return None
return session.state_snapshot
@staticmethod
def _load_snapshot_v2(raw_snapshot: dict[str, Any]) -> AgentSessionSnapshot:
try:
return AgentSessionSnapshot.model_validate(raw_snapshot)
except Exception as exc: # noqa: BLE001
raise ValueError("Invalid state_snapshot format") from exc
async def _get_session_for_update(
self, session_id: UUID
) -> AgentChatSession | None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _assert_session_owner(self, session: AgentChatSession) -> None:
if self._current_user is None:
return
if session.user_id != self.require_user_id():
raise HTTPException(status_code=404, detail="Session not found")
@staticmethod
def _validate_no_newlines(value: str, *, field_name: str) -> None:
if "\n" in value or "\r" in value:
raise ValueError(f"{field_name} must not contain newlines")
@staticmethod
def _sse_data(payload: dict[str, Any]) -> str:
return f"data: {json.dumps(payload)}\\n\\n"
async def set_pending_tool_call(
self,
*,
session_id: UUID,
interrupt_id: str,
tool_name: str,
tool_args: dict,
expires_at: datetime,
thread_id: str,
run_id: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
snapshot = AgentSessionSnapshot(
version=2,
run_context=SnapshotRunContext(thread_id=thread_id, run_id=run_id),
pending_tool_call=PendingToolCall(
interrupt_id=interrupt_id,
tool_name=tool_name,
tool_args=tool_args,
status=PendingToolStatus.PENDING_APPROVAL,
expires_at=expires_at,
decision=None,
result=None,
updated_at=datetime.now(timezone.utc),
),
)
session.state_snapshot = snapshot.model_dump(mode="json")
async def update_pending_tool_call_status(
self,
*,
session_id: UUID,
interrupt_id: str,
status: str,
) -> None:
stmt = select(AgentChatSession).where(AgentChatSession.id == session_id)
session = await self._session.scalar(stmt)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
if session.state_snapshot is None:
raise ValueError("No pending tool call found")
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
raise ValueError("No pending tool call found")
if pending.interrupt_id != interrupt_id:
raise ValueError("Interrupt ID mismatch")
updated_pending = pending.model_copy(
update={
"status": PendingToolStatus(status),
"updated_at": datetime.now(timezone.utc),
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
async def apply_resume_decision(
self,
*,
session_id: UUID,
interrupt_id: str,
decision: dict[str, Any],
) -> ResumeDecisionResult:
session = await self._get_session_for_update(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
self._assert_session_owner(session)
if session.state_snapshot is None:
return ResumeDecisionResult(applied=False)
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is None:
return ResumeDecisionResult(applied=False)
if pending.interrupt_id != interrupt_id:
return ResumeDecisionResult(applied=False)
if pending.status != PendingToolStatus.PENDING_APPROVAL:
return ResumeDecisionResult(applied=False)
now = datetime.now(timezone.utc)
if pending.expires_at <= now:
expired_pending = pending.model_copy(
update={
"status": PendingToolStatus.EXPIRED,
"updated_at": now,
}
)
expired_snapshot = snapshot.model_copy(
update={"pending_tool_call": expired_pending}
)
session.state_snapshot = expired_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=False)
decision_value = decision.get("decision", "approved")
next_status = (
PendingToolStatus.APPROVED_EXECUTING
if decision_value == "approved"
else PendingToolStatus.REJECTED
)
updated_pending = pending.model_copy(
update={
"status": next_status,
"decision": decision,
"updated_at": now,
}
)
updated_snapshot = snapshot.model_copy(
update={"pending_tool_call": updated_pending}
)
session.state_snapshot = updated_snapshot.model_dump(mode="json")
return ResumeDecisionResult(applied=True)
async def stream_run(self, input_data: RunAgentInput) -> AsyncGenerator[str, None]:
self._validate_no_newlines(input_data.runId, field_name="runId")
yield self._sse_data({"type": "RUN_STARTED", "runId": input_data.runId})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m1"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Hello"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m1"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": input_data.runId})
async def prepare_resume(self, run_id: str, input_data: RunAgentInput) -> None:
self._validate_no_newlines(run_id, field_name="runId")
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_resume",
identifier=str(user_id),
limit=DEFAULT_RATE_LIMIT,
window_seconds=DEFAULT_RATE_LIMIT,
)
try:
session_id = UUID(run_id)
except ValueError as exc:
raise HTTPException(
status_code=422, detail="run_id must be a valid UUID"
) from exc
session = await self._get_session_for_update(session_id)
if session is None or session.user_id != user_id:
raise HTTPException(status_code=404, detail="Session not found")
if input_data.resume is None:
raise HTTPException(status_code=422, detail="resume payload is required")
interrupt_id = input_data.resume.get("interruptId")
if not isinstance(interrupt_id, str) or not interrupt_id:
raise HTTPException(
status_code=422, detail="resume.interruptId is required"
)
decision_payload = input_data.resume.get("payload", {})
if not isinstance(decision_payload, dict):
raise HTTPException(
status_code=422,
detail="resume.payload must be an object",
)
try:
decision_result = await self.apply_resume_decision(
session_id=session_id,
interrupt_id=interrupt_id,
decision=decision_payload,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
if not decision_result.applied:
if session.state_snapshot is not None:
snapshot = self._load_snapshot_v2(session.state_snapshot)
pending = snapshot.pending_tool_call
if pending is not None and pending.status == PendingToolStatus.EXPIRED:
await self._session.commit()
raise HTTPException(
status_code=410,
detail="Pending tool call expired",
)
raise HTTPException(
status_code=409,
detail="Resume decision not applicable",
)
await self._session.commit()
async def stream_resume(
self, run_id: str, input_data: RunAgentInput
) -> AsyncGenerator[str, None]:
self._validate_no_newlines(run_id, field_name="runId")
yield self._sse_data({"type": "RUN_STARTED", "runId": run_id})
yield self._sse_data({"type": "TEXT_MESSAGE_START", "messageId": "m2"})
yield self._sse_data({"type": "TEXT_MESSAGE_CONTENT", "delta": "Resumed"})
yield self._sse_data({"type": "TEXT_MESSAGE_END", "messageId": "m2"})
yield self._sse_data({"type": "RUN_FINISHED", "runId": run_id})
-80
View File
@@ -1,80 +0,0 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel
from v1.agent.tool_registry import validate_tool_spec
ALLOWED_BACKEND_TOOLS = frozenset(
{
"srv.search_docs",
"srv.get_user_info",
"srv.send_message",
"srv.transfer_funds",
"srv.delete_file",
}
)
ALLOWED_FRONTEND_TOOLS = frozenset(
{
"ui.navigate_to",
}
)
class InterruptResult(BaseModel):
interrupt_type: str
tool_name: str
tool_args: dict[str, Any]
class BackendExecutionResult(BaseModel):
tool_name: str
tool_args: dict[str, Any]
result: Any | None = None
class ToolDispatcher:
def dispatch(
self, tool: dict[str, Any]
) -> InterruptResult | BackendExecutionResult:
return dispatch_tool_call(tool)
def dispatch_tool_call(
tool: dict[str, Any],
) -> InterruptResult | BackendExecutionResult:
validate_tool_spec(tool)
name = tool["name"]
target = tool["execution_target"]
args = tool.get("args", {})
if target == "frontend":
if name not in ALLOWED_FRONTEND_TOOLS:
raise ValueError(f"Frontend tool '{name}' not in allowlist")
return InterruptResult(
interrupt_type="tool_execution",
tool_name=name,
tool_args=args,
)
if target == "backend":
if name not in ALLOWED_BACKEND_TOOLS:
raise ValueError(f"Backend tool '{name}' not in allowlist")
requires_approval = tool.get("requires_approval", False)
if requires_approval:
return InterruptResult(
interrupt_type="approval_required",
tool_name=name,
tool_args=args,
)
return BackendExecutionResult(
tool_name=name,
tool_args=args,
)
raise ValueError(f"Unknown execution_target: {target}")
-24
View File
@@ -1,24 +0,0 @@
from __future__ import annotations
from typing import Any
def validate_tool_spec(spec: dict[str, Any]) -> None:
try:
name = spec["name"]
target = spec["execution_target"]
except KeyError as e:
raise ValueError(f"Missing required field: {e.args[0]}") from e
if not name or not isinstance(name, str):
raise ValueError("Tool name must be a non-empty string")
if not target or not isinstance(target, str):
raise ValueError("execution_target must be a non-empty string")
if not (name.startswith("ui.") or name.startswith("srv.")):
raise ValueError("Tool name must be in ui.* or srv.* namespace")
if name.startswith("ui.") and target != "frontend":
raise ValueError("ui.* must use frontend target")
if name.startswith("srv.") and target != "backend":
raise ValueError("srv.* must use backend target")
-2
View File
@@ -3,7 +3,6 @@ from __future__ import annotations
from fastapi import APIRouter
from core.http.models import HealthResponse
from v1.agent.router import router as agent_router
from v1.auth.router import router as auth_router
from v1.friendships.router import router as friendships_router
from v1.inbox_messages.router import router as inbox_messages_router
@@ -17,7 +16,6 @@ router.include_router(auth_router)
router.include_router(friendships_router)
router.include_router(infra_router)
router.include_router(users_router)
router.include_router(agent_router)
router.include_router(schedule_items_router)
router.include_router(inbox_messages_router)