feat: 增强日历功能并集成 AgentScope 代理服务
This commit is contained in:
@@ -83,7 +83,23 @@ def _resolve_metadata(tool_args: dict[str, object]) -> ScheduleItemMetadata:
|
||||
color = tool_args.get("color")
|
||||
raw_color = color.strip() if isinstance(color, str) and color.strip() else "#4F46E5"
|
||||
color_value = raw_color if _HEX_COLOR_PATTERN.match(raw_color) else "#4F46E5"
|
||||
return ScheduleItemMetadata(location=location_value, color=color_value)
|
||||
reminder_raw = tool_args.get("reminderMinutes")
|
||||
reminder_value: int | None = None
|
||||
if isinstance(reminder_raw, bool):
|
||||
reminder_value = None
|
||||
elif isinstance(reminder_raw, (int, float, str)):
|
||||
try:
|
||||
parsed = int(str(reminder_raw).strip())
|
||||
if parsed < 0 or parsed > 10080:
|
||||
raise ValueError("reminderMinutes must be 0..10080")
|
||||
reminder_value = parsed
|
||||
except ValueError as exc:
|
||||
raise ValueError("reminderMinutes must be an integer in 0..10080") from exc
|
||||
return ScheduleItemMetadata(
|
||||
location=location_value,
|
||||
color=color_value,
|
||||
reminder_minutes=reminder_value,
|
||||
)
|
||||
|
||||
|
||||
def _event_payload(event: object) -> dict[str, object]:
|
||||
@@ -91,6 +107,7 @@ def _event_payload(event: object) -> dict[str, object]:
|
||||
metadata = getattr(event, "metadata", None)
|
||||
location_value = getattr(metadata, "location", None)
|
||||
color_value = getattr(metadata, "color", None) or "#4F46E5"
|
||||
reminder_minutes_value = getattr(metadata, "reminder_minutes", None)
|
||||
return {
|
||||
"id": event_id,
|
||||
"title": getattr(event, "title"),
|
||||
@@ -104,6 +121,7 @@ def _event_payload(event: object) -> dict[str, object]:
|
||||
"timezone": getattr(event, "timezone"),
|
||||
"location": location_value,
|
||||
"color": color_value,
|
||||
"reminderMinutes": reminder_minutes_value,
|
||||
}
|
||||
|
||||
|
||||
@@ -221,7 +239,8 @@ async def _execute_update(
|
||||
) from exc
|
||||
has_location = isinstance(tool_args.get("location"), str)
|
||||
has_color = isinstance(tool_args.get("color"), str)
|
||||
if has_location or has_color:
|
||||
has_reminder = "reminderMinutes" in tool_args
|
||||
if has_location or has_color or has_reminder:
|
||||
existing = await service.get_by_id(event_id)
|
||||
metadata_dump = (
|
||||
existing.metadata.model_dump() if existing.metadata is not None else {}
|
||||
@@ -236,6 +255,22 @@ async def _execute_update(
|
||||
metadata_dump["color"] = color
|
||||
else:
|
||||
raise ValueError("color must be a hex string like #RRGGBB")
|
||||
if has_reminder:
|
||||
reminder_raw = tool_args.get("reminderMinutes")
|
||||
if reminder_raw is None:
|
||||
metadata_dump["reminder_minutes"] = None
|
||||
elif isinstance(reminder_raw, bool):
|
||||
raise ValueError("reminderMinutes must be an integer in 0..10080")
|
||||
else:
|
||||
try:
|
||||
reminder = int(str(reminder_raw).strip())
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"reminderMinutes must be an integer in 0..10080"
|
||||
) from exc
|
||||
if reminder < 0 or reminder > 10080:
|
||||
raise ValueError("reminderMinutes must be 0..10080")
|
||||
metadata_dump["reminder_minutes"] = reminder
|
||||
update_data["metadata"] = ScheduleItemMetadata.model_validate(metadata_dump)
|
||||
|
||||
updated = await service.update(
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from core.agentscope.events.agui_codec import AgentScopeAgUiCodec, to_agui_wire_event
|
||||
from core.agentscope.events.pipeline import AgentScopeEventPipeline
|
||||
from core.agentscope.events.redis_bus import RedisStreamBus
|
||||
from core.agentscope.events.sse import to_sse_event
|
||||
from core.agentscope.events.store import NullEventStore
|
||||
|
||||
__all__ = [
|
||||
"AgentScopeAgUiCodec",
|
||||
"AgentScopeEventPipeline",
|
||||
"RedisStreamBus",
|
||||
"NullEventStore",
|
||||
"to_agui_wire_event",
|
||||
"to_sse_event",
|
||||
]
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
_TYPE_MAP: dict[str, str] = {
|
||||
"run.started": "RUN_STARTED",
|
||||
"run.finished": "RUN_FINISHED",
|
||||
"run.error": "RUN_ERROR",
|
||||
"step.start": "STEP_STARTED",
|
||||
"step.finish": "STEP_FINISHED",
|
||||
"text.start": "TEXT_MESSAGE_START",
|
||||
"text.delta": "TEXT_MESSAGE_CONTENT",
|
||||
"text.end": "TEXT_MESSAGE_END",
|
||||
"tool.start": "TOOL_CALL_START",
|
||||
"tool.args": "TOOL_CALL_ARGS",
|
||||
"tool.end": "TOOL_CALL_END",
|
||||
"tool.result": "TOOL_CALL_RESULT",
|
||||
"tool.error": "TOOL_CALL_ERROR",
|
||||
"state.snapshot": "STATE_SNAPSHOT",
|
||||
"messages.snapshot": "MESSAGES_SNAPSHOT",
|
||||
}
|
||||
|
||||
|
||||
def to_agui_wire_event(event: dict[str, Any]) -> dict[str, Any]:
|
||||
event_type = str(event.get("type", "")).strip()
|
||||
wire_type = _TYPE_MAP.get(event_type, event_type.upper().replace(".", "_"))
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"type": wire_type,
|
||||
}
|
||||
thread_id = event.get("threadId")
|
||||
run_id = event.get("runId")
|
||||
if isinstance(thread_id, str) and thread_id:
|
||||
payload["threadId"] = thread_id
|
||||
if isinstance(run_id, str) and run_id:
|
||||
payload["runId"] = run_id
|
||||
|
||||
data = event.get("data")
|
||||
if isinstance(data, dict):
|
||||
reserved = {"type", "threadId", "runId"}
|
||||
data_map = cast(dict[str, Any], data)
|
||||
payload.update({k: v for k, v in data_map.items() if k not in reserved})
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class AgentScopeAgUiCodec:
|
||||
def to_wire(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
return to_agui_wire_event(event)
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class CodecLike(Protocol):
|
||||
def to_wire(self, event: dict[str, Any]) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class StoreLike(Protocol):
|
||||
async def persist(self, event: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class BusLike(Protocol):
|
||||
async def publish(self, *, session_id: str, event: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class AgentScopeEventPipeline:
|
||||
_codec: CodecLike
|
||||
_store: StoreLike
|
||||
_bus: BusLike
|
||||
|
||||
def __init__(self, *, codec: CodecLike, store: StoreLike, bus: BusLike) -> None:
|
||||
self._codec = codec
|
||||
self._store = store
|
||||
self._bus = bus
|
||||
|
||||
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str:
|
||||
wire_event = self._codec.to_wire(event)
|
||||
await self._store.persist(wire_event)
|
||||
return await self._bus.publish(session_id=session_id, event=wire_event)
|
||||
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
|
||||
class RedisStreamClient(Protocol):
|
||||
def xadd(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def xread(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class RedisStreamBus:
|
||||
_client: RedisStreamClient
|
||||
_stream_prefix: str
|
||||
_read_count: int
|
||||
_block_ms: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: RedisStreamClient,
|
||||
stream_prefix: str,
|
||||
read_count: int = 100,
|
||||
block_ms: int = 5000,
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._stream_prefix = stream_prefix
|
||||
self._read_count = read_count
|
||||
self._block_ms = block_ms
|
||||
|
||||
async def publish(self, *, session_id: str, event: dict[str, Any]) -> str:
|
||||
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
|
||||
result = self._client.xadd(self._stream_name(session_id), {"event": payload})
|
||||
if inspect.isawaitable(result):
|
||||
return str(await result)
|
||||
return str(result)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
stream = self._stream_name(session_id)
|
||||
start_id = "0-0" if last_event_id is None else last_event_id
|
||||
raw = self._client.xread(
|
||||
{stream: start_id},
|
||||
count=self._read_count,
|
||||
block=self._block_ms,
|
||||
)
|
||||
response = await raw if inspect.isawaitable(raw) else raw
|
||||
if not response:
|
||||
return []
|
||||
|
||||
first = response[0]
|
||||
if (
|
||||
not isinstance(first, tuple)
|
||||
or len(first) != 2
|
||||
or not isinstance(first[1], list)
|
||||
):
|
||||
return []
|
||||
|
||||
entries = cast(list[tuple[str, dict[str, Any]]], first[1])
|
||||
rows: list[dict[str, Any]] = []
|
||||
for entry in entries:
|
||||
if (
|
||||
not isinstance(entry, tuple)
|
||||
or len(entry) != 2
|
||||
or not isinstance(entry[0], str)
|
||||
or not isinstance(entry[1], dict)
|
||||
):
|
||||
continue
|
||||
payload_map = cast(dict[str, Any], entry[1])
|
||||
event_payload = payload_map.get("event")
|
||||
if isinstance(event_payload, bytes):
|
||||
event_payload = event_payload.decode("utf-8", errors="replace")
|
||||
if not isinstance(event_payload, str):
|
||||
continue
|
||||
try:
|
||||
decoded = json.loads(event_payload)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if not isinstance(decoded, dict):
|
||||
continue
|
||||
rows.append({"id": entry[0], "event": decoded})
|
||||
return rows
|
||||
|
||||
def _stream_name(self, session_id: str) -> str:
|
||||
return f"{self._stream_prefix}:{session_id}"
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from ag_ui.core.events import BaseEvent
|
||||
from ag_ui.encoder.encoder import EventEncoder
|
||||
|
||||
_EVENT_TYPE_RE = re.compile(r"^[A-Z0-9_]+$")
|
||||
_ENCODER = EventEncoder()
|
||||
|
||||
|
||||
def to_sse_event(stream_id: str, event: dict[str, Any]) -> str:
|
||||
safe_stream_id = str(stream_id).replace("\r", "").replace("\n", "")
|
||||
try:
|
||||
event_model = BaseEvent.model_validate(event)
|
||||
event_type = event_model.type.value
|
||||
encoded_data = _ENCODER.encode(event_model)
|
||||
return f"id: {safe_stream_id}\nevent: {event_type}\n{encoded_data}"
|
||||
except Exception: # noqa: BLE001
|
||||
raw_event_type = (
|
||||
str(event.get("type", "MESSAGE")).replace("\r", "").replace("\n", "")
|
||||
)
|
||||
event_type = (
|
||||
raw_event_type if _EVENT_TYPE_RE.fullmatch(raw_event_type) else "MESSAGE"
|
||||
)
|
||||
payload = json.dumps(event, ensure_ascii=True, separators=(",", ":"))
|
||||
return f"id: {safe_stream_id}\nevent: {event_type}\ndata: {payload}\n\n"
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class EventStore(Protocol):
|
||||
async def persist(self, event: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class NullEventStore:
|
||||
async def persist(self, event: dict[str, Any]) -> None:
|
||||
del event
|
||||
@@ -1,4 +1,9 @@
|
||||
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
||||
from core.agentscope.runtime.orchestrator import AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.runtime.react_runner import AgentScopeReActRunner
|
||||
|
||||
__all__ = ["AgentScopeRuntimeOrchestrator", "AgentScopeReActRunner"]
|
||||
__all__ = [
|
||||
"AgentRouteRuntime",
|
||||
"AgentScopeRuntimeOrchestrator",
|
||||
"AgentScopeReActRunner",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.domain.user_context import UserAgentContext
|
||||
from core.logging import get_logger
|
||||
from core.agentscope.schemas import RuntimeOutput
|
||||
from core.agentscope.schemas.agent_runtime import ResumeCommand, RunCommand
|
||||
|
||||
|
||||
class OrchestratorLike(Protocol):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
session: AsyncSession,
|
||||
owner_id: UUID,
|
||||
user_token: str,
|
||||
user_context: UserAgentContext,
|
||||
user_input: str | list[dict[str, Any]],
|
||||
) -> RuntimeOutput: ...
|
||||
|
||||
|
||||
class PipelineLike(Protocol):
|
||||
async def emit(self, *, session_id: str, event: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class AgentRouteRuntime:
|
||||
_orchestrator: OrchestratorLike
|
||||
_pipeline: PipelineLike
|
||||
_logger = get_logger("core.agentscope.runtime.agent_route_runtime")
|
||||
|
||||
def __init__(
|
||||
self, *, orchestrator: OrchestratorLike, pipeline: PipelineLike
|
||||
) -> None:
|
||||
self._orchestrator = orchestrator
|
||||
self._pipeline = pipeline
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
command: RunCommand,
|
||||
owner_id: UUID,
|
||||
user_token: str,
|
||||
user_context: UserAgentContext,
|
||||
session: AsyncSession,
|
||||
) -> RuntimeOutput:
|
||||
return await self._execute(
|
||||
command=command,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
session=session,
|
||||
)
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
*,
|
||||
command: ResumeCommand,
|
||||
owner_id: UUID,
|
||||
user_token: str,
|
||||
user_context: UserAgentContext,
|
||||
session: AsyncSession,
|
||||
) -> RuntimeOutput:
|
||||
return await self._execute(
|
||||
command=command,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
session=session,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
*,
|
||||
command: RunCommand,
|
||||
owner_id: UUID,
|
||||
user_token: str,
|
||||
user_context: UserAgentContext,
|
||||
session: AsyncSession,
|
||||
) -> RuntimeOutput:
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "run.started",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.start",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "intent"},
|
||||
},
|
||||
)
|
||||
try:
|
||||
result = await self._orchestrator.run(
|
||||
session=session,
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
user_input=command.messages,
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "intent"},
|
||||
},
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
self._logger.exception(
|
||||
"agentscope runtime execution failed",
|
||||
thread_id=command.thread_id,
|
||||
run_id=command.run_id,
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "run.error",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"message": "runtime execution failed"},
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
if result.execution is not None:
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.start",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "execution"},
|
||||
},
|
||||
)
|
||||
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.start",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "report"},
|
||||
},
|
||||
)
|
||||
|
||||
report_message_id = f"assistant-{command.run_id}"
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.start",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"messageId": report_message_id, "role": "assistant"},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.delta",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {
|
||||
"messageId": report_message_id,
|
||||
"delta": result.report.assistant_text,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "text.end",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"messageId": report_message_id},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "step.finish",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {"stepName": "report"},
|
||||
},
|
||||
)
|
||||
await self._pipeline.emit(
|
||||
session_id=command.thread_id,
|
||||
event={
|
||||
"type": "run.finished",
|
||||
"threadId": command.thread_id,
|
||||
"runId": command.run_id,
|
||||
"data": {},
|
||||
},
|
||||
)
|
||||
return result
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from core.agentscope.events import (
|
||||
AgentScopeAgUiCodec,
|
||||
AgentScopeEventPipeline,
|
||||
NullEventStore,
|
||||
RedisStreamBus,
|
||||
)
|
||||
from core.agentscope.runtime import AgentRouteRuntime, AgentScopeRuntimeOrchestrator
|
||||
from core.agentscope.schemas.agent_runtime import ResumeCommand, RunCommand
|
||||
from core.config.settings import config
|
||||
from core.db.session import AsyncSessionLocal
|
||||
from core.logging import get_logger
|
||||
from core.taskiq.app import bulk_broker, critical_broker, default_broker
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("core.agentscope.runtime.tasks")
|
||||
|
||||
|
||||
def _build_user_context(*, owner_id: UUID, run_input: RunCommand) -> UserAgentContext:
|
||||
forwarded = (
|
||||
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
|
||||
)
|
||||
username = str(forwarded.get("username", "user")).strip() or "user"
|
||||
bio_value = forwarded.get("bio")
|
||||
bio = str(bio_value).strip() if isinstance(bio_value, str) else None
|
||||
profile_settings = forwarded.get("profileSettings")
|
||||
settings_raw = profile_settings if isinstance(profile_settings, dict) else None
|
||||
return UserAgentContext(
|
||||
user_id=owner_id,
|
||||
username=username,
|
||||
bio=bio,
|
||||
settings=parse_profile_settings(settings_raw),
|
||||
)
|
||||
|
||||
|
||||
def _extract_user_token(
|
||||
*, command: dict[str, Any], run_input: RunCommand
|
||||
) -> str | None:
|
||||
raw_token = command.get("user_token")
|
||||
if isinstance(raw_token, str) and raw_token.strip():
|
||||
return raw_token.strip()
|
||||
forwarded = (
|
||||
run_input.forwarded_props if isinstance(run_input.forwarded_props, dict) else {}
|
||||
)
|
||||
for key in ("accessToken", "userToken", "token"):
|
||||
value = forwarded.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
async def run_agentscope_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
command_type = str(command.get("command", "run")).strip().lower()
|
||||
raw_run_input = command.get("run_input")
|
||||
raw_owner_id = command.get("owner_id")
|
||||
|
||||
if not isinstance(raw_run_input, dict):
|
||||
raise ValueError("run_input is required")
|
||||
if not isinstance(raw_owner_id, str) or not raw_owner_id.strip():
|
||||
raise ValueError("owner_id is required")
|
||||
|
||||
owner_id = UUID(raw_owner_id)
|
||||
parsed_run_input = (
|
||||
ResumeCommand.model_validate(raw_run_input)
|
||||
if command_type == "resume"
|
||||
else RunCommand.model_validate(raw_run_input)
|
||||
)
|
||||
user_context = _build_user_context(owner_id=owner_id, run_input=parsed_run_input)
|
||||
user_token = _extract_user_token(command=command, run_input=parsed_run_input) or ""
|
||||
|
||||
redis_client = await get_or_init_redis_client()
|
||||
bus = RedisStreamBus(
|
||||
client=redis_client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
pipeline = AgentScopeEventPipeline(
|
||||
codec=AgentScopeAgUiCodec(),
|
||||
store=NullEventStore(),
|
||||
bus=bus,
|
||||
)
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=AgentScopeRuntimeOrchestrator(),
|
||||
pipeline=pipeline,
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
if command_type == "resume":
|
||||
await runtime.resume(
|
||||
command=ResumeCommand.model_validate(raw_run_input),
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
session=session,
|
||||
)
|
||||
elif command_type == "run":
|
||||
await runtime.run(
|
||||
command=RunCommand.model_validate(raw_run_input),
|
||||
owner_id=owner_id,
|
||||
user_token=user_token,
|
||||
user_context=user_context,
|
||||
session=session,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid command type")
|
||||
|
||||
logger.info(
|
||||
"agentscope runtime task completed",
|
||||
command_type=command_type,
|
||||
thread_id=parsed_run_input.thread_id,
|
||||
run_id=parsed_run_input.run_id,
|
||||
)
|
||||
return {
|
||||
"thread_id": parsed_run_input.thread_id,
|
||||
"run_id": parsed_run_input.run_id,
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
|
||||
@default_broker.task(task_name="tasks.agentscope.run_command")
|
||||
async def run_command_task(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@critical_broker.task(task_name="tasks.agentscope.run_command.critical")
|
||||
async def run_command_task_critical(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
|
||||
|
||||
@bulk_broker.task(task_name="tasks.agentscope.run_command.bulk")
|
||||
async def run_command_task_bulk(command: dict[str, Any]) -> dict[str, object]:
|
||||
return await run_agentscope_task(command)
|
||||
@@ -1,13 +1,31 @@
|
||||
from core.agentscope.schemas.agent_runtime import (
|
||||
AcceptedTaskResponse,
|
||||
AgUiWireEvent,
|
||||
HistorySnapshotResponse,
|
||||
InternalRuntimeEvent,
|
||||
ResumeCommand,
|
||||
RunCommand,
|
||||
TaskAccepted,
|
||||
TaskAcceptedResponse,
|
||||
)
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput, ExecutionTaskOutput
|
||||
from core.agentscope.schemas.intent import IntentOutput, IntentTask
|
||||
from core.agentscope.schemas.report import ReportOutput
|
||||
from core.agentscope.schemas.runtime import RuntimeOutput
|
||||
|
||||
__all__ = [
|
||||
"AgUiWireEvent",
|
||||
"AcceptedTaskResponse",
|
||||
"ExecutionBatchOutput",
|
||||
"ExecutionTaskOutput",
|
||||
"HistorySnapshotResponse",
|
||||
"IntentOutput",
|
||||
"IntentTask",
|
||||
"InternalRuntimeEvent",
|
||||
"ReportOutput",
|
||||
"ResumeCommand",
|
||||
"RuntimeOutput",
|
||||
"RunCommand",
|
||||
"TaskAccepted",
|
||||
"TaskAcceptedResponse",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class _AliasModel(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
populate_by_name=True, serialize_by_alias=True, extra="forbid"
|
||||
)
|
||||
|
||||
|
||||
class AcceptedTaskResponse(_AliasModel):
|
||||
task_id: str = Field(alias="taskId", min_length=1)
|
||||
thread_id: str = Field(alias="threadId", min_length=1)
|
||||
run_id: str = Field(alias="runId", min_length=1)
|
||||
created: bool
|
||||
|
||||
|
||||
class RunCommand(_AliasModel):
|
||||
thread_id: str = Field(alias="threadId", min_length=1)
|
||||
run_id: str = Field(alias="runId", min_length=1)
|
||||
state: dict[str, Any] | None = None
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
tools: list[dict[str, Any]] = Field(default_factory=list)
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
forwarded_props: dict[str, Any] = Field(
|
||||
default_factory=dict, alias="forwardedProps"
|
||||
)
|
||||
|
||||
|
||||
class ResumeCommand(RunCommand):
|
||||
pass
|
||||
|
||||
|
||||
# Backward compatibility alias during migration.
|
||||
TaskAcceptedResponse = AcceptedTaskResponse
|
||||
TaskAccepted = AcceptedTaskResponse
|
||||
|
||||
|
||||
class InternalRuntimeEvent(_AliasModel):
|
||||
type: str = Field(min_length=1)
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
run_id: str | None = Field(default=None, alias="runId")
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgUiWireEvent(_AliasModel):
|
||||
type: str = Field(min_length=1)
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
run_id: str | None = Field(default=None, alias="runId")
|
||||
payload: Any = None
|
||||
|
||||
|
||||
class HistorySnapshot(_AliasModel):
|
||||
scope: Literal["history_day"] = "history_day"
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
day: str | None = None
|
||||
has_more: bool = Field(default=False, alias="hasMore")
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HistorySnapshotResponse(_AliasModel):
|
||||
type: Literal["STATE_SNAPSHOT"] = "STATE_SNAPSHOT"
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
run_id: str | None = Field(default=None, alias="runId")
|
||||
snapshot: HistorySnapshot
|
||||
@@ -134,6 +134,14 @@ async def calendar_write(
|
||||
str | None,
|
||||
Field(description="Event color value, for example #4F46E5."),
|
||||
] = None,
|
||||
reminder_minutes: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Minutes before start time to trigger reminder (0-10080).",
|
||||
ge=0,
|
||||
le=10080,
|
||||
),
|
||||
] = None,
|
||||
status: Annotated[
|
||||
Literal["active", "completed", "canceled", "archived"] | None,
|
||||
Field(description="Event status: active, completed, canceled, or archived."),
|
||||
@@ -158,6 +166,7 @@ async def calendar_write(
|
||||
timezone: Event timezone.
|
||||
location: Event location.
|
||||
color: Event color.
|
||||
reminder_minutes: Reminder minutes before event start.
|
||||
status: Event lifecycle status.
|
||||
replace: Replace-strategy flag for conflict handling.
|
||||
session: Runtime-injected database session.
|
||||
@@ -193,6 +202,12 @@ async def calendar_write(
|
||||
return build_tool_response(
|
||||
_invalid_argument_response(message="timezone length must be <= 50")
|
||||
)
|
||||
if reminder_minutes is not None and (
|
||||
reminder_minutes < 0 or reminder_minutes > 10080
|
||||
):
|
||||
return build_tool_response(
|
||||
_invalid_argument_response(message="reminder_minutes must be 0..10080")
|
||||
)
|
||||
|
||||
if session is None or owner_id is None:
|
||||
raise ValueError("calendar.write missing runtime preset arguments")
|
||||
@@ -221,6 +236,8 @@ async def calendar_write(
|
||||
tool_args["location"] = location
|
||||
if color is not None:
|
||||
tool_args["color"] = color
|
||||
if reminder_minutes is not None:
|
||||
tool_args["reminderMinutes"] = reminder_minutes
|
||||
if status is not None:
|
||||
tool_args["status"] = status
|
||||
|
||||
|
||||
@@ -2,21 +2,20 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.infrastructure.events.redis_stream import RedisStreamEventStore
|
||||
from core.agent.infrastructure.storage.tool_result_storage import (
|
||||
create_tool_result_storage,
|
||||
)
|
||||
from core.agent.infrastructure.queue.tasks import (
|
||||
from core.agentscope.events import RedisStreamBus
|
||||
from core.agentscope.runtime.tasks import (
|
||||
run_command_task,
|
||||
run_command_task_bulk,
|
||||
run_command_task_critical,
|
||||
)
|
||||
from core.agent.infrastructure.storage.tool_result_storage import (
|
||||
create_tool_result_storage,
|
||||
)
|
||||
from core.config.settings import config
|
||||
from core.db import get_db
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
@@ -84,18 +83,18 @@ class TaskiqQueueClient:
|
||||
|
||||
class RedisEventStream:
|
||||
def __init__(self) -> None:
|
||||
self._store: RedisStreamEventStore | None = None
|
||||
self._bus: RedisStreamBus | None = None
|
||||
|
||||
async def _get_store(self) -> RedisStreamEventStore:
|
||||
if self._store is None:
|
||||
async def _get_bus(self) -> RedisStreamBus:
|
||||
if self._bus is None:
|
||||
client = await get_or_init_redis_client()
|
||||
self._store = RedisStreamEventStore(
|
||||
self._bus = RedisStreamBus(
|
||||
client=client,
|
||||
stream_prefix=config.agent_runtime.redis_stream_prefix,
|
||||
read_count=config.agent_runtime.redis_stream_read_count,
|
||||
block_ms=config.agent_runtime.redis_stream_block_ms,
|
||||
)
|
||||
return self._store
|
||||
return self._bus
|
||||
|
||||
async def read(
|
||||
self,
|
||||
@@ -103,12 +102,9 @@ class RedisEventStream:
|
||||
session_id: str,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
store = await self._get_store()
|
||||
rows = await store.read_events(
|
||||
session_id=UUID(session_id),
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
return [{**row, "cursor": last_event_id} for row in rows]
|
||||
bus = await self._get_bus()
|
||||
rows = await bus.read(session_id=session_id, last_event_id=last_event_id)
|
||||
return [{**row, "cursor": row.get("id")} for row in rows]
|
||||
|
||||
|
||||
def get_agent_service(session: AsyncSession = Depends(get_db)) -> AgentService:
|
||||
|
||||
@@ -14,7 +14,7 @@ from fastapi import APIRouter, Depends, Header, Query, Request, status, UploadFi
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from core.agent.infrastructure.agui.stream import to_sse_event
|
||||
from core.agentscope.events import to_sse_event
|
||||
from core.agent.domain.agui_input import (
|
||||
parse_run_input,
|
||||
validate_run_request_messages_contract,
|
||||
|
||||
@@ -18,6 +18,17 @@ from core.logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_user_token_from_run_input(run_input: RunAgentInput) -> str | None:
|
||||
forwarded = run_input.forwarded_props
|
||||
if not isinstance(forwarded, dict):
|
||||
return None
|
||||
for key in ("accessToken", "userToken", "token"):
|
||||
value = forwarded.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskAccepted:
|
||||
task_id: str
|
||||
@@ -65,6 +76,10 @@ def ensure_session_owner(*, owner_id: str, current_user: CurrentUser) -> None:
|
||||
|
||||
|
||||
class AgentService:
|
||||
_repository: AgentRepositoryLike
|
||||
_queue: QueueClientLike
|
||||
_stream: EventStreamLike
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -107,6 +122,8 @@ class AgentService:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "run",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _extract_user_token_from_run_input(run_input),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=None,
|
||||
@@ -132,6 +149,8 @@ class AgentService:
|
||||
task_id = await self._queue.enqueue(
|
||||
command={
|
||||
"command": "resume",
|
||||
"owner_id": str(current_user.id),
|
||||
"user_token": _extract_user_token_from_run_input(run_input),
|
||||
"run_input": run_input.model_dump(mode="json", by_alias=True),
|
||||
},
|
||||
dedup_key=dedup_key,
|
||||
|
||||
@@ -32,6 +32,7 @@ class ScheduleItemMetadata(BaseModel):
|
||||
location: str | None = None
|
||||
notes: str | None = None
|
||||
attachments: list[ScheduleItemMetadataAttachment] = Field(default_factory=list)
|
||||
reminder_minutes: int | None = Field(default=None, ge=0, le=10080)
|
||||
version: Literal[1] = 1
|
||||
|
||||
|
||||
|
||||
@@ -135,14 +135,13 @@ class ScheduleItemService(BaseService):
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle metadata separately (model_dump returns dict)
|
||||
if "metadata" in update_data and update_data["metadata"] is not None:
|
||||
metadata_value = update_data["metadata"]
|
||||
if "metadata" in update_data:
|
||||
metadata_value = update_data.pop("metadata")
|
||||
update_data["extra_metadata"] = (
|
||||
metadata_value.model_dump()
|
||||
if hasattr(metadata_value, "model_dump")
|
||||
else metadata_value
|
||||
)
|
||||
del update_data["metadata"]
|
||||
|
||||
# Validate time range
|
||||
next_start = update_data.get("start_at", existing.start_at)
|
||||
|
||||
@@ -68,6 +68,8 @@ async def _invoke_tool(
|
||||
else:
|
||||
text = getattr(first, "text", None)
|
||||
assert isinstance(text, str)
|
||||
if text.startswith("Error:"):
|
||||
raise AssertionError(f"tool {tool_name} failed: {text}")
|
||||
payload = json.loads(text)
|
||||
assert isinstance(payload, dict)
|
||||
return payload
|
||||
@@ -101,40 +103,45 @@ class _SmokeRunner:
|
||||
|
||||
if stage_config.stage == "execution":
|
||||
assert toolkit is not None
|
||||
created = await _invoke_tool(
|
||||
toolkit,
|
||||
tool_name="calendar.write",
|
||||
tool_input={
|
||||
"operation": "create",
|
||||
"title": "agentscope smoke event",
|
||||
"description": "agentscope runtime smoke",
|
||||
"start_at": datetime.now(timezone.utc).isoformat(),
|
||||
"timezone": "Asia/Shanghai",
|
||||
},
|
||||
)
|
||||
created_data = created.get("data")
|
||||
assert isinstance(created_data, dict)
|
||||
created_id = created_data.get("id")
|
||||
assert isinstance(created_id, str) and created_id
|
||||
created_id: str | None = None
|
||||
items: list[object] = []
|
||||
try:
|
||||
created = await _invoke_tool(
|
||||
toolkit,
|
||||
tool_name="calendar.write",
|
||||
tool_input={
|
||||
"operation": "create",
|
||||
"title": "agentscope smoke event",
|
||||
"description": "agentscope runtime smoke",
|
||||
"start_at": datetime.now(timezone.utc).isoformat(),
|
||||
"timezone": "Asia/Shanghai",
|
||||
},
|
||||
)
|
||||
created_data = created.get("data")
|
||||
assert isinstance(created_data, dict)
|
||||
created_id = created_data.get("id")
|
||||
assert isinstance(created_id, str) and created_id
|
||||
|
||||
read_payload = await _invoke_tool(
|
||||
toolkit,
|
||||
tool_name="calendar.read",
|
||||
tool_input={"page": 1, "page_size": 10},
|
||||
)
|
||||
read_data = read_payload.get("data")
|
||||
assert isinstance(read_data, dict)
|
||||
items = read_data.get("items")
|
||||
assert isinstance(items, list)
|
||||
|
||||
deleted = await _invoke_tool(
|
||||
toolkit,
|
||||
tool_name="calendar.write",
|
||||
tool_input={"operation": "delete", "event_id": created_id},
|
||||
)
|
||||
deleted_data = deleted.get("data")
|
||||
assert isinstance(deleted_data, dict)
|
||||
assert deleted_data.get("ok") is True
|
||||
read_payload = await _invoke_tool(
|
||||
toolkit,
|
||||
tool_name="calendar.read",
|
||||
tool_input={"page": 1, "page_size": 10},
|
||||
)
|
||||
read_data = read_payload.get("data")
|
||||
assert isinstance(read_data, dict)
|
||||
parsed_items = read_data.get("items")
|
||||
assert isinstance(parsed_items, list)
|
||||
items = parsed_items
|
||||
finally:
|
||||
if created_id:
|
||||
deleted = await _invoke_tool(
|
||||
toolkit,
|
||||
tool_name="calendar.write",
|
||||
tool_input={"operation": "delete", "event_id": created_id},
|
||||
)
|
||||
deleted_data = deleted.get("data")
|
||||
assert isinstance(deleted_data, dict)
|
||||
assert deleted_data.get("ok") is True
|
||||
|
||||
return {
|
||||
"task_id": "smoke-task-1",
|
||||
|
||||
@@ -25,6 +25,8 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
|
||||
|
||||
async def create_agent_generated(self, payload):
|
||||
assert payload.title == "晨会"
|
||||
assert payload.metadata is not None
|
||||
assert payload.metadata.reminder_minutes == 15
|
||||
return SimpleNamespace(
|
||||
id=created_id,
|
||||
title="晨会",
|
||||
@@ -32,7 +34,11 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
|
||||
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
|
||||
end_at=datetime(2026, 3, 8, 2, 0, tzinfo=timezone.utc),
|
||||
timezone="Asia/Shanghai",
|
||||
metadata=SimpleNamespace(location="会议室A", color="#4F46E5"),
|
||||
metadata=SimpleNamespace(
|
||||
location="会议室A",
|
||||
color="#4F46E5",
|
||||
reminder_minutes=15,
|
||||
),
|
||||
)
|
||||
|
||||
class _FakeRepository:
|
||||
@@ -61,6 +67,7 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
|
||||
"endAt": "2026-03-08T10:00:00+08:00",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"location": "会议室A",
|
||||
"reminderMinutes": 15,
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -69,6 +76,77 @@ async def test_mutate_calendar_event_create_returns_calendar_card_v1(
|
||||
data = cast(dict[str, object], result["data"])
|
||||
assert data["id"] == str(created_id)
|
||||
assert data["ok"] is True
|
||||
assert data["reminderMinutes"] == 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mutate_calendar_event_update_maps_reminder_minutes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
event_id = uuid4()
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
del kwargs
|
||||
|
||||
async def get_by_id(self, item_id):
|
||||
assert item_id == event_id
|
||||
return SimpleNamespace(
|
||||
metadata=SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"color": "#4F46E5",
|
||||
"location": "会议室A",
|
||||
"version": 1,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def update(self, item_id, payload):
|
||||
assert item_id == event_id
|
||||
assert payload.metadata is not None
|
||||
assert payload.metadata.reminder_minutes == 30
|
||||
return SimpleNamespace(
|
||||
id=event_id,
|
||||
title="更新后",
|
||||
description=None,
|
||||
start_at=datetime(2026, 3, 8, 1, 0, tzinfo=timezone.utc),
|
||||
end_at=None,
|
||||
timezone="Asia/Shanghai",
|
||||
metadata=SimpleNamespace(
|
||||
location="会议室A",
|
||||
color="#4F46E5",
|
||||
reminder_minutes=30,
|
||||
),
|
||||
)
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self, session) -> None:
|
||||
del session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.tools.create_calendar_event_tool.ScheduleItemService",
|
||||
_FakeService,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.agent.infrastructure.crewai.tools.create_calendar_event_tool.SQLAlchemyScheduleItemRepository",
|
||||
_FakeRepository,
|
||||
)
|
||||
|
||||
result = cast(
|
||||
dict[str, object],
|
||||
await _execute_mutate_calendar_event(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
tool_args={
|
||||
"operation": "update",
|
||||
"eventId": str(event_id),
|
||||
"reminderMinutes": 30,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
data = cast(dict[str, object], result["data"])
|
||||
assert data["reminderMinutes"] == 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.events.agui_codec import to_agui_wire_event
|
||||
|
||||
|
||||
def test_maps_internal_text_delta_to_agui_wire_event() -> None:
|
||||
internal = {
|
||||
"id": "e1",
|
||||
"type": "text.delta",
|
||||
"threadId": "t1",
|
||||
"runId": "r1",
|
||||
"data": {"delta": "hel"},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert result["threadId"] == "t1"
|
||||
assert result["runId"] == "r1"
|
||||
assert result["delta"] == "hel"
|
||||
|
||||
|
||||
def test_reserved_keys_in_data_cannot_override_wire_fields() -> None:
|
||||
internal = {
|
||||
"id": "e2",
|
||||
"type": "run.started",
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-1",
|
||||
"data": {
|
||||
"type": "RUN_ERROR",
|
||||
"threadId": "thread-override",
|
||||
"runId": "run-override",
|
||||
"message": "ok",
|
||||
},
|
||||
}
|
||||
|
||||
result = to_agui_wire_event(internal)
|
||||
|
||||
assert result["type"] == "RUN_STARTED"
|
||||
assert result["threadId"] == "thread-1"
|
||||
assert result["runId"] == "run-1"
|
||||
assert result["message"] == "ok"
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agentscope.events.pipeline import AgentScopeEventPipeline
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_orders_codec_persist_publish() -> None:
|
||||
calls: list[str] = []
|
||||
|
||||
class _Codec:
|
||||
def to_wire(self, event: dict[str, object]) -> dict[str, object]:
|
||||
calls.append("codec")
|
||||
return {"type": "RUN_STARTED", **event}
|
||||
|
||||
class _Store:
|
||||
async def persist(self, event: dict[str, object]) -> None:
|
||||
calls.append("persist")
|
||||
assert event["type"] == "RUN_STARTED"
|
||||
|
||||
class _Bus:
|
||||
async def publish(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
calls.append("publish")
|
||||
assert session_id == "thread-1"
|
||||
return "1-0"
|
||||
|
||||
pipeline = AgentScopeEventPipeline(codec=_Codec(), store=_Store(), bus=_Bus())
|
||||
cursor = await pipeline.emit(session_id="thread-1", event={"id": "evt-1"})
|
||||
|
||||
assert cursor == "1-0"
|
||||
assert calls == ["codec", "persist", "publish"]
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.agentscope.events.redis_bus import RedisStreamBus
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self._rows: list[tuple[str, str]] = []
|
||||
|
||||
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
|
||||
cursor = f"{len(self._rows) + 1}-0"
|
||||
self._rows.append((cursor, fields["event"]))
|
||||
return cursor
|
||||
|
||||
def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, str]]]]]:
|
||||
del count, block
|
||||
stream_name, last = next(iter(streams.items()))
|
||||
rows: list[tuple[str, dict[str, str]]] = []
|
||||
for cursor, payload in self._rows:
|
||||
if cursor > last:
|
||||
rows.append((cursor, {"event": payload}))
|
||||
return [(stream_name, rows)]
|
||||
|
||||
|
||||
class _FakeRedisBytes:
|
||||
def __init__(self) -> None:
|
||||
self._rows: list[tuple[str, str]] = []
|
||||
|
||||
def xadd(self, _stream: str, fields: dict[str, str]) -> str:
|
||||
cursor = f"{len(self._rows) + 1}-0"
|
||||
self._rows.append((cursor, fields["event"]))
|
||||
return cursor
|
||||
|
||||
def xread(
|
||||
self,
|
||||
streams: dict[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[str, dict[str, bytes]]]]]:
|
||||
del count, block
|
||||
stream_name, last = next(iter(streams.items()))
|
||||
rows: list[tuple[str, dict[str, bytes]]] = []
|
||||
for cursor, payload in self._rows:
|
||||
if cursor > last:
|
||||
rows.append((cursor, {"event": payload.encode("utf-8")}))
|
||||
return [(stream_name, rows)]
|
||||
|
||||
|
||||
async def test_publish_then_read_after_cursor() -> None:
|
||||
bus = RedisStreamBus(client=_FakeRedis(), stream_prefix="agent.events")
|
||||
|
||||
first_cursor = await bus.publish(
|
||||
session_id="thread-1", event={"type": "RUN_STARTED"}
|
||||
)
|
||||
await bus.publish(session_id="thread-1", event={"type": "RUN_FINISHED"})
|
||||
|
||||
rows = await bus.read(session_id="thread-1", last_event_id=first_cursor)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["event"]["type"] == "RUN_FINISHED"
|
||||
|
||||
|
||||
async def test_read_supports_bytes_payload() -> None:
|
||||
bus = RedisStreamBus(client=_FakeRedisBytes(), stream_prefix="agent.events")
|
||||
await bus.publish(session_id="thread-1", event={"type": "RUN_STARTED"})
|
||||
rows = await bus.read(session_id="thread-1", last_event_id=None)
|
||||
assert rows[0]["event"]["type"] == "RUN_STARTED"
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from core.agentscope.events.sse import to_sse_event
|
||||
|
||||
|
||||
def test_sse_frame_contains_event_and_json_payload() -> None:
|
||||
payload = {"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"}
|
||||
|
||||
frame = to_sse_event("1-0", payload)
|
||||
|
||||
assert frame.startswith("id: 1-0\n")
|
||||
assert "event: RUN_STARTED\n" in frame
|
||||
assert frame.endswith("\n\n")
|
||||
|
||||
data_line = [line for line in frame.splitlines() if line.startswith("data: ")][0]
|
||||
parsed = json.loads(data_line[len("data: ") :])
|
||||
assert parsed["threadId"] == "t1"
|
||||
|
||||
|
||||
def test_sse_sanitizes_stream_id_newlines() -> None:
|
||||
payload = {"type": "RUN_STARTED"}
|
||||
frame = to_sse_event("1-0\nmalicious: yes", payload)
|
||||
assert frame.startswith("id: 1-0malicious: yes\n")
|
||||
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.agent.domain.user_context import UserAgentContext, parse_profile_settings
|
||||
from core.agentscope.runtime.agent_route_runtime import AgentRouteRuntime
|
||||
from core.agentscope.schemas import ReportOutput, RuntimeOutput
|
||||
from core.agentscope.schemas.agent_runtime import RunCommand
|
||||
from core.agentscope.schemas.execution import ExecutionBatchOutput
|
||||
from core.agentscope.schemas.intent import IntentOutput
|
||||
|
||||
|
||||
def _user_context() -> UserAgentContext:
|
||||
return UserAgentContext(
|
||||
user_id=uuid4(),
|
||||
username="tester",
|
||||
bio=None,
|
||||
settings=parse_profile_settings(
|
||||
{
|
||||
"version": 1,
|
||||
"preferences": {
|
||||
"interface_language": "zh-CN",
|
||||
"ai_language": "zh-CN",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"country": "CN",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_emits_started_text_and_finished_events() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FakeOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
return RuntimeOutput(
|
||||
intent=IntentOutput(
|
||||
route="DIRECT_RESPONSE",
|
||||
intent_summary="summary",
|
||||
direct_response="done",
|
||||
tasks=[],
|
||||
complexity="simple",
|
||||
),
|
||||
execution=ExecutionBatchOutput(
|
||||
task_results=[],
|
||||
overall_status="SUCCESS",
|
||||
aggregate_summary="ok",
|
||||
),
|
||||
report=ReportOutput(
|
||||
assistant_text="hello world",
|
||||
response_metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FakeOrchestrator(), pipeline=_FakePipeline()
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"step.finish",
|
||||
"step.start",
|
||||
"text.start",
|
||||
"text.delta",
|
||||
"text.end",
|
||||
"step.finish",
|
||||
"run.finished",
|
||||
]
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["stepName"] == "intent"
|
||||
assert calls[3]["data"]["stepName"] == "execution"
|
||||
assert calls[4]["data"]["stepName"] == "execution"
|
||||
assert calls[5]["data"]["stepName"] == "report"
|
||||
assert calls[7]["data"]["delta"] == "hello world"
|
||||
assert calls[6]["data"]["messageId"] == calls[7]["data"]["messageId"]
|
||||
assert calls[7]["data"]["messageId"] == calls[8]["data"]["messageId"]
|
||||
assert calls[9]["data"]["stepName"] == "report"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_emits_run_error_when_orchestrator_fails() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
class _FakePipeline:
|
||||
async def emit(self, *, session_id: str, event: dict[str, object]) -> str:
|
||||
assert session_id == "thread-1"
|
||||
calls.append(event)
|
||||
return f"{len(calls)}-0"
|
||||
|
||||
class _FailOrchestrator:
|
||||
async def run(self, **_: object) -> RuntimeOutput:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
runtime = AgentRouteRuntime(
|
||||
orchestrator=_FailOrchestrator(),
|
||||
pipeline=_FakePipeline(),
|
||||
)
|
||||
command = RunCommand(threadId="thread-1", runId="run-1", messages=[])
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await runtime.run(
|
||||
command=command,
|
||||
owner_id=uuid4(),
|
||||
user_token="token",
|
||||
user_context=_user_context(),
|
||||
session=cast(AsyncSession, object()),
|
||||
)
|
||||
|
||||
assert [item["type"] for item in calls] == [
|
||||
"run.started",
|
||||
"step.start",
|
||||
"run.error",
|
||||
]
|
||||
assert calls[1]["data"]["stepName"] == "intent"
|
||||
assert calls[2]["data"]["message"] == "runtime execution failed"
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import core.agentscope.runtime.tasks as tasks_module
|
||||
|
||||
|
||||
def _run_input_payload() -> dict[str, Any]:
|
||||
return {
|
||||
"threadId": str(uuid4()),
|
||||
"runId": "run-1",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": {},
|
||||
"forwardedProps": {},
|
||||
}
|
||||
|
||||
|
||||
class _FakeSessionCtx:
|
||||
async def __aenter__(self) -> object:
|
||||
return object()
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_calls_runtime_run(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
called: dict[str, int] = {"run": 0, "resume": 0}
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
called["run"] += 1
|
||||
return object()
|
||||
|
||||
async def resume(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
called["resume"] += 1
|
||||
return object()
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"get_or_init_redis_client",
|
||||
_fake_get_redis_client,
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
|
||||
result = await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": _run_input_payload(),
|
||||
}
|
||||
)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert called["run"] == 1
|
||||
assert called["resume"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_calls_runtime_resume(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
called: dict[str, int] = {"run": 0, "resume": 0}
|
||||
|
||||
class _FakeRuntime:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
del kwargs
|
||||
|
||||
async def run(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
called["run"] += 1
|
||||
return object()
|
||||
|
||||
async def resume(self, **kwargs: object) -> object:
|
||||
del kwargs
|
||||
called["resume"] += 1
|
||||
return object()
|
||||
|
||||
async def _fake_get_redis_client() -> object:
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(tasks_module, "AgentRouteRuntime", _FakeRuntime)
|
||||
monkeypatch.setattr(
|
||||
tasks_module,
|
||||
"get_or_init_redis_client",
|
||||
_fake_get_redis_client,
|
||||
)
|
||||
monkeypatch.setattr(tasks_module, "AsyncSessionLocal", lambda: _FakeSessionCtx())
|
||||
|
||||
result = await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "resume",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": _run_input_payload(),
|
||||
}
|
||||
)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert called["run"] == 0
|
||||
assert called["resume"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_requires_owner_id() -> None:
|
||||
with pytest.raises(ValueError, match="owner_id is required"):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "run",
|
||||
"run_input": _run_input_payload(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agentscope_task_rejects_invalid_command_type() -> None:
|
||||
with pytest.raises(ValueError, match="invalid command type"):
|
||||
await tasks_module.run_agentscope_task(
|
||||
{
|
||||
"command": "unknown",
|
||||
"owner_id": str(uuid4()),
|
||||
"run_input": _run_input_payload(),
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agentscope import schemas as exported_schemas
|
||||
from core.agentscope.schemas.agent_runtime import (
|
||||
AcceptedTaskResponse,
|
||||
AgUiWireEvent,
|
||||
HistorySnapshot,
|
||||
HistorySnapshotResponse,
|
||||
InternalRuntimeEvent,
|
||||
ResumeCommand,
|
||||
RunCommand,
|
||||
)
|
||||
|
||||
|
||||
def test_run_command_alias_roundtrip() -> None:
|
||||
payload = {
|
||||
"threadId": "thread-001",
|
||||
"runId": "run-001",
|
||||
"state": {"cursor": 1},
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [{"name": "calendar.lookup"}],
|
||||
"context": {"locale": "zh-CN"},
|
||||
"forwardedProps": {"traceId": "trace-1"},
|
||||
}
|
||||
|
||||
command = RunCommand.model_validate(payload)
|
||||
|
||||
assert command.thread_id == "thread-001"
|
||||
assert command.run_id == "run-001"
|
||||
assert command.forwarded_props == {"traceId": "trace-1"}
|
||||
|
||||
dumped = command.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["threadId"] == "thread-001"
|
||||
assert dumped["runId"] == "run-001"
|
||||
assert dumped["forwardedProps"] == {"traceId": "trace-1"}
|
||||
|
||||
|
||||
def test_history_snapshot_response_shape() -> None:
|
||||
response = HistorySnapshotResponse(
|
||||
threadId="thread-123",
|
||||
snapshot=HistorySnapshot(
|
||||
threadId="thread-123",
|
||||
day="2026-03-11",
|
||||
hasMore=False,
|
||||
messages=[{"id": "msg-1"}],
|
||||
),
|
||||
)
|
||||
|
||||
dumped = response.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
assert dumped["type"] == "STATE_SNAPSHOT"
|
||||
assert dumped["threadId"] == "thread-123"
|
||||
assert dumped["snapshot"]["scope"] == "history_day"
|
||||
assert dumped["snapshot"]["hasMore"] is False
|
||||
assert dumped["snapshot"]["messages"] == [{"id": "msg-1"}]
|
||||
|
||||
|
||||
def test_runtime_event_validation_basics() -> None:
|
||||
internal = InternalRuntimeEvent(type="RUN_STARTED", data={"step": 1})
|
||||
assert internal.type == "RUN_STARTED"
|
||||
assert internal.model_dump(mode="json", by_alias=True)["data"] == {"step": 1}
|
||||
|
||||
wire = AgUiWireEvent(type="TEXT_MESSAGE_CONTENT", payload={"delta": "hello"})
|
||||
dumped = wire.model_dump(mode="json", by_alias=True, exclude_none=True)
|
||||
assert dumped["type"] == "TEXT_MESSAGE_CONTENT"
|
||||
assert dumped["payload"] == {"delta": "hello"}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
InternalRuntimeEvent.model_validate({"threadId": "t-1", "data": {}})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AgUiWireEvent.model_validate({"payload": {"delta": "hello"}})
|
||||
|
||||
|
||||
def test_task_response_and_resume_aliases() -> None:
|
||||
accepted = AcceptedTaskResponse(
|
||||
taskId="task-1",
|
||||
threadId="thread-1",
|
||||
runId="run-1",
|
||||
created=False,
|
||||
)
|
||||
dumped = accepted.model_dump(mode="json", by_alias=True)
|
||||
assert dumped["taskId"] == "task-1"
|
||||
assert dumped["threadId"] == "thread-1"
|
||||
assert dumped["runId"] == "run-1"
|
||||
|
||||
resumed = ResumeCommand.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"runId": "run-2",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"context": {},
|
||||
}
|
||||
)
|
||||
assert resumed.thread_id == "thread-1"
|
||||
assert resumed.run_id == "run-2"
|
||||
|
||||
|
||||
def test_schemas_exports_include_task_and_history_models() -> None:
|
||||
assert exported_schemas.AcceptedTaskResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAccepted is AcceptedTaskResponse
|
||||
assert exported_schemas.TaskAcceptedResponse is AcceptedTaskResponse
|
||||
assert exported_schemas.HistorySnapshotResponse is HistorySnapshotResponse
|
||||
@@ -131,3 +131,50 @@ async def test_calendar_write_rejects_event_id_for_create(
|
||||
|
||||
assert result["data"]["ok"] is False
|
||||
assert result["data"]["code"] == "INVALID_ARGUMENT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_write_maps_reminder_minutes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def _fake_execute(**kwargs: Any) -> dict[str, object]:
|
||||
captured.update(cast(dict[str, object], kwargs["tool_args"]))
|
||||
return {"type": "calendar_card.v1", "version": "v1", "data": {"ok": True}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
calendar_module,
|
||||
"_execute_mutate_calendar_event",
|
||||
_fake_execute,
|
||||
)
|
||||
monkeypatch.setattr(calendar_module, "_verify_user_token", lambda **_: True)
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
await calendar_module.calendar_write(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-abc",
|
||||
operation="create",
|
||||
reminder_minutes=15,
|
||||
)
|
||||
|
||||
assert captured["reminderMinutes"] == 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_write_rejects_invalid_reminder_minutes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(calendar_module, "build_tool_response", lambda payload: payload)
|
||||
|
||||
result = await calendar_module.calendar_write(
|
||||
session=cast(AsyncSession, SimpleNamespace()),
|
||||
owner_id=uuid4(),
|
||||
user_token="token-abc",
|
||||
operation="create",
|
||||
reminder_minutes=10081,
|
||||
)
|
||||
|
||||
assert result["data"]["ok"] is False
|
||||
assert result["data"]["code"] == "INVALID_ARGUMENT"
|
||||
|
||||
@@ -103,6 +103,18 @@ def test_metadata_rejects_unknown_field() -> None:
|
||||
ScheduleItemMetadata.model_validate({"color": "#FF6B6B", "unknown": True})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", [None, 0, 15, 10080])
|
||||
def test_metadata_accepts_reminder_minutes(value: int | None) -> None:
|
||||
metadata = ScheduleItemMetadata(reminder_minutes=value)
|
||||
assert metadata.reminder_minutes == value
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", [-1, 10081])
|
||||
def test_metadata_rejects_out_of_range_reminder_minutes(value: int) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleItemMetadata(reminder_minutes=value)
|
||||
|
||||
|
||||
def test_metadata_attachment_rejects_unknown_field() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleItemMetadataAttachment.model_validate(
|
||||
|
||||
@@ -221,7 +221,12 @@ async def test_create_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
|
||||
request = ScheduleItemCreateRequest(
|
||||
title="Roadmap",
|
||||
start_at=datetime(2026, 2, 28, 16, 0, 0, tzinfo=timezone.utc),
|
||||
metadata=ScheduleItemMetadata(location="会议室A", color="#4F46E5", version=1),
|
||||
metadata=ScheduleItemMetadata(
|
||||
location="会议室A",
|
||||
color="#4F46E5",
|
||||
reminder_minutes=15,
|
||||
version=1,
|
||||
),
|
||||
)
|
||||
service = ScheduleItemService(
|
||||
repository=CaptureRepo(None),
|
||||
@@ -234,6 +239,7 @@ async def test_create_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
|
||||
assert captured is not None
|
||||
assert "extra_metadata" in captured
|
||||
assert captured["extra_metadata"]["location"] == "会议室A"
|
||||
assert captured["extra_metadata"]["reminder_minutes"] == 15
|
||||
assert "metadata" not in captured
|
||||
|
||||
|
||||
@@ -261,7 +267,10 @@ async def test_update_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
|
||||
item.id,
|
||||
ScheduleItemUpdateRequest(
|
||||
metadata=ScheduleItemMetadata(
|
||||
location="线上会议", color="#3B82F6", version=1
|
||||
location="线上会议",
|
||||
color="#3B82F6",
|
||||
reminder_minutes=30,
|
||||
version=1,
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -269,4 +278,38 @@ async def test_update_maps_metadata_to_extra_metadata(mock_session: AsyncMock) -
|
||||
assert captured is not None
|
||||
assert "extra_metadata" in captured
|
||||
assert captured["extra_metadata"]["location"] == "线上会议"
|
||||
assert captured["extra_metadata"]["reminder_minutes"] == 30
|
||||
assert "metadata" not in captured
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_maps_null_metadata_to_extra_metadata_null(
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
item = _create_mock_schedule_item()
|
||||
captured: dict | None = None
|
||||
|
||||
class CaptureRepo(FakeRepo):
|
||||
async def update_by_item_id(
|
||||
self, item_id: UUID, owner_id: UUID, data: dict
|
||||
) -> ScheduleItem | None:
|
||||
nonlocal captured
|
||||
captured = data
|
||||
return await super().update_by_item_id(item_id, owner_id, data)
|
||||
|
||||
service = ScheduleItemService(
|
||||
repository=CaptureRepo(item),
|
||||
session=mock_session,
|
||||
current_user=CurrentUser(id=user_id),
|
||||
)
|
||||
|
||||
await service.update(
|
||||
item.id,
|
||||
ScheduleItemUpdateRequest(metadata=None),
|
||||
)
|
||||
|
||||
assert captured is not None
|
||||
assert "extra_metadata" in captured
|
||||
assert captured["extra_metadata"] is None
|
||||
assert "metadata" not in captured
|
||||
|
||||
Reference in New Issue
Block a user