refactor: 梳理规则体系并统一记忆与部署流程
This commit is contained in:
@@ -9,7 +9,7 @@ from core.logging import get_logger
|
||||
from models.agent_chat_message import AgentChatMessageRole
|
||||
from models.agent_chat_session import AgentChatSessionStatus
|
||||
from schemas.agent.system_agent import AgentType
|
||||
from schemas.agent.runtime_models import AgentOutput, ToolAgentOutput
|
||||
from schemas.agent.runtime_models import AgentOutput, RouterAgentOutput, ToolAgentOutput
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.messages.chat_message import AgentChatMessageMetadata
|
||||
|
||||
@@ -79,6 +79,14 @@ class SqlAlchemyEventStore:
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
)
|
||||
elif event_type == "STEP_FINISHED":
|
||||
await self._persist_router_step_output(
|
||||
event=event,
|
||||
session_id=session_id,
|
||||
chat_session=chat_session,
|
||||
session_repo=session_repo,
|
||||
message_repo=message_repo,
|
||||
)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
await self._persist_tool_call_result(
|
||||
event=event,
|
||||
@@ -199,6 +207,95 @@ class SqlAlchemyEventStore:
|
||||
cost_delta=cost,
|
||||
)
|
||||
|
||||
async def _persist_router_step_output(
|
||||
self,
|
||||
*,
|
||||
event: dict[str, Any],
|
||||
session_id: UUID,
|
||||
chat_session: Any,
|
||||
session_repo: SessionRepository,
|
||||
message_repo: MessageRepository,
|
||||
) -> None:
|
||||
step_name = self._event_value(event, "stepName")
|
||||
if not isinstance(step_name, str) or step_name.strip().lower() != "router":
|
||||
return
|
||||
|
||||
run_id = self._event_value(event, "runId")
|
||||
run_id_value = run_id if isinstance(run_id, str) and run_id else None
|
||||
if run_id_value is None:
|
||||
return
|
||||
|
||||
persist_payload = event.get("_router_persist")
|
||||
if not isinstance(persist_payload, dict):
|
||||
return
|
||||
|
||||
router_output_raw = persist_payload.get("router_output")
|
||||
response_metadata_raw = persist_payload.get("response_metadata")
|
||||
if not isinstance(router_output_raw, dict):
|
||||
return
|
||||
|
||||
response_metadata = (
|
||||
response_metadata_raw if isinstance(response_metadata_raw, dict) else {}
|
||||
)
|
||||
model_code_raw = response_metadata.get("model")
|
||||
model_code = model_code_raw if isinstance(model_code_raw, str) else None
|
||||
input_tokens = self._to_int(response_metadata.get("inputTokens"))
|
||||
output_tokens = self._to_int(response_metadata.get("outputTokens"))
|
||||
token_delta = input_tokens + output_tokens
|
||||
cost = self._to_decimal(response_metadata.get("cost"))
|
||||
latency_ms = self._to_int_or_none(response_metadata.get("latencyMs"))
|
||||
|
||||
try:
|
||||
router_output = RouterAgentOutput.model_validate(router_output_raw)
|
||||
metadata_model = AgentChatMessageMetadata(
|
||||
run_id=run_id_value,
|
||||
agent_type=AgentType.ROUTER,
|
||||
router_agent_output=router_output,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"invalid router metadata payload",
|
||||
run_id=run_id_value,
|
||||
)
|
||||
return
|
||||
|
||||
content = ""
|
||||
|
||||
locked_session = await session_repo.lock_session_for_update(
|
||||
session_id=session_id
|
||||
)
|
||||
if locked_session is None:
|
||||
return
|
||||
seq = int(getattr(locked_session, "message_count", 0) or 0) + 1
|
||||
await message_repo.append_message(
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=AgentChatMessageRole.ASSISTANT,
|
||||
content=content,
|
||||
model_code=model_code,
|
||||
metadata=metadata_model.model_dump(mode="json", exclude_none=True),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
visibility_mask=0,
|
||||
)
|
||||
|
||||
current_status = getattr(chat_session, "status", AgentChatSessionStatus.RUNNING)
|
||||
status = (
|
||||
current_status
|
||||
if isinstance(current_status, AgentChatSessionStatus)
|
||||
else AgentChatSessionStatus.RUNNING
|
||||
)
|
||||
await self._update_session_state(
|
||||
session_repo=session_repo,
|
||||
chat_session=chat_session,
|
||||
status=status,
|
||||
message_delta=1,
|
||||
token_delta=token_delta,
|
||||
cost_delta=cost,
|
||||
)
|
||||
|
||||
async def _persist_tool_call_result(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.agentscope.schemas.agui_input import extract_latest_user_payload
|
||||
from core.agentscope.runtime.json_react_agent import JsonReActAgent
|
||||
from core.agentscope.runtime.model_tracking import TrackingChatModel
|
||||
from core.agentscope.runtime.stage_emitter import PipelineStageEmitter
|
||||
from core.agentscope.tools.tool_config import AgentTool
|
||||
from core.agentscope.tools.tool_config import AgentTool, resolve_tool_function_names
|
||||
from core.agentscope.tools.toolkit import build_toolkit
|
||||
from core.agentscope.utils import (
|
||||
finalize_json_response,
|
||||
@@ -123,7 +123,11 @@ class AgentScopeRunner:
|
||||
owner_id: UUID,
|
||||
enabled_tools: list[AgentTool],
|
||||
) -> Any:
|
||||
tool_names = [t.value for t in enabled_tools] if enabled_tools else []
|
||||
tool_names = (
|
||||
sorted(resolve_tool_function_names(set(enabled_tools)))
|
||||
if enabled_tools
|
||||
else []
|
||||
)
|
||||
return build_toolkit(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
@@ -189,6 +193,14 @@ class AgentScopeRunner:
|
||||
run_input=run_input,
|
||||
step_name=AgentType.ROUTER.value,
|
||||
event_type="STEP_FINISHED",
|
||||
extra_event={
|
||||
"_router_persist": {
|
||||
"router_output": router_output.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
),
|
||||
"response_metadata": router_result.response_metadata,
|
||||
}
|
||||
},
|
||||
)
|
||||
return router_output
|
||||
|
||||
@@ -382,11 +394,13 @@ class AgentScopeRunner:
|
||||
self, *, stage_config: SystemAgentRuntimeConfig
|
||||
) -> TrackingChatModel:
|
||||
generate_kwargs: dict[str, Any] = {
|
||||
"temperature": stage_config.llm_config.temperature,
|
||||
"max_tokens": stage_config.llm_config.max_tokens,
|
||||
"timeout": stage_config.llm_config.timeout_seconds,
|
||||
"extra_body": {"enable_thinking": False},
|
||||
}
|
||||
if stage_config.llm_config.temperature is not None:
|
||||
generate_kwargs["temperature"] = stage_config.llm_config.temperature
|
||||
if stage_config.llm_config.max_tokens is not None:
|
||||
generate_kwargs["max_tokens"] = stage_config.llm_config.max_tokens
|
||||
|
||||
model = OpenAIChatModel(
|
||||
model_name=stage_config.model_code,
|
||||
@@ -423,15 +437,19 @@ class AgentScopeRunner:
|
||||
run_input: RunAgentInput,
|
||||
step_name: str,
|
||||
event_type: str,
|
||||
extra_event: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"stepName": step_name,
|
||||
}
|
||||
if extra_event:
|
||||
payload.update(extra_event)
|
||||
await pipeline.emit(
|
||||
session_id=run_input.thread_id,
|
||||
event={
|
||||
"type": event_type,
|
||||
"threadId": run_input.thread_id,
|
||||
"runId": run_input.run_id,
|
||||
"stepName": step_name,
|
||||
},
|
||||
event=payload,
|
||||
)
|
||||
|
||||
def _resolve_runtime_client_time(
|
||||
|
||||
@@ -52,6 +52,50 @@ class CalendarShareInvitee(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CalendarWriteOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: Literal["create", "update", "delete"] = Field(
|
||||
description="Action type for this operation item."
|
||||
)
|
||||
event_id: str | None = Field(
|
||||
default=None,
|
||||
description="Event id required for update/delete.",
|
||||
)
|
||||
title: str | None = Field(default=None, description="Event title.")
|
||||
description: str | None = Field(default=None, description="Event description.")
|
||||
start_at: str | None = Field(
|
||||
default=None,
|
||||
description="Start time in ISO 8601 with timezone offset.",
|
||||
)
|
||||
end_at: str | None = Field(
|
||||
default=None,
|
||||
description="End time in ISO 8601 with timezone offset.",
|
||||
)
|
||||
event_timezone: str | None = Field(
|
||||
default=None,
|
||||
description="IANA timezone for the event.",
|
||||
)
|
||||
location: str | None = Field(default=None, description="Event location.")
|
||||
color: str | None = Field(default=None, description="Event color.")
|
||||
reminder_minutes: int | None = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
le=10080,
|
||||
description="Reminder minutes before event start.",
|
||||
)
|
||||
status: Literal["active", "completed", "canceled", "archived"] | None = Field(
|
||||
default=None,
|
||||
description="Optional status for update action.",
|
||||
)
|
||||
|
||||
|
||||
class CalendarWriteBatchArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
operations: list[CalendarWriteOperation] = Field(min_length=1, max_length=20)
|
||||
|
||||
|
||||
def _validate_runtime_context(
|
||||
*,
|
||||
tool_name: str,
|
||||
@@ -178,125 +222,48 @@ async def calendar_read(
|
||||
|
||||
async def calendar_write(
|
||||
operations: Annotated[
|
||||
list[Literal["create", "update", "delete"]],
|
||||
list[CalendarWriteOperation],
|
||||
Field(
|
||||
description=(
|
||||
"Batch operations list. Each item must be create, update, or delete."
|
||||
"Batch operation objects. Each item includes action and its fields. "
|
||||
"Use create/update/delete in a single call."
|
||||
),
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
),
|
||||
],
|
||||
event_ids: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(
|
||||
description=(
|
||||
"Optional event id list aligned with operations. "
|
||||
"Required for update/delete item."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
titles: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(description="Optional title list aligned with operations."),
|
||||
] = None,
|
||||
descriptions: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(description="Optional description list aligned with operations."),
|
||||
] = None,
|
||||
start_ats: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(
|
||||
description=(
|
||||
"Optional start time list aligned with operations, ISO 8601 with timezone."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
end_ats: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(
|
||||
description=(
|
||||
"Optional end time list aligned with operations, ISO 8601 with timezone."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
event_timezones: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(
|
||||
description=(
|
||||
"Optional event timezone list aligned with operations, IANA timezone."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
locations: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(description="Optional location list aligned with operations."),
|
||||
] = None,
|
||||
colors: Annotated[
|
||||
list[str | None] | None,
|
||||
Field(description="Optional color list aligned with operations."),
|
||||
] = None,
|
||||
reminder_minutes_list: Annotated[
|
||||
list[int | None] | None,
|
||||
Field(
|
||||
description=(
|
||||
"Optional reminder minutes list aligned with operations, value range 0-10080."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
statuses: Annotated[
|
||||
list[Literal["active", "completed", "canceled", "archived"] | None] | None,
|
||||
Field(description="Optional status list aligned with operations."),
|
||||
] = None,
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Batch create/update/delete calendar events using aligned list parameters.
|
||||
"""Batch create/update/delete calendar events using operation objects.
|
||||
|
||||
Args:
|
||||
operations: Operation list. Length defines batch size.
|
||||
event_ids: Optional event id list aligned with operations.
|
||||
titles: Optional title list aligned with operations.
|
||||
descriptions: Optional description list aligned with operations.
|
||||
start_ats: Optional start time list aligned with operations.
|
||||
end_ats: Optional end time list aligned with operations.
|
||||
event_timezones: Optional event timezone list aligned with operations.
|
||||
locations: Optional location list aligned with operations.
|
||||
colors: Optional color list aligned with operations.
|
||||
reminder_minutes_list: Optional reminder minute list aligned with operations.
|
||||
statuses: Optional status list aligned with operations.
|
||||
|
||||
Constraints:
|
||||
- All provided list parameters must have the same length as operations.
|
||||
- create item requires start_ats[i] and event_timezones[i].
|
||||
- update/delete item requires event_ids[i].
|
||||
- start/end datetime must include timezone offset.
|
||||
operations: Batch operation objects.
|
||||
- create requires start_at and event_timezone.
|
||||
- update/delete requires event_id.
|
||||
- datetime fields must include timezone offset.
|
||||
|
||||
Returns:
|
||||
ToolResponse with serialized ToolAgentOutput payload.
|
||||
"""
|
||||
tool_name = "calendar_write"
|
||||
try:
|
||||
parsed_batch = CalendarWriteBatchArgs.model_validate({"operations": operations})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
code, message, retryable = map_calendar_exception(exc)
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args={"operations": operations},
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
|
||||
def _align_list(name: str, values: list[Any] | None, size: int) -> list[Any | None]:
|
||||
if values is None:
|
||||
return [None] * size
|
||||
if len(values) != size:
|
||||
raise ValueError(f"{name} 长度必须与 operations 一致")
|
||||
return list(values)
|
||||
|
||||
batch_size = len(operations)
|
||||
tool_call_args = {
|
||||
"operations": operations,
|
||||
"event_ids": event_ids,
|
||||
"titles": titles,
|
||||
"descriptions": descriptions,
|
||||
"start_ats": start_ats,
|
||||
"end_ats": end_ats,
|
||||
"event_timezones": event_timezones,
|
||||
"locations": locations,
|
||||
"colors": colors,
|
||||
"reminder_minutes_list": reminder_minutes_list,
|
||||
"statuses": statuses,
|
||||
"operations": [
|
||||
operation.model_dump(mode="json", exclude_none=True)
|
||||
for operation in parsed_batch.operations
|
||||
]
|
||||
}
|
||||
runtime_error = _validate_runtime_context(
|
||||
tool_name=tool_name,
|
||||
@@ -311,40 +278,26 @@ async def calendar_write(
|
||||
service = create_schedule_service(
|
||||
cast(AsyncSession, session), cast(UUID, owner_id)
|
||||
)
|
||||
aligned_event_ids = _align_list("event_ids", event_ids, batch_size)
|
||||
aligned_titles = _align_list("titles", titles, batch_size)
|
||||
aligned_descriptions = _align_list("descriptions", descriptions, batch_size)
|
||||
aligned_start_ats = _align_list("start_ats", start_ats, batch_size)
|
||||
aligned_end_ats = _align_list("end_ats", end_ats, batch_size)
|
||||
aligned_event_timezones = _align_list(
|
||||
"event_timezones", event_timezones, batch_size
|
||||
)
|
||||
aligned_locations = _align_list("locations", locations, batch_size)
|
||||
aligned_colors = _align_list("colors", colors, batch_size)
|
||||
aligned_reminders = _align_list(
|
||||
"reminder_minutes_list", reminder_minutes_list, batch_size
|
||||
)
|
||||
aligned_statuses = _align_list("statuses", statuses, batch_size)
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
success_event_ids: list[str] = []
|
||||
result_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, operation in enumerate(operations):
|
||||
event_id = aligned_event_ids[idx]
|
||||
title = aligned_titles[idx]
|
||||
description = aligned_descriptions[idx]
|
||||
start_at = aligned_start_ats[idx]
|
||||
end_at = aligned_end_ats[idx]
|
||||
event_timezone = aligned_event_timezones[idx]
|
||||
location = aligned_locations[idx]
|
||||
color = aligned_colors[idx]
|
||||
reminder_minutes = aligned_reminders[idx]
|
||||
status = aligned_statuses[idx]
|
||||
for operation in parsed_batch.operations:
|
||||
event_id = operation.event_id
|
||||
title = operation.title
|
||||
description = operation.description
|
||||
start_at = operation.start_at
|
||||
end_at = operation.end_at
|
||||
event_timezone = operation.event_timezone
|
||||
location = operation.location
|
||||
color = operation.color
|
||||
reminder_minutes = operation.reminder_minutes
|
||||
status = operation.status
|
||||
|
||||
try:
|
||||
if operation == "create":
|
||||
if operation.action == "create":
|
||||
if start_at is None or not start_at.strip():
|
||||
raise ValueError(
|
||||
"创建日程需要提供 start_at,且必须包含时区偏移"
|
||||
@@ -385,7 +338,7 @@ async def calendar_write(
|
||||
success_event_ids.append(str(created.id))
|
||||
continue
|
||||
|
||||
if operation == "update":
|
||||
if operation.action == "update":
|
||||
if event_id is None or not event_id.strip():
|
||||
raise ValueError("更新日程需要提供 event_id")
|
||||
parsed_event_id = UUID(event_id)
|
||||
@@ -429,7 +382,7 @@ async def calendar_write(
|
||||
success_event_ids.append(str(updated.id))
|
||||
continue
|
||||
|
||||
if operation == "delete":
|
||||
if operation.action == "delete":
|
||||
if event_id is None or not event_id.strip():
|
||||
raise ValueError("删除日程需要提供 event_id")
|
||||
await service.delete(UUID(event_id))
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.agentscope.tools.utils.tool_response_builder import (
|
||||
build_tool_response,
|
||||
)
|
||||
from models.memories import MemoryType
|
||||
from schemas.agent.runtime_models import ToolAgentOutput, ToolStatus
|
||||
from schemas.agent.runtime_models import ErrorInfo, ToolAgentOutput, ToolStatus
|
||||
from schemas.memories.memory_content import UserMemoryContent, WorkProfileContent
|
||||
|
||||
|
||||
@@ -38,6 +38,12 @@ class MemoryWriteArgs(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class MemoryWriteBatchArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
operations: list[MemoryWriteArgs] = Field(min_length=1, max_length=20)
|
||||
|
||||
|
||||
class MemoryForgetArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@@ -70,6 +76,12 @@ class MemoryForgetArgs(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class MemoryForgetBatchArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
operations: list[MemoryForgetArgs] = Field(min_length=1, max_length=20)
|
||||
|
||||
|
||||
def _memory_error_output(
|
||||
*,
|
||||
tool_name: str,
|
||||
@@ -149,28 +161,45 @@ def _delete_nested_path(payload: dict[str, Any], keys: list[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _compact_result_items(items: list[dict[str, object]]) -> str:
|
||||
return ",".join(
|
||||
"{" + ",".join(f"{key}={value}" for key, value in item.items()) + "}"
|
||||
for item in items
|
||||
)
|
||||
|
||||
|
||||
async def memory_write(
|
||||
memory_type: Annotated[
|
||||
str,
|
||||
Field(description="Memory type: user or work."),
|
||||
] = "user",
|
||||
user_content: Annotated[
|
||||
UserMemoryContent | None,
|
||||
Field(description="Patch payload for user memory content."),
|
||||
] = None,
|
||||
work_content: Annotated[
|
||||
WorkProfileContent | None,
|
||||
Field(description="Patch payload for work memory content."),
|
||||
] = None,
|
||||
operations: Annotated[
|
||||
list[MemoryWriteArgs],
|
||||
Field(
|
||||
description=(
|
||||
"Batch memory write operations. Each item must include memory_type and "
|
||||
"the matching content object (user_content or work_content)."
|
||||
),
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
),
|
||||
],
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Merge structured facts into user/work memory.
|
||||
|
||||
Args:
|
||||
memory_type: Target memory domain, either ``user`` or ``work``.
|
||||
user_content: Partial user-memory payload when ``memory_type='user'``.
|
||||
work_content: Partial work-memory payload when ``memory_type='work'``.
|
||||
|
||||
Runtime:
|
||||
``session`` and ``owner_id`` are injected by toolkit preset kwargs.
|
||||
|
||||
Returns:
|
||||
ToolResponse wrapping ToolAgentOutput.
|
||||
- success: ``result`` contains a compact status summary.
|
||||
- failure: ``error`` contains structured code/message/retryable metadata.
|
||||
"""
|
||||
tool_name = "memory_write"
|
||||
tool_call_args: dict[str, Any] = {
|
||||
"memory_type": memory_type,
|
||||
"user_content": user_content,
|
||||
"work_content": work_content,
|
||||
}
|
||||
tool_call_args: dict[str, Any] = {"operations": operations}
|
||||
runtime_error = _validate_runtime_context(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
@@ -181,52 +210,117 @@ async def memory_write(
|
||||
return runtime_error
|
||||
|
||||
try:
|
||||
parsed_args = MemoryWriteArgs.model_validate(tool_call_args)
|
||||
parsed_batch = MemoryWriteBatchArgs.model_validate(tool_call_args)
|
||||
service = create_memories_service(
|
||||
session=cast(AsyncSession, session),
|
||||
owner_id=cast(UUID, owner_id),
|
||||
)
|
||||
existing = await service.get_memory_model(memory_type=parsed_args.memory_type)
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
updated_types: list[str] = []
|
||||
failed_operations: list[dict[str, object]] = []
|
||||
result_items: list[dict[str, object]] = []
|
||||
for idx, op in enumerate(parsed_batch.operations):
|
||||
try:
|
||||
existing = await service.get_memory_model(memory_type=op.memory_type)
|
||||
if op.memory_type == MemoryType.USER:
|
||||
base_model = (
|
||||
UserMemoryContent.model_validate(existing.content)
|
||||
if existing is not None
|
||||
else UserMemoryContent()
|
||||
)
|
||||
patch_model = cast(UserMemoryContent, op.user_content)
|
||||
merged = _deep_merge_dict(
|
||||
base_model.model_dump(),
|
||||
patch_model.model_dump(exclude_unset=True),
|
||||
)
|
||||
validated = UserMemoryContent.model_validate(merged)
|
||||
updated = await service.update_user_memory(content=validated)
|
||||
else:
|
||||
base_model = (
|
||||
WorkProfileContent.model_validate(existing.content)
|
||||
if existing is not None
|
||||
else WorkProfileContent()
|
||||
)
|
||||
patch_model = cast(WorkProfileContent, op.work_content)
|
||||
merged = _deep_merge_dict(
|
||||
base_model.model_dump(),
|
||||
patch_model.model_dump(exclude_unset=True),
|
||||
)
|
||||
validated = WorkProfileContent.model_validate(merged)
|
||||
updated = await service.update_work_memory(content=validated)
|
||||
|
||||
if parsed_args.memory_type == MemoryType.USER:
|
||||
base_model = (
|
||||
UserMemoryContent.model_validate(existing.content)
|
||||
if existing is not None
|
||||
else UserMemoryContent()
|
||||
)
|
||||
patch_model = cast(UserMemoryContent, parsed_args.user_content)
|
||||
merged = _deep_merge_dict(
|
||||
base_model.model_dump(),
|
||||
patch_model.model_dump(exclude_unset=True),
|
||||
)
|
||||
validated = UserMemoryContent.model_validate(merged)
|
||||
await service.update_user_memory(
|
||||
content=validated,
|
||||
)
|
||||
else:
|
||||
base_model = (
|
||||
WorkProfileContent.model_validate(existing.content)
|
||||
if existing is not None
|
||||
else WorkProfileContent()
|
||||
)
|
||||
patch_model = cast(WorkProfileContent, parsed_args.work_content)
|
||||
merged = _deep_merge_dict(
|
||||
base_model.model_dump(),
|
||||
patch_model.model_dump(exclude_unset=True),
|
||||
)
|
||||
validated = WorkProfileContent.model_validate(merged)
|
||||
await service.update_work_memory(
|
||||
content=validated,
|
||||
)
|
||||
success_count += 1
|
||||
updated_types.append(op.memory_type.value)
|
||||
memory_id = str(
|
||||
getattr(updated, "id", None)
|
||||
or (getattr(existing, "id", None) if existing is not None else "")
|
||||
or ""
|
||||
)
|
||||
result_items.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"memoryType": op.memory_type.value,
|
||||
"status": "success",
|
||||
"memoryId": memory_id,
|
||||
}
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
failed_count += 1
|
||||
code, message, retryable = map_memory_exception(exc)
|
||||
failed_operations.append(
|
||||
{
|
||||
"memory_type": op.memory_type.value,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"retryable": retryable,
|
||||
}
|
||||
)
|
||||
result_items.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"memoryType": op.memory_type.value,
|
||||
"status": "failure",
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
summary = f"status=success memory_type={parsed_args.memory_type.value}"
|
||||
status = (
|
||||
ToolStatus.SUCCESS
|
||||
if failed_count == 0
|
||||
else (ToolStatus.FAILURE if success_count == 0 else ToolStatus.PARTIAL)
|
||||
)
|
||||
status_text = (
|
||||
"success"
|
||||
if status == ToolStatus.SUCCESS
|
||||
else ("failure" if status == ToolStatus.FAILURE else "partial")
|
||||
)
|
||||
|
||||
summary = (
|
||||
f"status={status_text} "
|
||||
f"success={success_count} failed={failed_count} "
|
||||
f"updated_types=[{','.join(updated_types)}]"
|
||||
)
|
||||
compact_items = _compact_result_items(result_items)
|
||||
if compact_items:
|
||||
summary = f"{summary} items=[{compact_items}]"
|
||||
error_info: ErrorInfo | None = None
|
||||
if failed_operations:
|
||||
first = failed_operations[0]
|
||||
error_info = ErrorInfo(
|
||||
code=str(first.get("code") or "MEMORY_BATCH_FAILED"),
|
||||
message=str(first.get("message") or "memory batch write failed"),
|
||||
retryable=bool(first.get("retryable") is True),
|
||||
details={"failed_operations": failed_operations},
|
||||
)
|
||||
return build_tool_response(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
status=status,
|
||||
result=summary,
|
||||
error=error_info,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -241,22 +335,38 @@ async def memory_write(
|
||||
|
||||
|
||||
async def memory_forget(
|
||||
memory_type: Annotated[
|
||||
str,
|
||||
Field(description="Memory type: user or work."),
|
||||
] = "user",
|
||||
forget_paths: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Dot paths to remove from content."),
|
||||
] = None,
|
||||
operations: Annotated[
|
||||
list[MemoryForgetArgs],
|
||||
Field(
|
||||
description=(
|
||||
"Batch memory forget operations. Each item must include memory_type and "
|
||||
"forget_paths."
|
||||
),
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
),
|
||||
],
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Forget selected paths from user/work memory content.
|
||||
|
||||
Args:
|
||||
memory_type: Target memory domain, either ``user`` or ``work``.
|
||||
forget_paths: Dot-path list to remove from memory content.
|
||||
|
||||
Notes:
|
||||
- Path root must belong to the target memory schema.
|
||||
- The tool is idempotent; missing paths are skipped safely.
|
||||
|
||||
Runtime:
|
||||
``session`` and ``owner_id`` are injected by toolkit preset kwargs.
|
||||
|
||||
Returns:
|
||||
ToolResponse wrapping ToolAgentOutput with compact execution summary.
|
||||
"""
|
||||
tool_name = "memory_forget"
|
||||
tool_call_args: dict[str, Any] = {
|
||||
"memory_type": memory_type,
|
||||
"forget_paths": forget_paths or [],
|
||||
}
|
||||
tool_call_args: dict[str, Any] = {"operations": operations}
|
||||
runtime_error = _validate_runtime_context(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
@@ -267,56 +377,120 @@ async def memory_forget(
|
||||
return runtime_error
|
||||
|
||||
try:
|
||||
parsed_args = MemoryForgetArgs.model_validate(tool_call_args)
|
||||
parsed_batch = MemoryForgetBatchArgs.model_validate(tool_call_args)
|
||||
service = create_memories_service(
|
||||
session=cast(AsyncSession, session),
|
||||
owner_id=cast(UUID, owner_id),
|
||||
)
|
||||
existing = await service.get_memory_model(memory_type=parsed_args.memory_type)
|
||||
if existing is None:
|
||||
summary = f"status=success memory_type={parsed_args.memory_type.value} forgotten=0"
|
||||
return build_tool_response(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
)
|
||||
)
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
forgotten_total = 0
|
||||
processed_types: list[str] = []
|
||||
failed_operations: list[dict[str, object]] = []
|
||||
result_items: list[dict[str, object]] = []
|
||||
for idx, op in enumerate(parsed_batch.operations):
|
||||
try:
|
||||
existing = await service.get_memory_model(memory_type=op.memory_type)
|
||||
if existing is None:
|
||||
success_count += 1
|
||||
processed_types.append(op.memory_type.value)
|
||||
result_items.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"memoryType": op.memory_type.value,
|
||||
"status": "success",
|
||||
"forgotten": 0,
|
||||
"memoryId": "",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if parsed_args.memory_type == MemoryType.USER:
|
||||
base_model = UserMemoryContent.model_validate(existing.content)
|
||||
updated_dict, removed_paths = _remove_content_paths(
|
||||
base_model.model_dump(),
|
||||
parsed_args.forget_paths,
|
||||
)
|
||||
validated = UserMemoryContent.model_validate(updated_dict)
|
||||
await service.update_user_memory(
|
||||
content=validated,
|
||||
)
|
||||
else:
|
||||
base_model = WorkProfileContent.model_validate(existing.content)
|
||||
updated_dict, removed_paths = _remove_content_paths(
|
||||
base_model.model_dump(),
|
||||
parsed_args.forget_paths,
|
||||
)
|
||||
validated = WorkProfileContent.model_validate(updated_dict)
|
||||
await service.update_work_memory(
|
||||
content=validated,
|
||||
)
|
||||
if op.memory_type == MemoryType.USER:
|
||||
base_model = UserMemoryContent.model_validate(existing.content)
|
||||
updated_dict, removed_paths = _remove_content_paths(
|
||||
base_model.model_dump(),
|
||||
op.forget_paths,
|
||||
)
|
||||
validated = UserMemoryContent.model_validate(updated_dict)
|
||||
await service.update_user_memory(content=validated)
|
||||
else:
|
||||
base_model = WorkProfileContent.model_validate(existing.content)
|
||||
updated_dict, removed_paths = _remove_content_paths(
|
||||
base_model.model_dump(),
|
||||
op.forget_paths,
|
||||
)
|
||||
validated = WorkProfileContent.model_validate(updated_dict)
|
||||
await service.update_work_memory(content=validated)
|
||||
|
||||
forgotten_total += len(removed_paths)
|
||||
success_count += 1
|
||||
processed_types.append(op.memory_type.value)
|
||||
result_items.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"memoryType": op.memory_type.value,
|
||||
"status": "success",
|
||||
"forgotten": len(removed_paths),
|
||||
"memoryId": str(getattr(existing, "id", "") or ""),
|
||||
}
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
failed_count += 1
|
||||
code, message, retryable = map_memory_exception(exc)
|
||||
failed_operations.append(
|
||||
{
|
||||
"memory_type": op.memory_type.value,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"retryable": retryable,
|
||||
}
|
||||
)
|
||||
result_items.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"memoryType": op.memory_type.value,
|
||||
"status": "failure",
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
status = (
|
||||
ToolStatus.SUCCESS
|
||||
if failed_count == 0
|
||||
else (ToolStatus.FAILURE if success_count == 0 else ToolStatus.PARTIAL)
|
||||
)
|
||||
status_text = (
|
||||
"success"
|
||||
if status == ToolStatus.SUCCESS
|
||||
else ("failure" if status == ToolStatus.FAILURE else "partial")
|
||||
)
|
||||
|
||||
summary = (
|
||||
f"status=success memory_type={parsed_args.memory_type.value} forgotten={len(removed_paths)} "
|
||||
f"skipped=0"
|
||||
f"status={status_text} "
|
||||
f"success={success_count} failed={failed_count} "
|
||||
f"forgotten={forgotten_total} "
|
||||
f"processed_types=[{','.join(processed_types)}]"
|
||||
)
|
||||
compact_items = _compact_result_items(result_items)
|
||||
if compact_items:
|
||||
summary = f"{summary} items=[{compact_items}]"
|
||||
error_info: ErrorInfo | None = None
|
||||
if failed_operations:
|
||||
first = failed_operations[0]
|
||||
error_info = ErrorInfo(
|
||||
code=str(first.get("code") or "MEMORY_BATCH_FAILED"),
|
||||
message=str(first.get("message") or "memory batch forget failed"),
|
||||
retryable=bool(first.get("retryable") is True),
|
||||
details={"failed_operations": failed_operations},
|
||||
)
|
||||
return build_tool_response(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
status=status,
|
||||
result=summary,
|
||||
error=error_info,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
|
||||
@@ -83,8 +83,10 @@ async def _dispatch_automation_run(
|
||||
"content": input_text,
|
||||
}
|
||||
],
|
||||
"tools": [],
|
||||
"context": [],
|
||||
"forwardedProps": {
|
||||
"runtimeMode": RuntimeMode.AUTOMATION.value,
|
||||
"runtime_mode": RuntimeMode.AUTOMATION.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
input_template: 请基于最近两天用户聊天上下文提取用户记忆;如果已有记忆内容变化请更新;如果记忆已失效请执行遗忘。
|
||||
input_template: |
|
||||
你正在执行自动化记忆提取任务。必须只使用 memory_forget 与 memory_write,不要执行任何 calendar 或 user_lookup 工具。
|
||||
步骤1:基于最近两天聊天上下文,抽取“有证据支持”的用户长期偏好变化,禁止编造。
|
||||
步骤2:对已失效或被用户明确否定的信息,调用 memory_forget 执行遗忘。
|
||||
步骤3:对新增或变化的信息,调用 memory_write 执行写入。
|
||||
步骤4:两类工具都必须使用批量参数 operations(对象数组),并保证参数是结构化 JSON,不要把数组或对象写成字符串。
|
||||
步骤5:只写入被证据覆盖的最小字段集;无证据字段不要写。
|
||||
输出要求:仅基于工具结果给出一句执行摘要(包含 success/failed 计数)。
|
||||
enabled_tools:
|
||||
- memory.write
|
||||
- memory.forget
|
||||
|
||||
@@ -61,3 +61,14 @@ llms:
|
||||
input_cost_per_token: 0.000002
|
||||
output_cost_per_token: 0.000003
|
||||
cache_hit_cost_per_token: 0.0000002
|
||||
|
||||
- model_code: qwen3.5-27b
|
||||
factory_name: dashscope
|
||||
litellm_model: dashscope/qwen3.5-27b
|
||||
pricing_tiers:
|
||||
- max_prompt_tokens: 128000
|
||||
input_cost_per_token: 0.0000006
|
||||
output_cost_per_token: 0.0000048
|
||||
- max_prompt_tokens: 256000
|
||||
input_cost_per_token: 0.0000018
|
||||
output_cost_per_token: 0.0000144
|
||||
|
||||
@@ -32,10 +32,6 @@ class Memory(TimestampMixin, Base):
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
agent_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
)
|
||||
memory_type: Mapped[MemoryType] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
|
||||
@@ -59,6 +59,10 @@ class AutomationJob(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def is_system(self) -> bool:
|
||||
return self.bootstrap_key is not None
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: OrmAutomationJob) -> "AutomationJob":
|
||||
return cls(
|
||||
|
||||
@@ -34,7 +34,6 @@ class MemoryModel(BaseModel):
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
agent_id: UUID | None = None
|
||||
memory_type: Literal["user", "work"]
|
||||
content: UserMemoryContent | WorkProfileContent
|
||||
status: MemoryStatus
|
||||
|
||||
@@ -16,6 +16,7 @@ from core.agentscope.schemas.agui_input import (
|
||||
)
|
||||
from core.auth.models import CurrentUser
|
||||
from core.logging import get_logger
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -180,7 +181,7 @@ async def stream_events(
|
||||
last_event_id=cursor,
|
||||
current_user=current_user,
|
||||
)
|
||||
except TimeoutError:
|
||||
except (TimeoutError, RedisTimeoutError):
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, time
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.automation_jobs import AutomationJob as OrmAutomationJob
|
||||
from models.automation_jobs import AutomationJobStatus, ScheduleType
|
||||
from schemas.automation import (
|
||||
AutomationJobConfig,
|
||||
)
|
||||
|
||||
|
||||
class AutomationJobResponse(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID
|
||||
bootstrap_key: str | None = None
|
||||
title: str
|
||||
schedule_type: ScheduleType
|
||||
run_at: time
|
||||
timezone: str
|
||||
status: AutomationJobStatus
|
||||
is_system: bool
|
||||
config: AutomationJobConfig
|
||||
next_run_at: datetime
|
||||
last_run_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: OrmAutomationJob) -> Self:
|
||||
return cls(
|
||||
id=obj.id,
|
||||
owner_id=obj.owner_id,
|
||||
bootstrap_key=obj.bootstrap_key,
|
||||
title=obj.title,
|
||||
schedule_type=obj.schedule_type,
|
||||
run_at=obj.run_at.time(),
|
||||
timezone=obj.timezone,
|
||||
status=obj.status,
|
||||
is_system=obj.bootstrap_key is not None,
|
||||
config=AutomationJobConfig.model_validate(obj.config or {}),
|
||||
next_run_at=obj.next_run_at,
|
||||
last_run_at=obj.last_run_at,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class AutomationJobCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
title: str = Field(..., min_length=1, max_length=255)
|
||||
schedule_type: ScheduleType
|
||||
run_at: time = Field(..., description="Local time in HH:MM:SS format")
|
||||
timezone: str = Field(..., min_length=1, max_length=50)
|
||||
status: AutomationJobStatus = Field(default=AutomationJobStatus.ACTIVE)
|
||||
config: AutomationJobConfig
|
||||
|
||||
|
||||
class AutomationJobUpdateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
title: str | None = Field(None, min_length=1, max_length=255)
|
||||
schedule_type: ScheduleType | None = None
|
||||
run_at: time | None = None
|
||||
timezone: str | None = Field(None, min_length=1, max_length=50)
|
||||
status: AutomationJobStatus | None = None
|
||||
config: AutomationJobConfig | None = None
|
||||
|
||||
|
||||
class AutomationJobListResponse(BaseModel):
|
||||
items: list[AutomationJobResponse]
|
||||
Reference in New Issue
Block a user