feat(agent): 增强多模态链路与工具调用能力
This commit is contained in:
@@ -23,10 +23,12 @@ class MessageRepository:
|
||||
role: AgentChatMessageRole,
|
||||
content: str,
|
||||
model_code: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int | None = None,
|
||||
) -> AgentChatMessage:
|
||||
message = AgentChatMessage(
|
||||
session_id=session_id,
|
||||
@@ -34,10 +36,12 @@ class MessageRepository:
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code,
|
||||
tool_name=tool_name,
|
||||
metadata_json=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
self._session.add(message)
|
||||
await self._session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Callable, Protocol
|
||||
from uuid import UUID
|
||||
@@ -24,6 +25,7 @@ class SqlAlchemyEventStore:
|
||||
def __init__(self, *, session_factory: Any) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._message_buffers: dict[tuple[str, str], str] = {}
|
||||
self._message_contexts: dict[tuple[str, str], dict[str, object]] = {}
|
||||
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
event_type = str(event.get("type", "")).strip().upper()
|
||||
@@ -48,6 +50,10 @@ class SqlAlchemyEventStore:
|
||||
self._buffer_text_delta(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "TEXT_MESSAGE_START":
|
||||
self._buffer_text_context(session_key=session_key, event=event)
|
||||
return
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
@@ -72,7 +78,15 @@ class SqlAlchemyEventStore:
|
||||
)
|
||||
self._clear_session_buffers(session_key=session_key)
|
||||
elif event_type == "TEXT_MESSAGE_END":
|
||||
await self._persist_assistant_message(
|
||||
await self._persist_text_message(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
@@ -97,8 +111,28 @@ class SqlAlchemyEventStore:
|
||||
stale_keys = [k for k in self._message_buffers if k[0] == session_key]
|
||||
for key in stale_keys:
|
||||
self._message_buffers.pop(key, None)
|
||||
stale_context_keys = [k for k in self._message_contexts if k[0] == session_key]
|
||||
for key in stale_context_keys:
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_assistant_message(
|
||||
def _buffer_text_context(self, *, session_key: str, event: dict[str, Any]) -> None:
|
||||
message_id = event.get("messageId")
|
||||
if not isinstance(message_id, str) or not message_id:
|
||||
return
|
||||
key = (session_key, message_id)
|
||||
role = event.get("role")
|
||||
stage = event.get("stage")
|
||||
tool_name = event.get("toolName")
|
||||
context: dict[str, object] = {}
|
||||
if isinstance(role, str) and role:
|
||||
context["role"] = role
|
||||
if isinstance(stage, str) and stage:
|
||||
context["stage"] = stage
|
||||
if isinstance(tool_name, str) and tool_name:
|
||||
context["tool_name"] = tool_name
|
||||
self._message_contexts[key] = context
|
||||
|
||||
async def _persist_text_message(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
@@ -114,6 +148,8 @@ class SqlAlchemyEventStore:
|
||||
if not content:
|
||||
return
|
||||
|
||||
context = self._message_contexts.get(key, {})
|
||||
|
||||
input_tokens = self._to_int(event.get("inputTokens"))
|
||||
output_tokens = self._to_int(event.get("outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
@@ -127,6 +163,20 @@ class SqlAlchemyEventStore:
|
||||
metadata["run_id"] = run_id
|
||||
if latency_ms is not None:
|
||||
metadata["latency_ms"] = latency_ms
|
||||
stage = event.get("stage")
|
||||
if not isinstance(stage, str):
|
||||
stage = context.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
|
||||
role_value = context.get("role")
|
||||
if not isinstance(role_value, str):
|
||||
role_value = "assistant"
|
||||
role = self._resolve_role(role_value)
|
||||
tool_name = context.get("tool_name")
|
||||
tool_name_value = (
|
||||
tool_name if isinstance(tool_name, str) and tool_name else None
|
||||
)
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
@@ -137,13 +187,15 @@ class SqlAlchemyEventStore:
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
role=role,
|
||||
content=content,
|
||||
model_code=model_code if isinstance(model_code, str) else None,
|
||||
tool_name=tool_name_value,
|
||||
metadata=metadata,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
@@ -161,6 +213,74 @@ class SqlAlchemyEventStore:
|
||||
cost_delta=cost,
|
||||
)
|
||||
self._message_buffers.pop(key, None)
|
||||
self._message_contexts.pop(key, None)
|
||||
|
||||
async def _persist_tool_call_result(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
session_id: UUID,
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
tool_name = event.get("toolName")
|
||||
if not isinstance(tool_name, str) or not tool_name:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"args": event.get("args"),
|
||||
"result": event.get("result"),
|
||||
"error": event.get("error"),
|
||||
"call_id": event.get("callId"),
|
||||
}
|
||||
content = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||
metadata: dict[str, object] = {"tool_name": tool_name}
|
||||
run_id = event.get("runId")
|
||||
if isinstance(run_id, str) and run_id:
|
||||
metadata["run_id"] = run_id
|
||||
stage = event.get("stage")
|
||||
if isinstance(stage, str) and stage:
|
||||
metadata["stage"] = stage
|
||||
task_id = event.get("taskId")
|
||||
if isinstance(task_id, str) and task_id:
|
||||
metadata["task_id"] = task_id
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.TOOL,
|
||||
content=content,
|
||||
tool_name=tool_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
status = (
|
||||
current_status
|
||||
if isinstance(current_status, AgentChatSessionStatus)
|
||||
else AgentChatSessionStatus.RUNNING
|
||||
)
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=status,
|
||||
message_delta=1,
|
||||
)
|
||||
|
||||
def _resolve_role(self, value: str) -> AgentChatMessageRole:
|
||||
normalized = value.strip().lower()
|
||||
if normalized == AgentChatMessageRole.SYSTEM.value:
|
||||
return AgentChatMessageRole.SYSTEM
|
||||
if normalized == AgentChatMessageRole.TOOL.value:
|
||||
return AgentChatMessageRole.TOOL
|
||||
return AgentChatMessageRole.ASSISTANT
|
||||
|
||||
async def _update_session_state(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user