refactor: 重命名 agent_chat 模块为 agent

This commit is contained in:
qzl
2026-03-02 11:13:20 +08:00
parent 2ac56e5084
commit 99d540a18d
57 changed files with 11175 additions and 74 deletions
+1
View File
@@ -0,0 +1 @@
from __future__ import annotations
+18
View File
@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.agent.service import AgentChatService
from v1.profile.dependencies import get_current_user
def get_agent_service(
session: Annotated[AsyncSession, Depends(get_db)],
user: Annotated[CurrentUser, Depends(get_current_user)],
) -> AgentChatService:
return AgentChatService(session=session, current_user=user)
+19
View File
@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import AgentChatRunRequest, AgentChatRunResponse
from v1.agent.service import AgentChatService
router = APIRouter(prefix="/agent", tags=["agent"])
@router.post("", response_model=AgentChatRunResponse)
async def run_agent_chat(
payload: AgentChatRunRequest,
service: Annotated[AgentChatService, Depends(get_agent_service)],
) -> AgentChatRunResponse:
return await service.run(payload)
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from uuid import UUID
from pydantic import BaseModel, Field
class AgentChatRunRequest(BaseModel):
message: str = Field(min_length=1, max_length=8000)
session_id: UUID | None = None
class AgentChatEvent(BaseModel):
type: str
run_id: str | None = None
message_id: str | None = None
delta: str | None = None
tool_name: str | None = None
result: str | None = None
output: str | None = None
error: str | None = None
class AgentChatRunResponse(BaseModel):
session_id: UUID
output: str
events: list[AgentChatEvent]
+286
View File
@@ -0,0 +1,286 @@
from __future__ import annotations
from datetime import datetime, timezone
from decimal import Decimal
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import func, select
from sqlalchemy.exc import SQLAlchemyError
from core.agent.agui_adapter import AguiAdapter
from core.agent.orchestrator import AgentChatOrchestrator
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.logging import get_logger
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
from v1.auth.rate_limit import enforce_rate_limit
from v1.agent.schemas import (
AgentChatEvent,
AgentChatRunRequest,
AgentChatRunResponse,
)
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = get_logger("v1.agent.service")
def build_session_title(first_message: str, *, now: datetime) -> str:
title = first_message.strip().replace("\n", " ")[:24]
if not title:
return now.strftime("新对话 %Y-%m-%d %H:%M")
return title
def aggregate_session_cost(costs: list[Decimal]) -> Decimal:
total = Decimal("0")
for cost in costs:
if cost < 0:
raise ValueError("cost must be non-negative")
total += cost
return total
def select_recent_session(
sessions: list[AgentChatSession],
) -> AgentChatSession | None:
if not sessions:
return None
return max(sessions, key=lambda item: item.last_activity_at)
class AgentChatService(BaseService):
_session: AsyncSession
def __init__(self, session: AsyncSession, current_user: CurrentUser | None) -> None:
super().__init__(current_user=current_user)
self._session = session
self._adapter = AguiAdapter()
self._orchestrator = AgentChatOrchestrator(
intent_stage=self._intent_stage,
execution_stage=self._execution_stage,
organization_stage=self._organization_stage,
)
async def run(self, payload: AgentChatRunRequest) -> AgentChatRunResponse:
try:
command = self._adapter.to_command(payload.model_dump(mode="python"))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
user_id = self.require_user_id()
await enforce_rate_limit(
scope="agent_run",
identifier=str(user_id),
limit=60,
window_seconds=60,
)
now = datetime.now(timezone.utc)
try:
chat_session = await self._resolve_session(
session_id=command["session_id"],
user_id=user_id,
first_message=command["message"],
now=now,
)
base_seq = await self._next_seq_base(chat_session.id)
user_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 1,
role=AgentChatMessageRole.USER,
content=command["message"],
cost=Decimal("0"),
)
orchestrator_result = await self._orchestrator.run(
run_id=str(chat_session.id),
user_message=command["message"],
)
assistant_message = AgentChatMessage(
session_id=chat_session.id,
seq=base_seq + 2,
role=AgentChatMessageRole.ASSISTANT,
content=orchestrator_result.output,
input_tokens=int(orchestrator_result.usage["input_tokens"]),
output_tokens=int(orchestrator_result.usage["output_tokens"]),
cost=Decimal(orchestrator_result.usage["cost"]),
)
self._session.add(user_message)
self._session.add(assistant_message)
chat_session.status = (
AgentChatSessionStatus.FAILED
if orchestrator_result.failed
else AgentChatSessionStatus.COMPLETED
)
chat_session.last_activity_at = now
chat_session.message_count = chat_session.message_count + 2
chat_session.total_tokens = chat_session.total_tokens + int(
orchestrator_result.usage["total_tokens"]
)
chat_session.total_cost = aggregate_session_cost(
[
Decimal(chat_session.total_cost),
Decimal(orchestrator_result.usage["cost"]),
]
)
await self._session.commit()
await self._session.refresh(chat_session)
await self._session.refresh(user_message)
mapped_events = self._build_mapped_events(
session_id=str(chat_session.id),
message_id=str(user_message.id),
user_message=command["message"],
assistant_output=assistant_message.content,
failed=orchestrator_result.failed,
error=orchestrator_result.error,
)
events = [AgentChatEvent.model_validate(item) for item in mapped_events]
if orchestrator_result.failed:
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
)
return AgentChatRunResponse(
session_id=chat_session.id,
output=assistant_message.content,
events=events,
)
except HTTPException:
await self._session.rollback()
raise
except SQLAlchemyError:
await self._session.rollback()
logger.exception("Agent chat run failed")
raise HTTPException(status_code=503, detail="Agent chat store unavailable")
except Exception as exc: # noqa: BLE001
await self._session.rollback()
logger.exception(
"Agent chat unexpected failure", error_type=type(exc).__name__
)
raise HTTPException(
status_code=502, detail="Agent orchestration failed"
) from exc
def _build_mapped_events(
self,
*,
session_id: str,
message_id: str,
user_message: str,
assistant_output: str,
failed: bool,
error: str | None,
) -> list[dict[str, object]]:
mapped_events = [
self._adapter.to_protocol_event(
{
"kind": "run_started",
"session_id": session_id,
}
),
self._adapter.to_protocol_event(
{
"kind": "message_delta",
"message_id": message_id,
"delta": user_message,
}
),
]
if failed:
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_failed",
"session_id": session_id,
"error": error or "orchestration failed",
}
)
)
return mapped_events
mapped_events.append(
self._adapter.to_protocol_event(
{
"kind": "run_completed",
"session_id": session_id,
"output": assistant_output,
}
)
)
return mapped_events
async def _resolve_session(
self,
*,
session_id: object | None,
user_id: UUID,
first_message: str,
now: datetime,
) -> AgentChatSession:
if session_id is not None:
stmt = (
select(AgentChatSession)
.where(AgentChatSession.id == session_id)
.where(AgentChatSession.user_id == user_id)
.where(AgentChatSession.deleted_at.is_(None))
.with_for_update()
.limit(1)
)
result = await self._session.execute(stmt)
existing = result.scalar_one_or_none()
if existing is None:
raise HTTPException(status_code=404, detail="Session not found")
existing.status = AgentChatSessionStatus.RUNNING
return existing
title = build_session_title(first_message, now=now)
created = AgentChatSession(
user_id=user_id,
title=title,
status=AgentChatSessionStatus.RUNNING,
last_activity_at=now,
)
self._session.add(created)
await self._session.flush()
return created
async def _next_seq_base(self, session_id: object) -> int:
stmt = select(func.max(AgentChatMessage.seq)).where(
AgentChatMessage.session_id == session_id
)
result = await self._session.scalar(stmt)
if result is None:
return 0
return int(result)
async def _intent_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
context["intent"] = "default"
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def _execution_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}
async def _organization_stage(
self, *, message: str, context: dict[str, object]
) -> dict[str, object]:
return {
"content": message,
"usage": {"input_tokens": 0, "output_tokens": 0, "cost": "0"},
}