feat: 实现 AgentScope tool call context,支持 runtime 上下文续接
This commit is contained in:
@@ -109,7 +109,7 @@ class PipelineStageEmitter:
|
||||
"role": "tool",
|
||||
"stage": self._stage,
|
||||
"tool_name": tool_output.tool_name,
|
||||
"tool_call_id": tool_output.tool_call_id,
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_args": tool_output.tool_call_args,
|
||||
"status": tool_output.status.value,
|
||||
"result": tool_output.result,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
@@ -18,10 +19,13 @@ from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
extract_user_message_attachments,
|
||||
)
|
||||
from schemas.user import UserContext
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from services.base.supabase import supabase_service
|
||||
from schemas.messages.chat_message import extract_user_message_attachments
|
||||
from v1.agent.dependencies import get_agent_service
|
||||
from v1.users.dependencies import get_user_service
|
||||
|
||||
@@ -29,6 +33,33 @@ logger = get_logger("core.agentscope.runtime.tasks")
|
||||
_MAX_CONTEXT_ATTACHMENTS = 3
|
||||
|
||||
|
||||
def _serialize_tool_agent_output(
|
||||
*,
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None,
|
||||
) -> str | None:
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
resolved_metadata = (
|
||||
metadata
|
||||
if isinstance(metadata, AgentChatMessageMetadata)
|
||||
else AgentChatMessageMetadata.model_validate(metadata)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
tool_agent_output = resolved_metadata.tool_agent_output
|
||||
if tool_agent_output is None:
|
||||
return None
|
||||
|
||||
return json.dumps(
|
||||
tool_agent_output.model_dump(mode="json", exclude_none=True),
|
||||
ensure_ascii=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
|
||||
def _load_runtime() -> type[Any]:
|
||||
return AgentScopeRuntimeOrchestrator
|
||||
|
||||
@@ -53,16 +84,25 @@ async def _build_recent_context_messages(
|
||||
if not result:
|
||||
return []
|
||||
|
||||
raw_messages: list[dict[str, Any]] = result.get("messages") or []
|
||||
raw_messages: list[dict[str, object]] = result.get("messages") or []
|
||||
if not raw_messages:
|
||||
return []
|
||||
|
||||
converted: list[Msg] = []
|
||||
|
||||
for msg in raw_messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
metadata = msg.get("metadata")
|
||||
role_raw = msg.get("role")
|
||||
role = role_raw if isinstance(role_raw, str) else "user"
|
||||
content_raw = msg.get("content", "")
|
||||
content: str = content_raw if isinstance(content_raw, str) else ""
|
||||
metadata_raw = msg.get("metadata")
|
||||
metadata: AgentChatMessageMetadata | dict[str, object] | None
|
||||
if isinstance(metadata_raw, AgentChatMessageMetadata):
|
||||
metadata = metadata_raw
|
||||
elif isinstance(metadata_raw, dict):
|
||||
metadata = metadata_raw
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
if role == "user" and metadata:
|
||||
image_blocks: list[dict[str, Any]] = []
|
||||
@@ -105,6 +145,10 @@ async def _build_recent_context_messages(
|
||||
|
||||
if role == "tool":
|
||||
role = "assistant"
|
||||
tool_content = _serialize_tool_agent_output(metadata=metadata)
|
||||
if not tool_content:
|
||||
continue
|
||||
content = tool_content
|
||||
|
||||
converted.append(
|
||||
Msg(
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.agentscope.tools.utils.calendar_ui import (
|
||||
calendar_error_output,
|
||||
dump_tool_output,
|
||||
)
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from schemas.agent.runtime_models import ErrorInfo, ToolAgentOutput, ToolStatus
|
||||
from v1.schedule_items.schemas import (
|
||||
ScheduleItemCreateRequest,
|
||||
@@ -75,9 +76,28 @@ def _format_event_brief(event_items: list[dict[str, Any]], limit: int = 3) -> st
|
||||
event_id = str(item.get("id") or "")
|
||||
title = str(item.get("title") or "")
|
||||
start_at = str(item.get("startAt") or "")
|
||||
end_at = str(item.get("endAt") or "")
|
||||
timezone = str(item.get("timezone") or "")
|
||||
status = str(item.get("status") or "")
|
||||
description = str(item.get("description") or "")
|
||||
location = str(item.get("location") or "")
|
||||
reminder_minutes = item.get("reminderMinutes")
|
||||
color = str(item.get("color") or "")
|
||||
source_type = str(item.get("sourceType") or "")
|
||||
updated_at = str(item.get("updatedAt") or "")
|
||||
permission = item.get("permission")
|
||||
is_owner = item.get("isOwner")
|
||||
if not event_id:
|
||||
continue
|
||||
briefs.append(f"{{id={event_id},title={title},startAt={start_at}}}")
|
||||
briefs.append(
|
||||
"{"
|
||||
f"id={event_id},title={title},startAt={start_at},endAt={end_at},"
|
||||
f"timezone={timezone},status={status},description={description},"
|
||||
f"location={location},reminderMinutes={reminder_minutes},color={color},"
|
||||
f"sourceType={source_type},updatedAt={updated_at},permission={permission},"
|
||||
f"isOwner={is_owner}"
|
||||
"}"
|
||||
)
|
||||
return ",".join(briefs)
|
||||
|
||||
|
||||
@@ -129,18 +149,17 @@ async def calendar_read(
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size if total else 0
|
||||
event_items = [schedule_event_to_dict(item) for item in items]
|
||||
query_value = (query or "").strip() or "*"
|
||||
event_brief = _format_event_brief(event_items)
|
||||
summary = (
|
||||
f"status=success query={query_value} total={total} page={page}/"
|
||||
f"{total_pages or 1} returned={len(event_items)}"
|
||||
f"status=success total={total} total_pages={total_pages or 1} "
|
||||
f"returned={len(event_items)} has_next={str(page < (total_pages or 1)).lower()}"
|
||||
)
|
||||
if event_brief:
|
||||
summary = f"{summary} items=[{event_brief}]"
|
||||
return dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=f"{tool_name}-call",
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
@@ -359,11 +378,8 @@ async def calendar_write(
|
||||
success_count += 1
|
||||
result_items.append(
|
||||
{
|
||||
"index": idx,
|
||||
"operation": operation,
|
||||
"status": "success",
|
||||
"eventId": str(created.id),
|
||||
"message": f"日程「{created.title}」已创建",
|
||||
}
|
||||
)
|
||||
success_event_ids.append(str(created.id))
|
||||
@@ -397,6 +413,7 @@ async def calendar_write(
|
||||
color=cast(str | None, color),
|
||||
reminder_minutes=cast(int | None, reminder_minutes),
|
||||
)
|
||||
changed_fields = sorted(update_data.keys())
|
||||
updated = await service.update(
|
||||
parsed_event_id,
|
||||
ScheduleItemUpdateRequest.model_validate(update_data),
|
||||
@@ -404,11 +421,9 @@ async def calendar_write(
|
||||
success_count += 1
|
||||
result_items.append(
|
||||
{
|
||||
"index": idx,
|
||||
"operation": operation,
|
||||
"status": "success",
|
||||
"eventId": str(updated.id),
|
||||
"message": f"日程「{updated.title}」已更新",
|
||||
"changedFields": changed_fields,
|
||||
}
|
||||
)
|
||||
success_event_ids.append(str(updated.id))
|
||||
@@ -421,11 +436,8 @@ async def calendar_write(
|
||||
success_count += 1
|
||||
result_items.append(
|
||||
{
|
||||
"index": idx,
|
||||
"operation": operation,
|
||||
"status": "success",
|
||||
"eventId": event_id,
|
||||
"message": f"日程 {event_id} 已删除",
|
||||
}
|
||||
)
|
||||
success_event_ids.append(event_id)
|
||||
@@ -435,8 +447,6 @@ async def calendar_write(
|
||||
failed_count += 1
|
||||
result_items.append(
|
||||
{
|
||||
"index": idx,
|
||||
"operation": operation,
|
||||
"status": "failure",
|
||||
"eventId": event_id,
|
||||
"code": code,
|
||||
@@ -447,21 +457,30 @@ async def calendar_write(
|
||||
if failed_count == 0:
|
||||
final_status = ToolStatus.SUCCESS
|
||||
summary = (
|
||||
f"status=success batch={batch_size} success={success_count} "
|
||||
f"failed={failed_count} ids=[{','.join(success_event_ids)}]"
|
||||
f"status=success success={success_count} failed={failed_count} "
|
||||
f"ids=[{','.join(success_event_ids)}]"
|
||||
)
|
||||
elif success_count == 0:
|
||||
final_status = ToolStatus.FAILURE
|
||||
summary = (
|
||||
f"status=failure batch={batch_size} success={success_count} "
|
||||
f"failed={failed_count}"
|
||||
)
|
||||
summary = f"status=failure success={success_count} failed={failed_count}"
|
||||
else:
|
||||
final_status = ToolStatus.PARTIAL
|
||||
summary = (
|
||||
f"status=partial batch={batch_size} success={success_count} "
|
||||
f"failed={failed_count} ids=[{','.join(success_event_ids)}]"
|
||||
f"status=partial success={success_count} failed={failed_count} "
|
||||
f"ids=[{','.join(success_event_ids)}]"
|
||||
)
|
||||
compact_items = ",".join(
|
||||
[
|
||||
"{"
|
||||
f"status={item.get('status')},"
|
||||
f"eventId={item.get('eventId')},code={item.get('code')},"
|
||||
f"changedFields={item.get('changedFields')}"
|
||||
"}"
|
||||
for item in result_items
|
||||
]
|
||||
)
|
||||
if compact_items:
|
||||
summary = f"{summary} items=[{compact_items}]"
|
||||
|
||||
error_info: ErrorInfo | None = None
|
||||
if final_status == ToolStatus.FAILURE:
|
||||
@@ -477,7 +496,11 @@ async def calendar_write(
|
||||
code=str(
|
||||
first_failure.get("code") if first_failure else "BATCH_FAILED"
|
||||
),
|
||||
message=str(first_failure.get("message") if first_failure else summary),
|
||||
message=str(
|
||||
first_failure.get("message")
|
||||
if first_failure and first_failure.get("message")
|
||||
else summary
|
||||
),
|
||||
retryable=False,
|
||||
details={"results": result_items},
|
||||
)
|
||||
@@ -489,7 +512,7 @@ async def calendar_write(
|
||||
return dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=f"{tool_name}-call",
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=final_status,
|
||||
result=summary,
|
||||
@@ -597,11 +620,13 @@ async def calendar_share(
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
summary = f"status=success event_id={event_id} invited_count={len(invited)}"
|
||||
summary = (
|
||||
f"status=success invited_count={len(invited)} invited=[{','.join(invited)}]"
|
||||
)
|
||||
return dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=f"{tool_name}-call",
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from agentscope.tool import ToolResponse
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from core.agentscope.tools.utils import (
|
||||
find_auth_email_by_user_id,
|
||||
list_auth_users,
|
||||
@@ -33,7 +34,7 @@ def _lookup_error_output(
|
||||
) -> ToolResponse:
|
||||
output = build_error_output(
|
||||
tool_name="user_lookup",
|
||||
tool_call_id="user_lookup-call",
|
||||
tool_call_id=get_current_tool_call_id(tool_name="user_lookup"),
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
@@ -148,7 +149,7 @@ async def user_lookup(
|
||||
return _dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
tool_name="user_lookup",
|
||||
tool_call_id="user_lookup-call",
|
||||
tool_call_id=get_current_tool_call_id(tool_name="user_lookup"),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from uuid import uuid4
|
||||
|
||||
_CURRENT_TOOL_CALL_ID: ContextVar[str | None] = ContextVar(
|
||||
"current_tool_call_id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def set_current_tool_call_id(tool_call_id: str | None) -> Token[str | None]:
|
||||
return _CURRENT_TOOL_CALL_ID.set(tool_call_id)
|
||||
|
||||
|
||||
def reset_current_tool_call_id(token: Token[str | None]) -> None:
|
||||
_CURRENT_TOOL_CALL_ID.reset(token)
|
||||
|
||||
|
||||
def get_current_tool_call_id(*, tool_name: str) -> str:
|
||||
current = _CURRENT_TOOL_CALL_ID.get()
|
||||
if isinstance(current, str) and current.strip():
|
||||
return current.strip()
|
||||
return f"{tool_name}-call-{uuid4().hex}"
|
||||
@@ -1,7 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncGenerator, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from core.agentscope.tools.tool_call_context import (
|
||||
reset_current_tool_call_id,
|
||||
set_current_tool_call_id,
|
||||
)
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_error_response,
|
||||
)
|
||||
@@ -17,6 +22,7 @@ def register_tool_middlewares(
|
||||
| None = None,
|
||||
) -> None:
|
||||
effective_config = config_by_name or meta_by_name or TOOL_CONFIGS
|
||||
toolkit.register_middleware(create_tool_call_context_middleware())
|
||||
toolkit.register_middleware(
|
||||
create_approval_middleware(
|
||||
config_by_name=effective_config,
|
||||
@@ -25,12 +31,40 @@ def register_tool_middlewares(
|
||||
)
|
||||
|
||||
|
||||
def create_tool_call_context_middleware() -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
async def tool_call_context_middleware(
|
||||
kwargs: dict[str, Any],
|
||||
next_handler: Callable[..., Any],
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
tool_call = kwargs.get("tool_call")
|
||||
tool_call_id: str | None = None
|
||||
if isinstance(tool_call, dict):
|
||||
raw_id = tool_call.get("id")
|
||||
if isinstance(raw_id, str) and raw_id.strip():
|
||||
tool_call_id = raw_id.strip()
|
||||
|
||||
token = set_current_tool_call_id(tool_call_id)
|
||||
try:
|
||||
async for response in await next_handler(**kwargs):
|
||||
yield response
|
||||
finally:
|
||||
reset_current_tool_call_id(token)
|
||||
|
||||
return tool_call_context_middleware
|
||||
|
||||
|
||||
def create_approval_middleware(
|
||||
*,
|
||||
config_by_name: dict[str, ToolConfig],
|
||||
approval_resolver: Callable[[str, dict[str, Any], ToolConfig], str | None]
|
||||
| None = None,
|
||||
) -> Callable[..., AsyncGenerator[Any, None]]:
|
||||
def _resolve_tool_call_id(tool_call: dict[str, Any]) -> str:
|
||||
raw_tool_call_id = tool_call.get("id")
|
||||
if isinstance(raw_tool_call_id, str) and raw_tool_call_id.strip():
|
||||
return raw_tool_call_id.strip()
|
||||
return f"tool-call-{uuid4().hex}"
|
||||
|
||||
async def approval_middleware(
|
||||
kwargs: dict[str, Any],
|
||||
next_handler: Callable[..., Any],
|
||||
@@ -74,7 +108,7 @@ def create_approval_middleware(
|
||||
if decision == "rejected":
|
||||
content = build_error_response(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call.get("id", "unknown"),
|
||||
tool_call_id=_resolve_tool_call_id(tool_call),
|
||||
code="TOOL_REJECTED",
|
||||
message=f"工具 {tool_name} 的调用已被审核拒绝",
|
||||
retryable=False,
|
||||
@@ -88,7 +122,7 @@ def create_approval_middleware(
|
||||
|
||||
pending_response = build_error_response(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call.get("id", "unknown"),
|
||||
tool_call_id=_resolve_tool_call_id(tool_call),
|
||||
code="TOOL_PENDING_APPROVAL",
|
||||
message=f"工具 {tool_name} 需要审核批准",
|
||||
retryable=True,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from agentscope.tool import ToolResponse
|
||||
from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_error_output,
|
||||
build_tool_response,
|
||||
@@ -24,7 +25,7 @@ def calendar_error_output(
|
||||
) -> ToolResponse:
|
||||
output = build_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=f"{tool_name}-call",
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
|
||||
Reference in New Issue
Block a user