feat: 增强日历功能并集成 AgentScope 代理服务

This commit is contained in:
qzl
2026-03-11 17:16:11 +08:00
parent e20e7d2a02
commit 85b314cf64
53 changed files with 3642 additions and 297 deletions
@@ -83,7 +83,23 @@ def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
color = tool_args.get("color")
raw_color = color.strip() if isinstance(color, str) and color.strip() else "#4F46E5"
color_value = raw_color if _HEX_COLOR_PATTERN.match(raw_color) else "#4F46E5"
return ScheduleItemMetadata(location=location_value, color=color_value)
reminder_raw = tool_args.get("reminderMinutes")
reminder_value: int | None = None
if isinstance(reminder_raw, bool):
reminder_value = None
elif isinstance(reminder_raw, (int, float, str)):
try:
parsed = int(str(reminder_raw).strip())
if parsed < 0 or parsed > 10080:
raise ValueError("reminderMinutes must be 0..10080")
reminder_value = parsed
except ValueError as exc:
raise ValueError("reminderMinutes must be an integer in 0..10080") from exc
return ScheduleItemMetadata(
location=location_value,
color=color_value,
reminder_minutes=reminder_value,
)
def _event_payload(event: object) -> dict[str, object]:
@@ -91,6 +107,7 @@ def _event_payload(event: object) -> dict[str, object]:
metadata = getattr(event, "metadata", None)
location_value = getattr(metadata, "location", None)
color_value = getattr(metadata, "color", None) or "#4F46E5"
reminder_minutes_value = getattr(metadata, "reminder_minutes", None)
return {
"id": event_id,
"title": getattr(event, "title"),
@@ -104,6 +121,7 @@ def _event_payload(event: object) -> dict[str, object]:
"timezone": getattr(event, "timezone"),
"location": location_value,
"color": color_value,
"reminderMinutes": reminder_minutes_value,
}
@@ -221,7 +239,8 @@ async def _execute_update(
) from exc
has_location = isinstance(tool_args.get("location"), str)
has_color = isinstance(tool_args.get("color"), str)
if has_location or has_color:
has_reminder = "reminderMinutes" in tool_args
if has_location or has_color or has_reminder:
existing = await service.get_by_id(event_id)
metadata_dump = (
existing.metadata.model_dump() if existing.metadata is not None else {}
@@ -236,6 +255,22 @@ async def _execute_update(
metadata_dump["color"] = color
else:
raise ValueError("color must be a hex string like #RRGGBB")
if has_reminder:
reminder_raw = tool_args.get("reminderMinutes")
if reminder_raw is None:
metadata_dump["reminder_minutes"] = None
elif isinstance(reminder_raw, bool):
raise ValueError("reminderMinutes must be an integer in 0..10080")
else:
try:
reminder = int(str(reminder_raw).strip())
except ValueError as exc:
raise ValueError(
"reminderMinutes must be an integer in 0..10080"
) from exc
if reminder < 0 or reminder > 10080:
raise ValueError("reminderMinutes must be 0..10080")
metadata_dump["reminder_minutes"] = reminder
update_data["metadata"] = ScheduleItemMetadata.model_validate(metadata_dump)
updated = await service.update(
@@ -0,0 +1,14 @@
from core.agentscope.events.agui_codec import AgentScopeAgUiCodec, to_agui_wire_event
from core.agentscope.events.pipeline import AgentScopeEventPipeline
from core.agentscope.events.redis_bus import RedisStreamBus
from core.agentscope.events.sse import to_sse_event
from core.agentscope.events.store import NullEventStore
__all__ = [
"AgentScopeAgUiCodec",
"AgentScopeEventPipeline",
"RedisStreamBus",
"NullEventStore",
"to_agui_wire_event",
"to_sse_event",
]
@@ -0,0 +1,49 @@
from __future__ import annotations
from typing import Any, cast
_TYPE_MAP: dict[str, str] = {
"run.started": "RUN_STARTED",
"run.finished": "RUN_FINISHED",
"run.error": "RUN_ERROR",
"step.start": "STEP_STARTED",
"step.finish": "STEP_FINISHED",
"text.start": "TEXT_MESSAGE_START",
"text.delta": "TEXT_MESSAGE_CONTENT",
"text.end": "TEXT_MESSAGE_END",
"tool.start": "TOOL_CALL_START",
"tool.args": "TOOL_CALL_ARGS",
"tool.end": "TOOL_CALL_END",
"tool.result": "TOOL_CALL_RESULT",
"tool.error": "TOOL_CALL_ERROR",
"state.snapshot": "STATE_SNAPSHOT",
"messages.snapshot": "MESSAGES_SNAPSHOT",
}
def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
event_type = str(event.get("type", "")).strip()
wire_type = _TYPE_MAP.get(event_type, event_type.upper().replace(".", "_"))
payload: dict[str, Any] = {
"type": wire_type,
}
thread_id = event.get("threadId")
run_id = event.get("runId")
if isinstance(thread_id, str) and thread_id:
payload["threadId"] = thread_id
if isinstance(run_id, str) and run_id:
payload["runId"] = run_id
data = event.get("data")
if isinstance(data, dict):
reserved = {"type", "threadId", "runId"}
data_map = cast(dict[str, Any], data)
payload.update({k: v for k, v in data_map.items() if k not in reserved})
return payload
class AgentScopeAgUiCodec:
def to_wire(self, event: dict[str, Any]) -> dict[str, Any]:
return to_agui_wire_event(event)
@@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Any, Protocol
class CodecLike(Protocol):
def to_wire(self, event: dict[str, Any]) -> dict[str, Any]: ...
class StoreLike(Protocol):
async def persist(self, event: dict[str, Any]) -> None: ...
class BusLike(Protocol):
async def publish(self, *, session_id: str, event: dict[str, Any]) -> str: ...
class AgentScopeEventPipeline:
_codec: CodecLike
_store: StoreLike
_bus: BusLike
def __init__(self, *, codec: CodecLike, store: StoreLike, bus: BusLike) -> None:
self._codec = codec
self._store = store
self._bus = bus
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str:
wire_event = self._codec.to_wire(event)
await self._store.persist(wire_event)
return await self._bus.publish(session_id=session_id, event=wire_event)
@@ -0,0 +1,91 @@
from __future__ import annotations
import inspect
import json
from typing import Any, Protocol, cast
class RedisStreamClient(Protocol):
def xadd(self, *args: Any, **kwargs: Any) -> Any: ...
def xread(self, *args: Any, **kwargs: Any) -> Any: ...
class RedisStreamBus:
_client: RedisStreamClient
_stream_prefix: str
_read_count: int
_block_ms: int
def __init__(
self,
*,
client: RedisStreamClient,
stream_prefix: str,
read_count: int = 100,
block_ms: int = 5000,
) -> None:
self._client = client
self._stream_prefix = stream_prefix
self._read_count = read_count
self._block_ms = block_ms
async def publish(self, *, session_id: str, event: dict[str, Any]) -> str:
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
result = self._client.xadd(self._stream_name(session_id), {"event": payload})
if inspect.isawaitable(result):
return str(await result)
return str(result)
async def read(
self,
*,
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
stream = self._stream_name(session_id)
start_id = "0-0" if last_event_id is None else last_event_id
raw = self._client.xread(
{stream: start_id},
count=self._read_count,
block=self._block_ms,
)
response = await raw if inspect.isawaitable(raw) else raw
if not response:
return []
first = response[0]
if (
not isinstance(first, tuple)
or len(first) != 2
or not isinstance(first[1], list)
):
return []
entries = cast(list[tuple[str, dict[str, Any]]], first[1])
rows: list[dict[str, Any]] = []
for entry in entries:
if (
not isinstance(entry, tuple)
or len(entry) != 2
or not isinstance(entry[0], str)
or not isinstance(entry[1], dict)
):
continue
payload_map = cast(dict[str, Any], entry[1])
event_payload = payload_map.get("event")
if isinstance(event_payload, bytes):
event_payload = event_payload.decode("utf-8", errors="replace")
if not isinstance(event_payload, str):
continue
try:
decoded = json.loads(event_payload)
except (TypeError, ValueError):
continue
if not isinstance(decoded, dict):
continue
rows.append({"id": entry[0], "event": decoded})
return rows
def _stream_name(self, session_id: str) -> str:
return f"{self._stream_prefix}:{session_id}"
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
import json
import re
from typing import Any
from ag_ui.core.events import BaseEvent
from ag_ui.encoder.encoder import EventEncoder
_EVENT_TYPE_RE = re.compile(r"^[A-Z0-9_]+$")
_ENCODER = EventEncoder()
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
safe_stream_id = str(stream_id).replace("\r", "").replace("\n", "")
try:
event_model = BaseEvent.model_validate(event)
event_type = event_model.type.value
encoded_data = _ENCODER.encode(event_model)
return f"id: {safe_stream_id}\nevent: {event_type}\n{encoded_data}"
except Exception: # noqa: BLE001
raw_event_type = (
str(event.get("type", "MESSAGE")).replace("\r", "").replace("\n", "")
)
event_type = (
raw_event_type if _EVENT_TYPE_RE.fullmatch(raw_event_type) else "MESSAGE"
)
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
return f"id: {safe_stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
@@ -0,0 +1,12 @@
from __future__ import annotations
from typing import Any, Protocol
class EventStore(Protocol):
async def persist(self, event: dict[str, Any]) -> None: ...
class NullEventStore:
async def persist(self, event: dict[str, Any]) -> None:
del event
@@ -1,4 +1,9 @@
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
from core.agentscope.runtime.react_runner import AgentScopeReActRunner
__all__ = ["AgentScopeRuntimeOrchestrator", "AgentScopeReActRunner"]
__all__ = [
"AgentRouteRuntime",
"AgentScopeRuntimeOrchestrator",
"AgentScopeReActRunner",
]
@@ -0,0 +1,215 @@
from __future__ import annotations
from typing import Any, Protocol
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.domain.user_context import UserAgentContext
from core.logging import get_logger
from core.agentscope.schemas import RuntimeOutput
from core.agentscope.schemas.agent_runtime import ResumeCommand, RunCommand
class OrchestratorLike(Protocol):
async def run(
self,
*,
session: AsyncSession,
owner_id: UUID,
user_token: str,
user_context: UserAgentContext,
user_input: str | list[dict[str, Any]],
) -> RuntimeOutput: ...
class PipelineLike(Protocol):
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str: ...
class AgentRouteRuntime:
_orchestrator: OrchestratorLike
_pipeline: PipelineLike
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
def __init__(
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
) -> None:
self._orchestrator = orchestrator
self._pipeline = pipeline
async def run(
self,
*,
command: RunCommand,
owner_id: UUID,
user_token: str,
user_context: UserAgentContext,
session: AsyncSession,
) -> RuntimeOutput:
return await self._execute(
command=command,
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
session=session,
)
async def resume(
self,
*,
command: ResumeCommand,
owner_id: UUID,
user_token: str,
user_context: UserAgentContext,
session: AsyncSession,
) -> RuntimeOutput:
return await self._execute(
command=command,
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
session=session,
)
async def _execute(
self,
*,
command: RunCommand,
owner_id: UUID,
user_token: str,
user_context: UserAgentContext,
session: AsyncSession,
) -> RuntimeOutput:
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "run.started",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.start",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "intent"},
},
)
try:
result = await self._orchestrator.run(
session=session,
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
user_input=command.messages,
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "intent"},
},
)
except Exception: # noqa: BLE001
self._logger.exception(
"agentscope runtime execution failed",
thread_id=command.thread_id,
run_id=command.run_id,
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "run.error",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"message": "runtime execution failed"},
},
)
raise
if result.execution is not None:
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.start",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "execution"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.start",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "report"},
},
)
report_message_id = f"assistant-{command.run_id}"
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.start",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"messageId": report_message_id, "role": "assistant"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.delta",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {
"messageId": report_message_id,
"delta": result.report.assistant_text,
},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "text.end",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"messageId": report_message_id},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "step.finish",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {"stepName": "report"},
},
)
await self._pipeline.emit(
session_id=command.thread_id,
event={
"type": "run.finished",
"threadId": command.thread_id,
"runId": command.run_id,
"data": {},
},
)
return result
@@ -0,0 +1,138 @@
from __future__ import annotations
from typing import Any
from uuid import UUID
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from core.agentscope.events import (
AgentScopeAgUiCodec,
AgentScopeEventPipeline,
NullEventStore,
RedisStreamBus,
)
from core.agentscope.runtime import AgentRouteRuntime, AgentScopeRuntimeOrchestrator
from core.agentscope.schemas.agent_runtime import ResumeCommand, RunCommand
from core.config.settings import config
from core.db.session import AsyncSessionLocal
from core.logging import get_logger
from core.taskiq.app import bulk_broker, critical_broker, default_broker
from services.base.redis import get_or_init_redis_client
logger = get_logger("core.agentscope.runtime.tasks")
def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentContext:
forwarded = (
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
)
username = str(forwarded.get("username", "user")).strip() or "user"
bio_value = forwarded.get("bio")
bio = str(bio_value).strip() if isinstance(bio_value, str) else None
profile_settings = forwarded.get("profileSettings")
settings_raw = profile_settings if isinstance(profile_settings, dict) else None
return UserAgentContext(
user_id=owner_id,
username=username,
bio=bio,
settings=parse_profile_settings(settings_raw),
)
def _extract_user_token(
*, command: dict[str, Any], run_input: RunCommand
) -> str | None:
raw_token = command.get("user_token")
if isinstance(raw_token, str) and raw_token.strip():
return raw_token.strip()
forwarded = (
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
)
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
command_type = str(command.get("command", "run")).strip().lower()
raw_run_input = command.get("run_input")
raw_owner_id = command.get("owner_id")
if not isinstance(raw_run_input, dict):
raise ValueError("run_input is required")
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
raise ValueError("owner_id is required")
owner_id = UUID(raw_owner_id)
parsed_run_input = (
ResumeCommand.model_validate(raw_run_input)
if command_type == "resume"
else RunCommand.model_validate(raw_run_input)
)
user_context = _build_user_context(owner_id=owner_id, run_input=parsed_run_input)
user_token = _extract_user_token(command=command, run_input=parsed_run_input) or ""
redis_client = await get_or_init_redis_client()
bus = RedisStreamBus(
client=redis_client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
pipeline = AgentScopeEventPipeline(
codec=AgentScopeAgUiCodec(),
store=NullEventStore(),
bus=bus,
)
runtime = AgentRouteRuntime(
orchestrator=AgentScopeRuntimeOrchestrator(),
pipeline=pipeline,
)
async with AsyncSessionLocal() as session:
if command_type == "resume":
await runtime.resume(
command=ResumeCommand.model_validate(raw_run_input),
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
session=session,
)
elif command_type == "run":
await runtime.run(
command=RunCommand.model_validate(raw_run_input),
owner_id=owner_id,
user_token=user_token,
user_context=user_context,
session=session,
)
else:
raise ValueError("invalid command type")
logger.info(
"agentscope runtime task completed",
command_type=command_type,
thread_id=parsed_run_input.thread_id,
run_id=parsed_run_input.run_id,
)
return {
"thread_id": parsed_run_input.thread_id,
"run_id": parsed_run_input.run_id,
"status": "completed",
}
@default_broker.task(task_name="tasks.agentscope.run_command")
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
return await run_agentscope_task(command)
@critical_broker.task(task_name="tasks.agentscope.run_command.critical")
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
return await run_agentscope_task(command)
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
return await run_agentscope_task(command)
@@ -1,13 +1,31 @@
from core.agentscope.schemas.agent_runtime import (
AcceptedTaskResponse,
AgUiWireEvent,
HistorySnapshotResponse,
InternalRuntimeEvent,
ResumeCommand,
RunCommand,
TaskAccepted,
TaskAcceptedResponse,
)
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
from core.agentscope.schemas.intent import IntentOutput, IntentTask
from core.agentscope.schemas.report import ReportOutput
from core.agentscope.schemas.runtime import RuntimeOutput
__all__ = [
"AgUiWireEvent",
"AcceptedTaskResponse",
"ExecutionBatchOutput",
"ExecutionTaskOutput",
"HistorySnapshotResponse",
"IntentOutput",
"IntentTask",
"InternalRuntimeEvent",
"ReportOutput",
"ResumeCommand",
"RuntimeOutput",
"RunCommand",
"TaskAccepted",
"TaskAcceptedResponse",
]
@@ -0,0 +1,68 @@
from __future__ import annotations
from typing import Any, ClassVar, Literal
from pydantic import BaseModel, ConfigDict, Field
class _AliasModel(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(
populate_by_name=True, serialize_by_alias=True, extra="forbid"
)
class AcceptedTaskResponse(_AliasModel):
task_id: str = Field(alias="taskId", min_length=1)
thread_id: str = Field(alias="threadId", min_length=1)
run_id: str = Field(alias="runId", min_length=1)
created: bool
class RunCommand(_AliasModel):
thread_id: str = Field(alias="threadId", min_length=1)
run_id: str = Field(alias="runId", min_length=1)
state: dict[str, Any] | None = None
messages: list[dict[str, Any]] = Field(default_factory=list)
tools: list[dict[str, Any]] = Field(default_factory=list)
context: dict[str, Any] = Field(default_factory=dict)
forwarded_props: dict[str, Any] = Field(
default_factory=dict, alias="forwardedProps"
)
class ResumeCommand(RunCommand):
pass
# Backward compatibility alias during migration.
TaskAcceptedResponse = AcceptedTaskResponse
TaskAccepted = AcceptedTaskResponse
class InternalRuntimeEvent(_AliasModel):
type: str = Field(min_length=1)
thread_id: str | None = Field(default=None, alias="threadId")
run_id: str | None = Field(default=None, alias="runId")
data: dict[str, Any] = Field(default_factory=dict)
class AgUiWireEvent(_AliasModel):
type: str = Field(min_length=1)
thread_id: str | None = Field(default=None, alias="threadId")
run_id: str | None = Field(default=None, alias="runId")
payload: Any = None
class HistorySnapshot(_AliasModel):
scope: Literal["history_day"] = "history_day"
thread_id: str | None = Field(default=None, alias="threadId")
day: str | None = None
has_more: bool = Field(default=False, alias="hasMore")
messages: list[dict[str, Any]] = Field(default_factory=list)
class HistorySnapshotResponse(_AliasModel):
type: Literal["STATE_SNAPSHOT"] = "STATE_SNAPSHOT"
thread_id: str | None = Field(default=None, alias="threadId")
run_id: str | None = Field(default=None, alias="runId")
snapshot: HistorySnapshot
@@ -134,6 +134,14 @@ async def calendar_write(
str | None,
Field(description="Event color value, for example #4F46E5."),
] = None,
reminder_minutes: Annotated[
int | None,
Field(
description="Minutes before start time to trigger reminder (0-10080).",
ge=0,
le=10080,
),
] = None,
status: Annotated[
Literal["active", "completed", "canceled", "archived"] | None,
Field(description="Event status: active, completed, canceled, or archived."),
@@ -158,6 +166,7 @@ async def calendar_write(
timezone: Event timezone.
location: Event location.
color: Event color.
reminder_minutes: Reminder minutes before event start.
status: Event lifecycle status.
replace: Replace-strategy flag for conflict handling.
session: Runtime-injected database session.
@@ -193,6 +202,12 @@ async def calendar_write(
return build_tool_response(
_invalid_argument_response(message="timezone length must be <= 50")
)
if reminder_minutes is not None and (
reminder_minutes < 0 or reminder_minutes > 10080
):
return build_tool_response(
_invalid_argument_response(message="reminder_minutes must be 0..10080")
)
if session is None or owner_id is None:
raise ValueError("calendar.write missing runtime preset arguments")
@@ -221,6 +236,8 @@ async def calendar_write(
tool_args["location"] = location
if color is not None:
tool_args["color"] = color
if reminder_minutes is not None:
tool_args["reminderMinutes"] = reminder_minutes
if status is not None:
tool_args["status"] = status
+13 -17
View File
@@ -2,21 +2,20 @@ from __future__ import annotations
import asyncio
from typing import Any
from uuid import UUID
from fastapi import Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from core.agent.infrastructure.queue.tasks import (
from core.agentscope.events import RedisStreamBus
from core.agentscope.runtime.tasks import (
run_command_task,
run_command_task_bulk,
run_command_task_critical,
)
from core.agent.infrastructure.storage.tool_result_storage import (
create_tool_result_storage,
)
from core.config.settings import config
from core.db import get_db
from services.base.redis import get_or_init_redis_client
@@ -84,18 +83,18 @@ class TaskiqQueueClient:
class RedisEventStream:
def __init__(self) -> None:
self._store: RedisStreamEventStore | None = None
self._bus: RedisStreamBus | None = None
async def _get_store(self) -> RedisStreamEventStore:
if self._store is None:
async def _get_bus(self) -> RedisStreamBus:
if self._bus is None:
client = await get_or_init_redis_client()
self._store = RedisStreamEventStore(
self._bus = RedisStreamBus(
client=client,
stream_prefix=config.agent_runtime.redis_stream_prefix,
read_count=config.agent_runtime.redis_stream_read_count,
block_ms=config.agent_runtime.redis_stream_block_ms,
)
return self._store
return self._bus
async def read(
self,
@@ -103,12 +102,9 @@ class RedisEventStream:
session_id: str,
last_event_id: str | None,
) -> list[dict[str, Any]]:
store = await self._get_store()
rows = await store.read_events(
session_id=UUID(session_id),
last_event_id=last_event_id,
)
return [{**row, "cursor": last_event_id} for row in rows]
bus = await self._get_bus()
rows = await bus.read(session_id=session_id, last_event_id=last_event_id)
return [{**row, "cursor": row.get("id")} for row in rows]
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
+1 -1
View File
@@ -14,7 +14,7 @@ from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFi
from fastapi import HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from core.agent.infrastructure.agui.stream import to_sse_event
from core.agentscope.events import to_sse_event
from core.agent.domain.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
+19
View File
@@ -18,6 +18,17 @@ from core.logging import get_logger
logger = get_logger(__name__)
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
forwarded = run_input.forwarded_props
if not isinstance(forwarded, dict):
return None
for key in ("accessToken", "userToken", "token"):
value = forwarded.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return None
@dataclass(frozen=True)
class TaskAccepted:
task_id: str
@@ -65,6 +76,10 @@ def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
class AgentService:
_repository: AgentRepositoryLike
_queue: QueueClientLike
_stream: EventStreamLike
def __init__(
self,
*,
@@ -107,6 +122,8 @@ class AgentService:
task_id = await self._queue.enqueue(
command={
"command": "run",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=None,
@@ -132,6 +149,8 @@ class AgentService:
task_id = await self._queue.enqueue(
command={
"command": "resume",
"owner_id": str(current_user.id),
"user_token": _extract_user_token_from_run_input(run_input),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
+1
View File
@@ -32,6 +32,7 @@ class ScheduleItemMetadata(BaseModel):
location: str | None = None
notes: str | None = None
attachments: list[ScheduleItemMetadataAttachment] = Field(default_factory=list)
reminder_minutes: int | None = Field(default=None, ge=0, le=10080)
version: Literal[1] = 1
+2 -3
View File
@@ -135,14 +135,13 @@ class ScheduleItemService(BaseService):
update_data = request.model_dump(exclude_unset=True)
# Handle metadata separately (model_dump returns dict)
if "metadata" in update_data and update_data["metadata"] is not None:
metadata_value = update_data["metadata"]
if "metadata" in update_data:
metadata_value = update_data.pop("metadata")
update_data["extra_metadata"] = (
metadata_value.model_dump()
if hasattr(metadata_value, "model_dump")
else metadata_value
)
del update_data["metadata"]
# Validate time range
next_start = update_data.get("start_at", existing.start_at)
@@ -68,6 +68,8 @@ async def _invoke_tool(
else:
text = getattr(first, "text", None)
assert isinstance(text, str)
if text.startswith("Error:"):
raise AssertionError(f"tool {tool_name} failed: {text}")
payload = json.loads(text)
assert isinstance(payload, dict)
return payload
@@ -101,40 +103,45 @@ class _SmokeRunner:
if stage_config.stage == "execution":
assert toolkit is not None
created = await _invoke_tool(
toolkit,
tool_name="calendar.write",
tool_input={
"operation": "create",
"title": "agentscope smoke event",
"description": "agentscope runtime smoke",
"start_at": datetime.now(timezone.utc).isoformat(),
"timezone": "Asia/Shanghai",
},
)
created_data = created.get("data")
assert isinstance(created_data, dict)
created_id = created_data.get("id")
assert isinstance(created_id, str) and created_id
created_id: str | None = None
items: list[object] = []
try:
created = await _invoke_tool(
toolkit,
tool_name="calendar.write",
tool_input={
"operation": "create",
"title": "agentscope smoke event",
"description": "agentscope runtime smoke",
"start_at": datetime.now(timezone.utc).isoformat(),
"timezone": "Asia/Shanghai",
},
)
created_data = created.get("data")
assert isinstance(created_data, dict)
created_id = created_data.get("id")
assert isinstance(created_id, str) and created_id
read_payload = await _invoke_tool(
toolkit,
tool_name="calendar.read",
tool_input={"page": 1, "page_size": 10},
)
read_data = read_payload.get("data")
assert isinstance(read_data, dict)
items = read_data.get("items")
assert isinstance(items, list)
deleted = await _invoke_tool(
toolkit,
tool_name="calendar.write",
tool_input={"operation": "delete", "event_id": created_id},
)
deleted_data = deleted.get("data")
assert isinstance(deleted_data, dict)
assert deleted_data.get("ok") is True
read_payload = await _invoke_tool(
toolkit,
tool_name="calendar.read",
tool_input={"page": 1, "page_size": 10},
)
read_data = read_payload.get("data")
assert isinstance(read_data, dict)
parsed_items = read_data.get("items")
assert isinstance(parsed_items, list)
items = parsed_items
finally:
if created_id:
deleted = await _invoke_tool(
toolkit,
tool_name="calendar.write",
tool_input={"operation": "delete", "event_id": created_id},
)
deleted_data = deleted.get("data")
assert isinstance(deleted_data, dict)
assert deleted_data.get("ok") is True
return {
"task_id": "smoke-task-1",
@@ -25,6 +25,8 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
async def create_agent_generated(self, payload):
assert payload.title == "晨会"
assert payload.metadata is not None
assert payload.metadata.reminder_minutes == 15
return SimpleNamespace(
id=created_id,
title="晨会",
@@ -32,7 +34,11 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
end_at=datetime(2026, 3, 8, 2, 0, tzinfo=timezone.utc),
timezone="Asia/Shanghai",
metadata=SimpleNamespace(location="会议室A", color="#4F46E5"),
metadata=SimpleNamespace(
location="会议室A",
color="#4F46E5",
reminder_minutes=15,
),
)
class _FakeRepository:
@@ -61,6 +67,7 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
"endAt": "2026-03-08T10:00:00+08:00",
"timezone": "Asia/Shanghai",
"location": "会议室A",
"reminderMinutes": 15,
},
),
)
@@ -69,6 +76,77 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
data = cast(dict[str, object], result["data"])
assert data["id"] == str(created_id)
assert data["ok"] is True
assert data["reminderMinutes"] == 15
@pytest.mark.asyncio
async def test_mutate_calendar_event_update_maps_reminder_minutes(
monkeypatch: pytest.MonkeyPatch,
) -> None:
event_id = uuid4()
class _FakeService:
def __init__(self, **kwargs) -> None:
del kwargs
async def get_by_id(self, item_id):
assert item_id == event_id
return SimpleNamespace(
metadata=SimpleNamespace(
model_dump=lambda: {
"color": "#4F46E5",
"location": "会议室A",
"version": 1,
}
)
)
async def update(self, item_id, payload):
assert item_id == event_id
assert payload.metadata is not None
assert payload.metadata.reminder_minutes == 30
return SimpleNamespace(
id=event_id,
title="更新后",
description=None,
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
end_at=None,
timezone="Asia/Shanghai",
metadata=SimpleNamespace(
location="会议室A",
color="#4F46E5",
reminder_minutes=30,
),
)
class _FakeRepository:
def __init__(self, session) -> None:
del session
monkeypatch.setattr(
"core.agent.infrastructure.crewai.tools.create_calendar_event_tool.ScheduleItemService",
_FakeService,
)
monkeypatch.setattr(
"core.agent.infrastructure.crewai.tools.create_calendar_event_tool.SQLAlchemyScheduleItemRepository",
_FakeRepository,
)
result = cast(
dict[str, object],
await _execute_mutate_calendar_event(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
tool_args={
"operation": "update",
"eventId": str(event_id),
"reminderMinutes": 30,
},
),
)
data = cast(dict[str, object], result["data"])
assert data["reminderMinutes"] == 30
@pytest.mark.asyncio
@@ -0,0 +1,42 @@
from __future__ import annotations
from core.agentscope.events.agui_codec import to_agui_wire_event
def test_maps_internal_text_delta_to_agui_wire_event() -> None:
internal = {
"id": "e1",
"type": "text.delta",
"threadId": "t1",
"runId": "r1",
"data": {"delta": "hel"},
}
result = to_agui_wire_event(internal)
assert result["type"] == "TEXT_MESSAGE_CONTENT"
assert result["threadId"] == "t1"
assert result["runId"] == "r1"
assert result["delta"] == "hel"
def test_reserved_keys_in_data_cannot_override_wire_fields() -> None:
internal = {
"id": "e2",
"type": "run.started",
"threadId": "thread-1",
"runId": "run-1",
"data": {
"type": "RUN_ERROR",
"threadId": "thread-override",
"runId": "run-override",
"message": "ok",
},
}
result = to_agui_wire_event(internal)
assert result["type"] == "RUN_STARTED"
assert result["threadId"] == "thread-1"
assert result["runId"] == "run-1"
assert result["message"] == "ok"
@@ -0,0 +1,32 @@
from __future__ import annotations
import pytest
from core.agentscope.events.pipeline import AgentScopeEventPipeline
@pytest.mark.asyncio
async def test_pipeline_orders_codec_persist_publish() -> None:
calls: list[str] = []
class _Codec:
def to_wire(self, event: dict[str, object]) -> dict[str, object]:
calls.append("codec")
return {"type": "RUN_STARTED", **event}
class _Store:
async def persist(self, event: dict[str, object]) -> None:
calls.append("persist")
assert event["type"] == "RUN_STARTED"
class _Bus:
async def publish(self, *, session_id: str, event: dict[str, object]) -> str:
calls.append("publish")
assert session_id == "thread-1"
return "1-0"
pipeline = AgentScopeEventPipeline(codec=_Codec(), store=_Store(), bus=_Bus())
cursor = await pipeline.emit(session_id="thread-1", event={"id": "evt-1"})
assert cursor == "1-0"
assert calls == ["codec", "persist", "publish"]
@@ -0,0 +1,71 @@
from __future__ import annotations
from core.agentscope.events.redis_bus import RedisStreamBus
class _FakeRedis:
def __init__(self) -> None:
self._rows: list[tuple[str, str]] = []
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
cursor = f"{len(self._rows) + 1}-0"
self._rows.append((cursor, fields["event"]))
return cursor
def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
del count, block
stream_name, last = next(iter(streams.items()))
rows: list[tuple[str, dict[str, str]]] = []
for cursor, payload in self._rows:
if cursor > last:
rows.append((cursor, {"event": payload}))
return [(stream_name, rows)]
class _FakeRedisBytes:
def __init__(self) -> None:
self._rows: list[tuple[str, str]] = []
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
cursor = f"{len(self._rows) + 1}-0"
self._rows.append((cursor, fields["event"]))
return cursor
def xread(
self,
streams: dict[str, str],
count: int,
block: int,
) -> list[tuple[str, list[tuple[str, dict[str, bytes]]]]]:
del count, block
stream_name, last = next(iter(streams.items()))
rows: list[tuple[str, dict[str, bytes]]] = []
for cursor, payload in self._rows:
if cursor > last:
rows.append((cursor, {"event": payload.encode("utf-8")}))
return [(stream_name, rows)]
async def test_publish_then_read_after_cursor() -> None:
bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events")
first_cursor = await bus.publish(
session_id="thread-1", event={"type": "RUN_STARTED"}
)
await bus.publish(session_id="thread-1", event={"type": "RUN_FINISHED"})
rows = await bus.read(session_id="thread-1", last_event_id=first_cursor)
assert len(rows) == 1
assert rows[0]["event"]["type"] == "RUN_FINISHED"
async def test_read_supports_bytes_payload() -> None:
bus = RedisStreamBus(client=_FakeRedisBytes(), stream_prefix="agent.events")
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
rows = await bus.read(session_id="thread-1", last_event_id=None)
assert rows[0]["event"]["type"] == "RUN_STARTED"
@@ -0,0 +1,25 @@
from __future__ import annotations
import json
from core.agentscope.events.sse import to_sse_event
def test_sse_frame_contains_event_and_json_payload() -> None:
payload = {"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"}
frame = to_sse_event("1-0", payload)
assert frame.startswith("id: 1-0\n")
assert "event: RUN_STARTED\n" in frame
assert frame.endswith("\n\n")
data_line = [line for line in frame.splitlines() if line.startswith("data: ")][0]
parsed = json.loads(data_line[len("data: ") :])
assert parsed["threadId"] == "t1"
def test_sse_sanitizes_stream_id_newlines() -> None:
payload = {"type": "RUN_STARTED"}
frame = to_sse_event("1-0\nmalicious: yes", payload)
assert frame.startswith("id: 1-0malicious: yes\n")
@@ -0,0 +1,139 @@
from __future__ import annotations
from typing import Any, cast
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
from core.agentscope.schemas import ReportOutput, RuntimeOutput
from core.agentscope.schemas.agent_runtime import RunCommand
from core.agentscope.schemas.execution import ExecutionBatchOutput
from core.agentscope.schemas.intent import IntentOutput
def _user_context() -> UserAgentContext:
return UserAgentContext(
user_id=uuid4(),
username="tester",
bio=None,
settings=parse_profile_settings(
{
"version": 1,
"preferences": {
"interface_language": "zh-CN",
"ai_language": "zh-CN",
"timezone": "Asia/Shanghai",
"country": "CN",
},
}
),
)
@pytest.mark.asyncio
async def test_runtime_emits_started_text_and_finished_events() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FakeOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
return RuntimeOutput(
intent=IntentOutput(
route="DIRECT_RESPONSE",
intent_summary="summary",
direct_response="done",
tasks=[],
complexity="simple",
),
execution=ExecutionBatchOutput(
task_results=[],
overall_status="SUCCESS",
aggregate_summary="ok",
),
report=ReportOutput(
assistant_text="hello world",
response_metadata={},
),
)
runtime = AgentRouteRuntime(
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
assert [item["type"] for item in calls] == [
"run.started",
"step.start",
"step.finish",
"step.start",
"step.finish",
"step.start",
"text.start",
"text.delta",
"text.end",
"step.finish",
"run.finished",
]
assert calls[1]["data"]["stepName"] == "intent"
assert calls[2]["data"]["stepName"] == "intent"
assert calls[3]["data"]["stepName"] == "execution"
assert calls[4]["data"]["stepName"] == "execution"
assert calls[5]["data"]["stepName"] == "report"
assert calls[7]["data"]["delta"] == "hello world"
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
assert calls[9]["data"]["stepName"] == "report"
@pytest.mark.asyncio
async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
calls: list[dict[str, Any]] = []
class _FakePipeline:
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
assert session_id == "thread-1"
calls.append(event)
return f"{len(calls)}-0"
class _FailOrchestrator:
async def run(self, **_: object) -> RuntimeOutput:
raise RuntimeError("boom")
runtime = AgentRouteRuntime(
orchestrator=_FailOrchestrator(),
pipeline=_FakePipeline(),
)
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
with pytest.raises(RuntimeError, match="boom"):
await runtime.run(
command=command,
owner_id=uuid4(),
user_token="token",
user_context=_user_context(),
session=cast(AsyncSession, object()),
)
assert [item["type"] for item in calls] == [
"run.started",
"step.start",
"run.error",
]
assert calls[1]["data"]["stepName"] == "intent"
assert calls[2]["data"]["message"] == "runtime execution failed"
@@ -0,0 +1,138 @@
from __future__ import annotations
from typing import Any
from uuid import uuid4
import pytest
import core.agentscope.runtime.tasks as tasks_module
def _run_input_payload() -> dict[str, Any]:
return {
"threadId": str(uuid4()),
"runId": "run-1",
"messages": [],
"tools": [],
"context": {},
"forwardedProps": {},
}
class _FakeSessionCtx:
async def __aenter__(self) -> object:
return object()
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
@pytest.mark.asyncio
async def test_run_agentscope_task_calls_runtime_run(
monkeypatch: pytest.MonkeyPatch,
) -> None:
called: dict[str, int] = {"run": 0, "resume": 0}
class _FakeRuntime:
def __init__(self, **kwargs: object) -> None:
del kwargs
async def run(self, **kwargs: object) -> object:
del kwargs
called["run"] += 1
return object()
async def resume(self, **kwargs: object) -> object:
del kwargs
called["resume"] += 1
return object()
async def _fake_get_redis_client() -> object:
return object()
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
"get_or_init_redis_client",
_fake_get_redis_client,
)
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
result = await tasks_module.run_agentscope_task(
{
"command": "run",
"owner_id": str(uuid4()),
"run_input": _run_input_payload(),
}
)
assert result["status"] == "completed"
assert called["run"] == 1
assert called["resume"] == 0
@pytest.mark.asyncio
async def test_run_agentscope_task_calls_runtime_resume(
monkeypatch: pytest.MonkeyPatch,
) -> None:
called: dict[str, int] = {"run": 0, "resume": 0}
class _FakeRuntime:
def __init__(self, **kwargs: object) -> None:
del kwargs
async def run(self, **kwargs: object) -> object:
del kwargs
called["run"] += 1
return object()
async def resume(self, **kwargs: object) -> object:
del kwargs
called["resume"] += 1
return object()
async def _fake_get_redis_client() -> object:
return object()
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
monkeypatch.setattr(
tasks_module,
"get_or_init_redis_client",
_fake_get_redis_client,
)
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
result = await tasks_module.run_agentscope_task(
{
"command": "resume",
"owner_id": str(uuid4()),
"run_input": _run_input_payload(),
}
)
assert result["status"] == "completed"
assert called["run"] == 0
assert called["resume"] == 1
@pytest.mark.asyncio
async def test_run_agentscope_task_requires_owner_id() -> None:
with pytest.raises(ValueError, match="owner_id is required"):
await tasks_module.run_agentscope_task(
{
"command": "run",
"run_input": _run_input_payload(),
}
)
@pytest.mark.asyncio
async def test_run_agentscope_task_rejects_invalid_command_type() -> None:
with pytest.raises(ValueError, match="invalid command type"):
await tasks_module.run_agentscope_task(
{
"command": "unknown",
"owner_id": str(uuid4()),
"run_input": _run_input_payload(),
}
)
@@ -0,0 +1,106 @@
from __future__ import annotations
import pytest
from pydantic import ValidationError
from core.agentscope import schemas as exported_schemas
from core.agentscope.schemas.agent_runtime import (
AcceptedTaskResponse,
AgUiWireEvent,
HistorySnapshot,
HistorySnapshotResponse,
InternalRuntimeEvent,
ResumeCommand,
RunCommand,
)
def test_run_command_alias_roundtrip() -> None:
payload = {
"threadId": "thread-001",
"runId": "run-001",
"state": {"cursor": 1},
"messages": [{"role": "user", "content": "hi"}],
"tools": [{"name": "calendar.lookup"}],
"context": {"locale": "zh-CN"},
"forwardedProps": {"traceId": "trace-1"},
}
command = RunCommand.model_validate(payload)
assert command.thread_id == "thread-001"
assert command.run_id == "run-001"
assert command.forwarded_props == {"traceId": "trace-1"}
dumped = command.model_dump(mode="json", by_alias=True)
assert dumped["threadId"] == "thread-001"
assert dumped["runId"] == "run-001"
assert dumped["forwardedProps"] == {"traceId": "trace-1"}
def test_history_snapshot_response_shape() -> None:
response = HistorySnapshotResponse(
threadId="thread-123",
snapshot=HistorySnapshot(
threadId="thread-123",
day="2026-03-11",
hasMore=False,
messages=[{"id": "msg-1"}],
),
)
dumped = response.model_dump(mode="json", by_alias=True, exclude_none=True)
assert dumped["type"] == "STATE_SNAPSHOT"
assert dumped["threadId"] == "thread-123"
assert dumped["snapshot"]["scope"] == "history_day"
assert dumped["snapshot"]["hasMore"] is False
assert dumped["snapshot"]["messages"] == [{"id": "msg-1"}]
def test_runtime_event_validation_basics() -> None:
internal = InternalRuntimeEvent(type="RUN_STARTED", data={"step": 1})
assert internal.type == "RUN_STARTED"
assert internal.model_dump(mode="json", by_alias=True)["data"] == {"step": 1}
wire = AgUiWireEvent(type="TEXT_MESSAGE_CONTENT", payload={"delta": "hello"})
dumped = wire.model_dump(mode="json", by_alias=True, exclude_none=True)
assert dumped["type"] == "TEXT_MESSAGE_CONTENT"
assert dumped["payload"] == {"delta": "hello"}
with pytest.raises(ValidationError):
InternalRuntimeEvent.model_validate({"threadId": "t-1", "data": {}})
with pytest.raises(ValidationError):
AgUiWireEvent.model_validate({"payload": {"delta": "hello"}})
def test_task_response_and_resume_aliases() -> None:
accepted = AcceptedTaskResponse(
taskId="task-1",
threadId="thread-1",
runId="run-1",
created=False,
)
dumped = accepted.model_dump(mode="json", by_alias=True)
assert dumped["taskId"] == "task-1"
assert dumped["threadId"] == "thread-1"
assert dumped["runId"] == "run-1"
resumed = ResumeCommand.model_validate(
{
"threadId": "thread-1",
"runId": "run-2",
"messages": [],
"tools": [],
"context": {},
}
)
assert resumed.thread_id == "thread-1"
assert resumed.run_id == "run-2"
def test_schemas_exports_include_task_and_history_models() -> None:
assert exported_schemas.AcceptedTaskResponse is AcceptedTaskResponse
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
@@ -131,3 +131,50 @@ async def test_calendar_write_rejects_event_id_for_create(
assert result["data"]["ok"] is False
assert result["data"]["code"] == "INVALID_ARGUMENT"
@pytest.mark.asyncio
async def test_calendar_write_maps_reminder_minutes(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict[str, object] = {}
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
captured.update(cast(dict[str, object], kwargs["tool_args"]))
return {"type": "calendar_card.v1", "version": "v1", "data": {"ok": True}}
monkeypatch.setattr(
calendar_module,
"_execute_mutate_calendar_event",
_fake_execute,
)
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
await calendar_module.calendar_write(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-abc",
operation="create",
reminder_minutes=15,
)
assert captured["reminderMinutes"] == 15
@pytest.mark.asyncio
async def test_calendar_write_rejects_invalid_reminder_minutes(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
result = await calendar_module.calendar_write(
session=cast(AsyncSession, SimpleNamespace()),
owner_id=uuid4(),
user_token="token-abc",
operation="create",
reminder_minutes=10081,
)
assert result["data"]["ok"] is False
assert result["data"]["code"] == "INVALID_ARGUMENT"
@@ -103,6 +103,18 @@ def test_metadata_rejects_unknown_field() -> None:
ScheduleItemMetadata.model_validate({"color": "#FF6B6B", "unknown": True})
@pytest.mark.parametrize("value", [None, 0, 15, 10080])
def test_metadata_accepts_reminder_minutes(value: int | None) -> None:
metadata = ScheduleItemMetadata(reminder_minutes=value)
assert metadata.reminder_minutes == value
@pytest.mark.parametrize("value", [-1, 10081])
def test_metadata_rejects_out_of_range_reminder_minutes(value: int) -> None:
with pytest.raises(ValidationError):
ScheduleItemMetadata(reminder_minutes=value)
def test_metadata_attachment_rejects_unknown_field() -> None:
with pytest.raises(ValidationError):
ScheduleItemMetadataAttachment.model_validate(
@@ -221,7 +221,12 @@ async def test_create_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
request = ScheduleItemCreateRequest(
title="Roadmap",
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
metadata=ScheduleItemMetadata(location="会议室A", color="#4F46E5", version=1),
metadata=ScheduleItemMetadata(
location="会议室A",
color="#4F46E5",
reminder_minutes=15,
version=1,
),
)
service = ScheduleItemService(
repository=CaptureRepo(None),
@@ -234,6 +239,7 @@ async def test_create_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
assert captured is not None
assert "extra_metadata" in captured
assert captured["extra_metadata"]["location"] == "会议室A"
assert captured["extra_metadata"]["reminder_minutes"] == 15
assert "metadata" not in captured
@@ -261,7 +267,10 @@ async def test_update_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
item.id,
ScheduleItemUpdateRequest(
metadata=ScheduleItemMetadata(
location="线上会议", color="#3B82F6", version=1
location="线上会议",
color="#3B82F6",
reminder_minutes=30,
version=1,
)
),
)
@@ -269,4 +278,38 @@ async def test_update_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
assert captured is not None
assert "extra_metadata" in captured
assert captured["extra_metadata"]["location"] == "线上会议"
assert captured["extra_metadata"]["reminder_minutes"] == 30
assert "metadata" not in captured
@pytest.mark.asyncio
async def test_update_maps_null_metadata_to_extra_metadata_null(
mock_session: AsyncMock,
) -> None:
user_id = UUID("00000000-0000-0000-0000-000000000001")
item = _create_mock_schedule_item()
captured: dict | None = None
class CaptureRepo(FakeRepo):
async def update_by_item_id(
self, item_id: UUID, owner_id: UUID, data: dict
) -> ScheduleItem | None:
nonlocal captured
captured = data
return await super().update_by_item_id(item_id, owner_id, data)
service = ScheduleItemService(
repository=CaptureRepo(item),
session=mock_session,
current_user=CurrentUser(id=user_id),
)
await service.update(
item.id,
ScheduleItemUpdateRequest(metadata=None),
)
assert captured is not None
assert "extra_metadata" in captured
assert captured["extra_metadata"] is None
assert "metadata" not in captured