feat(agent): 增强多模态链路与工具调用能力

This commit is contained in:
zl-q
2026-03-12 00:18:45 +08:00
parent 18db6c50e7
commit 21ba8e4a44
35 changed files with 2057 additions and 829 deletions
@@ -23,10 +23,12 @@ class MessageRepository:
role: AgentChatMessageRole,
content: str,
model_code: str | None = None,
tool_name: str | None = None,
metadata: dict[str, object] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
latency_ms: int | None = None,
) -> AgentChatMessage:
message = AgentChatMessage(
session_id=session_id,
@@ -34,10 +36,12 @@ class MessageRepository:
role=role,
content=content,
model_code=model_code,
tool_name=tool_name,
metadata_json=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
)
self._session.add(message)
await self._session.flush()
+123 -3
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
from decimal import Decimal, InvalidOperation
from typing import Any, Callable, Protocol
from uuid import UUID
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
def __init__(self, *, session_factory: Any) -> None:
self._session_factory = session_factory
self._message_buffers: dict[tuple[str, str], str] = {}
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
async def persist(self, event: dict[str, Any]) -> None:
event_type = str(event.get("type", "")).strip().upper()
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
self._buffer_text_delta(session_key=session_key, event=event)
return
if event_type == "TEXT_MESSAGE_START":
self._buffer_text_context(session_key=session_key, event=event)
return
if event_type == "RUN_STARTED":
await self._update_session_state(
session_repo=session_repo,
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
)
self._clear_session_buffers(session_key=session_key)
elif event_type == "TEXT_MESSAGE_END":
await self._persist_assistant_message(
await self._persist_text_message(
event=event,
session_id=session_id,
chat_session=chat_session,
session_repo=session_repo,
message_repo=message_repo,
)
elif event_type == "TOOL_CALL_RESULT":
await self._persist_tool_call_result(
event=event,
session_id=session_id,
chat_session=chat_session,
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
for key in stale_keys:
self._message_buffers.pop(key, None)
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
for key in stale_context_keys:
self._message_contexts.pop(key, None)
async def _persist_assistant_message(
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
message_id = event.get("messageId")
if not isinstance(message_id, str) or not message_id:
return
key = (session_key, message_id)
role = event.get("role")
stage = event.get("stage")
tool_name = event.get("toolName")
context: dict[str, object] = {}
if isinstance(role, str) and role:
context["role"] = role
if isinstance(stage, str) and stage:
context["stage"] = stage
if isinstance(tool_name, str) and tool_name:
context["tool_name"] = tool_name
self._message_contexts[key] = context
async def _persist_text_message(
self,
*,
event: dict[str, Any],
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
if not content:
return
context = self._message_contexts.get(key, {})
input_tokens = self._to_int(event.get("inputTokens"))
output_tokens = self._to_int(event.get("outputTokens"))
token_delta = input_tokens + output_tokens
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
metadata["run_id"] = run_id
if latency_ms is not None:
metadata["latency_ms"] = latency_ms
stage = event.get("stage")
if not isinstance(stage, str):
stage = context.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
role_value = context.get("role")
if not isinstance(role_value, str):
role_value = "assistant"
role = self._resolve_role(role_value)
tool_name = context.get("tool_name")
tool_name_value = (
tool_name if isinstance(tool_name, str) and tool_name else None
)
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.ASSISTANT,
role=role,
content=content,
model_code=model_code if isinstance(model_code, str) else None,
tool_name=tool_name_value,
metadata=metadata,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
cost_delta=cost,
)
self._message_buffers.pop(key, None)
self._message_contexts.pop(key, None)
async def _persist_tool_call_result(
self,
*,
event: dict[str, Any],
session_id: UUID,
chat_session: Any,
session_repo: SessionRepository,
message_repo: MessageRepository,
) -> None:
tool_name = event.get("toolName")
if not isinstance(tool_name, str) or not tool_name:
return
payload = {
"args": event.get("args"),
"result": event.get("result"),
"error": event.get("error"),
"call_id": event.get("callId"),
}
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
metadata: dict[str, object] = {"tool_name": tool_name}
run_id = event.get("runId")
if isinstance(run_id, str) and run_id:
metadata["run_id"] = run_id
stage = event.get("stage")
if isinstance(stage, str) and stage:
metadata["stage"] = stage
task_id = event.get("taskId")
if isinstance(task_id, str) and task_id:
metadata["task_id"] = task_id
locked_session = await session_repo.lock_session_for_update(
session_id=session_id
)
if locked_session is None:
return
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
await message_repo.append_message(
session_id=session_id,
seq=seq,
role=AgentChatMessageRole.TOOL,
content=content,
tool_name=tool_name,
metadata=metadata,
)
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
status = (
current_status
if isinstance(current_status, AgentChatSessionStatus)
else AgentChatSessionStatus.RUNNING
)
await self._update_session_state(
session_repo=session_repo,
chat_session=chat_session,
status=status,
message_delta=1,
)
def _resolve_role(self, value: str) -> AgentChatMessageRole:
normalized = value.strip().lower()
if normalized == AgentChatMessageRole.SYSTEM.value:
return AgentChatMessageRole.SYSTEM
if normalized == AgentChatMessageRole.TOOL.value:
return AgentChatMessageRole.TOOL
return AgentChatMessageRole.ASSISTANT
async def _update_session_state(
self,