feat: 重构 Reminder Notification 系统并更新应用包名

This commit is contained in:
qzl
2026-03-30 18:36:57 +08:00
parent 9fb2a6857b
commit 91bf3c3f96
90 changed files with 5133 additions and 3017 deletions
@@ -346,7 +346,7 @@ class AgentScopeRunner:
*messages_for_router,
],
output_model=RouterAgentOutput,
retries=0,
retries=3,
)
response_msg = Msg(
name="router",
@@ -1,3 +1,4 @@
import json
from typing import Annotated, Any, Literal, cast
from uuid import UUID
@@ -21,6 +22,7 @@ from core.agentscope.tools.tool_call_context import get_current_tool_call_id
from schemas.agent.runtime_models import ErrorInfo, ToolAgentOutput, ToolStatus
from v1.schedule_items.schemas import (
ScheduleItemCreateRequest,
ScheduleItemListRequest,
ScheduleItemShareRequest,
ScheduleItemStatus,
ScheduleItemUpdateRequest,
@@ -98,6 +100,13 @@ class CalendarWriteBatchArgs(BaseModel):
operations: list[CalendarWriteOperation] = Field(min_length=1, max_length=20)
class CalendarShareArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
event_id: str
invitees: list[CalendarShareInvitee] = Field(min_length=1)
def _validate_runtime_context(
*,
tool_name: str,
@@ -116,69 +125,60 @@ def _validate_runtime_context(
return None
def _format_event_brief(event_items: list[dict[str, Any]], limit: int = 3) -> str:
briefs: list[str] = []
for item in event_items[:limit]:
event_id = str(item.get("id") or "")
title = str(item.get("title") or "")
start_at = str(item.get("startAt") or "")
end_at = str(item.get("endAt") or "")
timezone = str(item.get("timezone") or "")
status = str(item.get("status") or "")
description = str(item.get("description") or "")
location = str(item.get("location") or "")
reminder_minutes = item.get("reminderMinutes")
color = str(item.get("color") or "")
source_type = str(item.get("sourceType") or "")
updated_at = str(item.get("updatedAt") or "")
permission = item.get("permission")
is_owner = item.get("isOwner")
if not event_id:
continue
briefs.append(
"{"
f"id={event_id},title={title},startAt={start_at},endAt={end_at},"
f"timezone={timezone},status={status},description={description},"
f"location={location},reminderMinutes={reminder_minutes},color={color},"
f"sourceType={source_type},updatedAt={updated_at},permission={permission},"
f"isOwner={is_owner}"
"}"
)
return ",".join(briefs)
async def calendar_read(
query: Annotated[
str | None,
Field(description="Optional keyword to filter calendar events."),
] = None,
page: Annotated[
int,
Field(description="Page number, starting from 1.", ge=1),
] = 1,
page_size: Annotated[
int,
Field(description="Number of items per page (1-100).", ge=1, le=100),
] = 20,
start_at: Annotated[
str,
Field(
description="Start of date range in ISO8601 with timezone, e.g. 2026-03-30T00:00:00+08:00."
),
],
end_at: Annotated[
str,
Field(
description="End of date range in ISO8601 with timezone, e.g. 2026-03-30T23:59:59+08:00."
),
],
session: Any = None,
owner_id: Any = None,
) -> ToolResponse:
"""Read calendar events with optional keyword filtering and pagination.
"""Read calendar events within a date range.
Status semantics for returned events:
- active: Event is actionable.
- archived: Event is historical/expired and should not trigger reminders.
Returns subscribed calendar events (owned or shared) with permission info.
Status: active=actionable, archived=past/expired.
Permission flags: is_owner, can_view, can_edit, can_invite, can_delete.
Args:
query: Optional keyword used to filter events by text fields.
page: Page number starting from 1.
page_size: Number of items per page, between 1 and 100.
start_at: Start of date range (required).
end_at: End of date range (required).
Returns:
ToolResponse with serialized ToolAgentOutput payload.
ToolResponse with JSON result:
{
"total": int,
"items": [{
"id": "uuid",
"owner_id": "uuid",
"title": "string",
"description": "string|null",
"start_at": "ISO8601 datetime",
"end_at": "ISO8601 datetime|null",
"timezone": "IANA timezone",
"status": "active|archived",
"source_type": "manual|imported|agent_generated",
"permission": {"can_view", "can_edit", "can_invite", "can_delete", "is_owner"},
"is_owner": boolean,
"metadata": {color, location, reminder_minutes}|null,
"subscribers": [{user_id, username, phone, permission, status}],
"created_at": "ISO8601 datetime",
"updated_at": "ISO8601 datetime"
}]
}
"""
tool_name = "calendar_read"
tool_call_args = {"query": query, "page": page, "page_size": page_size}
tool_call_args: dict[str, Any] = {"start_at": start_at, "end_at": end_at}
runtime_error = _validate_runtime_context(
tool_name=tool_name,
tool_call_args=tool_call_args,
@@ -189,30 +189,30 @@ async def calendar_read(
return runtime_error
try:
parsed_start = parse_iso_datetime(start_at)
parsed_end = parse_iso_datetime(end_at)
if parsed_start is None or parsed_end is None:
raise ValueError("start_at 和 end_at 都是必填项")
if parsed_start >= parsed_end:
raise ValueError("start_at 必须早于 end_at")
service = create_schedule_service(
cast(AsyncSession, session), cast(UUID, owner_id)
)
items, total = await service.list_paginated(
page=page,
page_size=page_size,
query=query,
)
total_pages = (total + page_size - 1) // page_size if total else 0
request = ScheduleItemListRequest(start_at=parsed_start, end_at=parsed_end)
items = await service.list_by_date_range(request)
event_items = [schedule_event_to_dict(item) for item in items]
event_brief = _format_event_brief(event_items)
summary = (
f"status=success total={total} total_pages={total_pages or 1} "
f"returned={len(event_items)} has_next={str(page < (total_pages or 1)).lower()}"
result = json.dumps(
{"total": len(event_items), "items": event_items},
ensure_ascii=False,
)
if event_brief:
summary = f"{summary} items=[{event_brief}]"
return dump_tool_output(
ToolAgentOutput(
tool_name=tool_name,
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
tool_call_args=tool_call_args,
status=ToolStatus.SUCCESS,
result=summary,
result=result,
)
)
except Exception as exc:
@@ -532,10 +532,25 @@ async def calendar_share(
ToolResponse with serialized ToolAgentOutput payload.
"""
tool_name = "calendar_share"
try:
parsed_args = CalendarShareArgs.model_validate(
{"event_id": event_id, "invitees": invitees}
)
except Exception as exc: # noqa: BLE001
code, message, retryable = map_calendar_exception(exc)
return calendar_error_output(
tool_name=tool_name,
tool_call_args={"event_id": event_id, "invitees": invitees},
code=code,
message=message,
retryable=retryable,
)
tool_call_args = {
"event_id": event_id,
"event_id": parsed_args.event_id,
"invitees": [
invitee.model_dump(mode="json", by_alias=True) for invitee in invitees
invitee.model_dump(mode="json", by_alias=True)
for invitee in parsed_args.invitees
],
}
runtime_error = _validate_runtime_context(
@@ -551,11 +566,11 @@ async def calendar_share(
service = create_schedule_service(
cast(AsyncSession, session), cast(UUID, owner_id)
)
target_uuid = UUID(event_id)
target_uuid = UUID(parsed_args.event_id)
invited: list[str] = []
result_items: list[dict[str, str]] = []
for invitee in invitees:
for invitee in parsed_args.invitees:
raw_phone = invitee.phone.strip()
normalized_phone = raw_phone
for separator in (" ", "-", "(", ")"):
@@ -4,6 +4,7 @@ import re
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from zoneinfo import ZoneInfo
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
@@ -12,8 +13,9 @@ from core.auth.models import CurrentUser
from core.http.errors import ApiProblemError
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
from v1.schedule_items.schemas import ScheduleItemMetadata
from v1.schedule_items.schemas import ScheduleItemMetadata, parse_permission
from v1.schedule_items.service import ScheduleItemService
from v1.users.repository import SQLAlchemyUserRepository
_HEX_COLOR_PATTERN = re.compile(r"^#[0-9A-Fa-f]{6}$")
@@ -39,31 +41,66 @@ def create_schedule_service(
session=session,
current_user=CurrentUser(id=owner_id),
inbox_repository=SQLAlchemyInboxMessageRepository(session),
user_repository=SQLAlchemyUserRepository(session),
)
def _convert_to_event_timezone(dt: datetime, event_timezone: str) -> datetime:
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
tz = ZoneInfo(event_timezone) if event_timezone else ZoneInfo("UTC")
return dt.astimezone(tz)
def schedule_event_to_dict(event: object) -> dict[str, Any]:
event_id = str(getattr(event, "id"))
event_timezone = str(getattr(event, "timezone") or "UTC")
start_at_utc = getattr(event, "start_at")
end_at_utc = getattr(event, "end_at")
permission_int = getattr(event, "permission", 1)
is_owner = getattr(event, "is_owner", permission_int == 15)
metadata = getattr(event, "metadata", None)
subscribers = getattr(event, "subscribers", []) or []
def _serialize_dt(dt: datetime | None) -> str | None:
if dt is None:
return None
return _convert_to_event_timezone(dt, event_timezone).isoformat()
def _serialize_subscriber(sub: object) -> dict[str, Any]:
return {
"user_id": str(getattr(sub, "user_id", "")),
"username": getattr(sub, "username", None),
"avatar_url": getattr(sub, "avatar_url", None),
"phone": getattr(sub, "phone", None),
"permission": getattr(sub, "permission", 1),
"status": str(getattr(sub, "status", "active")),
"subscribed_at": _serialize_dt(getattr(sub, "subscribed_at", None)),
}
status_value = getattr(event, "status", None)
if hasattr(status_value, "value"):
status_value = getattr(status_value, "value")
location_value = getattr(metadata, "location", None)
color_value = getattr(metadata, "color", None) or "#4F46E5"
reminder_minutes_value = getattr(metadata, "reminder_minutes", None)
if status_value is not None and hasattr(status_value, "value"):
status_value = status_value.value
source_type_value = getattr(event, "source_type", None)
if source_type_value is not None and hasattr(source_type_value, "value"):
source_type_value = source_type_value.value
return {
"id": event_id,
"title": getattr(event, "title"),
"description": getattr(event, "description"),
"startAt": getattr(event, "start_at").isoformat(),
"endAt": getattr(event, "end_at").isoformat()
if getattr(event, "end_at") is not None
else None,
"timezone": getattr(event, "timezone"),
"id": str(getattr(event, "id", "")),
"owner_id": str(getattr(event, "owner_id", "")),
"title": getattr(event, "title", ""),
"description": getattr(event, "description", None),
"start_at": _serialize_dt(start_at_utc),
"end_at": _serialize_dt(end_at_utc),
"timezone": event_timezone,
"metadata": metadata.model_dump(mode="json") if metadata else None,
"status": status_value,
"location": location_value,
"color": color_value,
"reminderMinutes": reminder_minutes_value,
"source_type": source_type_value,
"created_at": _serialize_dt(getattr(event, "created_at", None)),
"updated_at": _serialize_dt(getattr(event, "updated_at", None)),
"permission": parse_permission(permission_int),
"is_owner": is_owner,
"subscribers": [_serialize_subscriber(sub) for sub in subscribers],
}
@@ -1,33 +1,34 @@
input_template: |
你正在执行一次自动化记忆回顾与整理任务。
你正在执行一次"自动化记忆回顾与整理"任务。
任务目标:
1) 回顾最近两天的聊天与上下文,识别用户长期偏好、习惯和关键事实的变化。
2) 对已经失效、被否定或明显过期的信息执行遗忘。
3) 对新增且有证据支持的信息执行写入。
4) 严禁编造;没有证据就不要写入。
5) 只更新最小必要字段,避免过度覆盖。
任务目标:
1) 回顾最近两天的聊天与上下文,识别用户长期偏好、习惯和关键事实的变化。
2) 对已经失效、被否定或明显过期的信息执行遗忘。
3) 对新增且有证据支持的信息执行写入。
4) 严禁编造;没有证据就不要写入。
5) 只更新最小必要字段,避免过度覆盖。
输出要求:
- 必须使用以下固定格式输出;每一行都要有
【记忆回顾】<一句人性化总结,说明今天主要发生了什么>
【新增记忆】<按“X条:要点1;要点2”描述;没有则写“0条”>
【遗忘记忆】<按X条:要点1;要点2描述;没有则写0条>
【未来展望】<基于本次记忆变化,给出1-2条温和、可执行的后续建议;若暂无建议则说明“可继续观察”>
输出要求:
- 必须使用以下固定格式输出:
<----------【周期任务输出】---------->
【记忆回顾】<一句人性化总结,说明今天主要发生了什么>
【新增记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
【遗忘记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
【未来展望】<基于本次记忆变化,给出1-2条温和、可执行的后续建议;若暂无建议则说明"可继续观察">
表达风格:
- 语言自然、温和、可读,像助理在做每日回顾。
- 结论先行,避免空话,不要输出与任务无关的闲聊内容。
表达风格:
- 语言自然、温和、可读,像助理在做每日回顾。
- 结论先行,避免空话,不要输出与任务无关的闲聊内容。
enabled_tools:
- memory.write
- memory.forget
- memory.write
- memory.forget
context:
source: latest_chat
window_mode: day
window_count: 2
source: latest_chat
window_mode: day
window_count: 2
schedule:
type: daily
run_at:
hour: 8
minute: 0
weekdays: null
type: daily
run_at:
hour: 8
minute: 0
weekdays: null
+12 -4
View File
@@ -5,7 +5,7 @@ import subprocess
import sys
from pathlib import Path
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from core.automation.scheduler import run_automation_scheduler_scan
from core.config.initial.init_data import initialize_data
@@ -118,22 +118,30 @@ async def run_automation_scheduler_forever() -> None:
batch_limit=batch_limit,
)
def scan_job() -> None:
async def scan_job() -> None:
try:
asyncio.run(run_automation_scheduler_scan(limit=batch_limit))
await run_automation_scheduler_scan(limit=batch_limit)
except Exception as exc:
logger.exception("Automation scheduler scan failed", error=str(exc))
scheduler = BlockingScheduler()
scheduler = AsyncIOScheduler()
scheduler.add_job(
scan_job,
trigger=IntervalTrigger(seconds=interval_seconds),
id="automation_scheduler_scan",
name="Automation scheduler scan",
replace_existing=True,
max_instances=1,
coalesce=True,
)
scheduler.start()
stop_event = asyncio.Event()
try:
await stop_event.wait()
finally:
scheduler.shutdown(wait=False)
def main() -> int:
"""CLI entry point."""
@@ -100,6 +100,13 @@ class ConstraintItem(BaseModel):
value: str
required: bool = True
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: object) -> object:
if isinstance(value, bool | int | float):
return str(value)
return value
class NormalizedTaskInput(BaseModel):
model_config = ConfigDict(extra="forbid")
+14
View File
@@ -211,6 +211,20 @@ class UiHintListItem(UiHintBaseModel):
status: UiHintStatus | None = Field(default=None)
actions: list[UiHintAction] = Field(default_factory=list)
@field_validator("status", mode="before")
@classmethod
def normalize_status(cls, value: object) -> object:
if value is None:
return None
if isinstance(value, dict):
status_type = value.get("type")
if isinstance(status_type, str):
return status_type
status_value = value.get("status")
if isinstance(status_value, str):
return status_value
return value
class UiHintSection(UiHintBaseModel):
title: str | None = Field(default=None, description="Section title.")
@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from core.auth.models import CurrentUser
from core.db import get_db
from v1.auth.gateway import SupabaseAuthGateway
from v1.friendships.repository import SQLAlchemyFriendshipRepository
from v1.friendships.service import FriendshipService
from v1.users.dependencies import get_current_user
@@ -25,9 +26,11 @@ async def get_friendship_service(
) -> FriendshipService:
friendship_repository = SQLAlchemyFriendshipRepository(session)
user_repository = SQLAlchemyUserRepository(session)
auth_gateway = SupabaseAuthGateway()
return FriendshipService(
repository=friendship_repository,
user_repository=user_repository,
session=session,
current_user=current_user,
auth_gateway=auth_gateway,
)
+91 -17
View File
@@ -9,10 +9,17 @@ from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.http.errors import ApiProblemError, problem_payload
from v1.inbox_messages.realtime import (
InboxMessageEventSnapshot,
publish_inbox_message_created,
publish_inbox_message_status_changed,
snapshot_from_inbox_message,
)
from core.logging import get_logger
from models.friendships import Friendship
from models.inbox_messages import InboxMessage
from schemas.enums import FriendshipStatus, InboxMessageStatus, InboxMessageType
from v1.auth.gateway import SupabaseAuthGateway
from v1.friendships.repository import FriendshipRepository
from v1.friendships.schemas import (
FriendRequestCreate,
@@ -42,6 +49,7 @@ class FriendshipService(BaseService):
_repository: FriendshipRepository
_user_repository: UserRepository
_session: AsyncSession
_auth_gateway: SupabaseAuthGateway | None
def __init__(
self,
@@ -49,11 +57,13 @@ class FriendshipService(BaseService):
user_repository: UserRepository,
session: AsyncSession,
current_user: CurrentUser | None,
auth_gateway: SupabaseAuthGateway | None = None,
) -> None:
super().__init__(current_user=current_user)
self._repository = repository
self._user_repository = user_repository
self._session = session
self._auth_gateway = auth_gateway
async def send_request(self, request: FriendRequestCreate) -> FriendRequestResponse:
user_id = self.require_user_id()
@@ -103,6 +113,8 @@ class FriendshipService(BaseService):
friendship, inbox = await self._repository.reactivate_request(
existing, user_id, request.content
)
await self._session.flush()
inbox_event = snapshot_from_inbox_message(inbox)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
@@ -121,6 +133,7 @@ class FriendshipService(BaseService):
"target_id": str(target_user_id),
},
)
await self._publish_created_events([inbox_event])
return await self._build_friend_request_response(
friendship, inbox, user_id, target_user_id
)
@@ -129,6 +142,8 @@ class FriendshipService(BaseService):
friendship, inbox = await self._repository.create_request(
user_id, target_user_id, request.content
)
await self._session.flush()
inbox_event = snapshot_from_inbox_message(inbox)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
@@ -144,6 +159,7 @@ class FriendshipService(BaseService):
"friend_request_sent",
extra={"initiator_id": str(user_id), "target_id": str(target_user_id)},
)
await self._publish_created_events([inbox_event])
return await self._build_friend_request_response(
friendship, inbox, user_id, target_user_id
@@ -172,11 +188,7 @@ class FriendshipService(BaseService):
),
)
recipient_id = (
friendship.user_low_id
if friendship.initiator_id == friendship.user_high_id
else friendship.user_high_id
)
recipient_id = self._get_recipient_id(friendship)
if recipient_id != user_id:
logger.warning(
@@ -218,6 +230,7 @@ class FriendshipService(BaseService):
friendship.status = FriendshipStatus.ACCEPTED
inbox.status = InboxMessageStatus.ACCEPTED
inbox_event = snapshot_from_inbox_message(inbox)
try:
await self._session.commit()
@@ -249,6 +262,7 @@ class FriendshipService(BaseService):
"initiator_id": str(sender_id),
},
)
await self._publish_status_events([inbox_event])
sender = await self._user_repository.get_by_user_id(sender_id)
recipient = await self._user_repository.get_by_user_id(user_id)
@@ -285,11 +299,7 @@ class FriendshipService(BaseService):
),
)
recipient_id = (
friendship.user_low_id
if friendship.initiator_id == friendship.user_high_id
else friendship.user_high_id
)
recipient_id = self._get_recipient_id(friendship)
if recipient_id != user_id:
logger.warning(
@@ -322,8 +332,10 @@ class FriendshipService(BaseService):
)
friendship.status = FriendshipStatus.DECLINED
inbox_event: InboxMessageEventSnapshot | None = None
if inbox:
inbox.status = InboxMessageStatus.REJECTED
inbox_event = snapshot_from_inbox_message(inbox)
try:
await self._session.commit()
@@ -355,6 +367,8 @@ class FriendshipService(BaseService):
"initiator_id": str(sender_id),
},
)
if inbox_event:
await self._publish_status_events([inbox_event])
sender = await self._user_repository.get_by_user_id(sender_id)
recipient = await self._user_repository.get_by_user_id(user_id)
@@ -422,8 +436,10 @@ class FriendshipService(BaseService):
)
friendship.status = FriendshipStatus.CANCELED
inbox_event: InboxMessageEventSnapshot | None = None
if inbox:
inbox.status = InboxMessageStatus.DISMISSED
inbox_event = snapshot_from_inbox_message(inbox)
try:
await self._session.commit()
@@ -457,6 +473,8 @@ class FriendshipService(BaseService):
"target_id": str(recipient_id),
},
)
if inbox_event:
await self._publish_status_events([inbox_event])
return FriendRequestResponse(
id=friendship.id,
@@ -583,11 +601,7 @@ class FriendshipService(BaseService):
)
sender = await self._user_repository.get_by_user_id(initiator_id)
recipient_id = (
friendship.user_low_id
if friendship.user_low_id != initiator_id
else friendship.user_high_id
)
recipient_id = self._get_recipient_id(friendship)
recipient = await self._user_repository.get_by_user_id(recipient_id)
# Map FriendshipStatus to response status
@@ -676,15 +690,24 @@ class FriendshipService(BaseService):
]
profiles_by_id = await self._user_repository.get_by_user_ids(friend_ids)
auth_users_by_id: dict[str, str | None] = {}
if self._auth_gateway is not None:
auth_users = await self._auth_gateway.get_users_by_ids(
[str(fid) for fid in friend_ids]
)
for uid, auth_user in auth_users.items():
auth_users_by_id[uid] = auth_user.phone
result: list[FriendResponse] = []
for friendship in friendships:
friend_id = self._get_other_user_id(friendship, user_id)
friend = profiles_by_id.get(friend_id)
phone = auth_users_by_id.get(str(friend_id))
result.append(
FriendResponse(
id=friendship.id,
friend=self._build_user_basic_info(friend),
friend=self._build_user_basic_info(friend, phone),
status="active",
created_at=friendship.created_at,
accepted_at=friendship.updated_at,
@@ -760,7 +783,9 @@ class FriendshipService(BaseService):
accepted_at=friendship.updated_at,
)
def _build_user_basic_info(self, profile: Any) -> "UserContext":
def _build_user_basic_info(
self, profile: Any, phone: str | None = None
) -> "UserContext":
from schemas.shared.user import UserContext
if profile is None:
@@ -770,7 +795,9 @@ class FriendshipService(BaseService):
return UserContext(
id=str(p.id),
username=p.username,
phone=phone,
avatar_url=p.avatar_url if hasattr(p, "avatar_url") else None,
bio=p.bio if hasattr(p, "bio") else None,
)
async def _build_friend_request_response(
@@ -800,3 +827,50 @@ class FriendshipService(BaseService):
if friendship.user_low_id == current_user_id
else friendship.user_low_id
)
def _get_recipient_id(self, friendship: Friendship) -> UUID:
return (
friendship.user_low_id
if friendship.initiator_id == friendship.user_high_id
else friendship.user_high_id
)
async def _publish_created_events(
self, messages: list[InboxMessageEventSnapshot]
) -> None:
for message in messages:
try:
await publish_inbox_message_created(message)
except Exception as exc: # noqa: BLE001
logger.exception(
"Failed to publish inbox created event",
message_id=str(message.message_id),
recipient_id=str(message.recipient_id),
)
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="INBOX_EVENT_STREAM_UNAVAILABLE",
detail="Inbox event stream unavailable",
),
) from exc
async def _publish_status_events(
self, messages: list[InboxMessageEventSnapshot]
) -> None:
for message in messages:
try:
await publish_inbox_message_status_changed(message)
except Exception as exc: # noqa: BLE001
logger.exception(
"Failed to publish inbox status event",
message_id=str(message.message_id),
recipient_id=str(message.recipient_id),
)
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="INBOX_EVENT_STREAM_UNAVAILABLE",
detail="Inbox event stream unavailable",
),
) from exc
+280
View File
@@ -0,0 +1,280 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import UTC, datetime
import inspect
import json
from typing import Any
from uuid import UUID, uuid4
from redis.exceptions import TimeoutError as RedisTimeoutError
from core.config.settings import config
from core.logging import get_logger
from models.inbox_messages import InboxMessage
from services.base.redis import get_or_init_redis_client
logger = get_logger("v1.inbox_messages.realtime")
INBOX_STREAM_PREFIX = "inbox:events"
EVENT_MESSAGE_CREATED = "INBOX_MESSAGE_CREATED"
EVENT_MESSAGE_READ_CHANGED = "INBOX_MESSAGE_READ_CHANGED"
EVENT_MESSAGE_STATUS_CHANGED = "INBOX_MESSAGE_STATUS_CHANGED"
EVENT_SNAPSHOT_REQUIRED = "INBOX_SNAPSHOT_REQUIRED"
@dataclass(frozen=True)
class InboxMessageEventSnapshot:
message_id: UUID
recipient_id: UUID
sender_id: UUID | None
message_type: str
schedule_item_id: UUID | None
friendship_id: UUID | None
content: dict[str, Any] | None
is_read: bool
status: str
created_at: datetime
occurred_at: datetime
def snapshot_from_inbox_message(message: InboxMessage) -> InboxMessageEventSnapshot:
message_type = (
message.message_type.value
if hasattr(message.message_type, "value")
else str(message.message_type)
)
status = (
message.status.value
if hasattr(message.status, "value")
else str(message.status)
)
if status in {"None", ""}:
status = "pending"
created_at = (
message.created_at
if isinstance(message.created_at, datetime)
else datetime.now(UTC)
)
occurred_at = (
message.updated_at if isinstance(message.updated_at, datetime) else created_at
)
message_id = message.id if isinstance(message.id, UUID) else uuid4()
return InboxMessageEventSnapshot(
message_id=message_id,
recipient_id=message.recipient_id,
sender_id=message.sender_id,
message_type=message_type,
schedule_item_id=message.schedule_item_id,
friendship_id=message.friendship_id,
content=message.content,
is_read=bool(message.is_read),
status=status,
created_at=created_at,
occurred_at=occurred_at,
)
def to_inbox_sse_event(stream_id: str, event_type: str, payload: dict[str, Any]) -> str:
safe_stream_id = str(stream_id).replace("\r", "").replace("\n", "")
safe_event_type = str(event_type).replace("\r", "").replace("\n", "")
data = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
return f"id: {safe_stream_id}\nevent: {safe_event_type}\ndata: {data}\n\n"
def _stream_name(recipient_id: UUID) -> str:
return f"{INBOX_STREAM_PREFIX}:{recipient_id}"
def _to_epoch_ms(value: datetime) -> int:
normalized = value.astimezone(UTC)
return int(normalized.timestamp() * 1000)
def _resolve_occurred_at(snapshot: InboxMessageEventSnapshot) -> datetime:
if isinstance(snapshot.occurred_at, datetime):
return snapshot.occurred_at
if isinstance(snapshot.created_at, datetime):
return snapshot.created_at
return datetime.now(UTC)
def _safe_stream_block_ms(requested_ms: int) -> int:
try:
socket_timeout_ms = max(int(float(config.redis.socket_timeout) * 1000), 1)
except (TypeError, ValueError):
socket_timeout_ms = 5000
safe_max = max(socket_timeout_ms - 100, 1)
return max(1, min(int(requested_ms), safe_max))
def _message_to_payload(snapshot: InboxMessageEventSnapshot) -> dict[str, Any]:
return {
"id": str(snapshot.message_id),
"recipient_id": str(snapshot.recipient_id),
"sender_id": str(snapshot.sender_id) if snapshot.sender_id else None,
"message_type": snapshot.message_type,
"schedule_item_id": str(snapshot.schedule_item_id)
if snapshot.schedule_item_id
else None,
"friendship_id": str(snapshot.friendship_id)
if snapshot.friendship_id
else None,
"content": snapshot.content,
"is_read": bool(snapshot.is_read),
"status": snapshot.status,
"created_at": snapshot.created_at.isoformat(),
}
async def _publish_event(recipient_id: UUID, payload: dict[str, Any]) -> str:
redis = await get_or_init_redis_client()
stream_name = _stream_name(recipient_id)
event_json = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
result = redis.xadd(stream_name, {"event": event_json})
if inspect.isawaitable(result):
return str(await result)
return str(result)
async def publish_inbox_message_created(
message: InboxMessage | InboxMessageEventSnapshot,
) -> str:
snapshot = (
message
if isinstance(message, InboxMessageEventSnapshot)
else snapshot_from_inbox_message(message)
)
occurred_at = _resolve_occurred_at(snapshot)
version = _to_epoch_ms(occurred_at)
payload = {
"event_id": str(uuid4()),
"occurred_at": occurred_at.isoformat(),
"user_id": str(snapshot.recipient_id),
"message_id": str(snapshot.message_id),
"event_type": EVENT_MESSAGE_CREATED,
"op": "created",
"version": version,
"data": {"message": _message_to_payload(snapshot)},
}
return await _publish_event(snapshot.recipient_id, payload)
async def publish_inbox_message_read_changed(
message: InboxMessage | InboxMessageEventSnapshot,
) -> str:
snapshot = (
message
if isinstance(message, InboxMessageEventSnapshot)
else snapshot_from_inbox_message(message)
)
occurred_at = _resolve_occurred_at(snapshot)
payload = {
"event_id": str(uuid4()),
"occurred_at": occurred_at.isoformat(),
"user_id": str(snapshot.recipient_id),
"message_id": str(snapshot.message_id),
"event_type": EVENT_MESSAGE_READ_CHANGED,
"op": "read_changed",
"version": _to_epoch_ms(occurred_at),
"data": {"is_read": bool(snapshot.is_read)},
}
return await _publish_event(snapshot.recipient_id, payload)
async def publish_inbox_message_status_changed(
message: InboxMessage | InboxMessageEventSnapshot,
) -> str:
snapshot = (
message
if isinstance(message, InboxMessageEventSnapshot)
else snapshot_from_inbox_message(message)
)
occurred_at = _resolve_occurred_at(snapshot)
payload = {
"event_id": str(uuid4()),
"occurred_at": occurred_at.isoformat(),
"user_id": str(snapshot.recipient_id),
"message_id": str(snapshot.message_id),
"event_type": EVENT_MESSAGE_STATUS_CHANGED,
"op": "status_changed",
"version": _to_epoch_ms(occurred_at),
"data": {"status": snapshot.status},
}
return await _publish_event(snapshot.recipient_id, payload)
async def publish_inbox_snapshot_required(
*, recipient_id: UUID, message_id: UUID
) -> str:
now = datetime.now(UTC)
payload = {
"event_id": str(uuid4()),
"occurred_at": now.isoformat(),
"user_id": str(recipient_id),
"message_id": str(message_id),
"event_type": EVENT_SNAPSHOT_REQUIRED,
"op": "snapshot_required",
"version": _to_epoch_ms(now),
"data": {},
}
return await _publish_event(recipient_id, payload)
async def read_inbox_events(
*,
recipient_id: UUID,
last_event_id: str | None,
count: int = 100,
block_ms: int = 5000,
) -> list[dict[str, Any]]:
redis = await get_or_init_redis_client()
stream = _stream_name(recipient_id)
start_id = "0-0" if not last_event_id else last_event_id
safe_block_ms = _safe_stream_block_ms(block_ms)
try:
raw = redis.xread({stream: start_id}, count=count, block=safe_block_ms)
response = await raw if inspect.isawaitable(raw) else raw
except (TimeoutError, asyncio.TimeoutError, RedisTimeoutError):
return []
if not response:
return []
first = response[0]
if not isinstance(first, (list, tuple)) or len(first) != 2:
return []
entries_raw = first[1]
if not isinstance(entries_raw, list):
return []
rows: list[dict[str, Any]] = []
for entry in entries_raw:
if not isinstance(entry, (list, tuple)) or len(entry) != 2:
continue
entry_id_raw, fields = entry
if isinstance(entry_id_raw, bytes):
stream_id = entry_id_raw.decode("utf-8", errors="replace")
elif isinstance(entry_id_raw, str):
stream_id = entry_id_raw
else:
continue
if not isinstance(fields, dict):
continue
payload_raw = fields.get("event")
if isinstance(payload_raw, bytes):
payload_raw = payload_raw.decode("utf-8", errors="replace")
if not isinstance(payload_raw, str):
continue
try:
payload = json.loads(payload_raw)
except (TypeError, ValueError):
logger.warning(
"Discard malformed inbox stream payload", stream_id=stream_id
)
continue
if not isinstance(payload, dict):
continue
rows.append({"id": stream_id, "event": payload})
return rows
+101 -1
View File
@@ -1,15 +1,54 @@
from __future__ import annotations
import asyncio
import re
from collections.abc import AsyncIterator
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Header, Query, Request
from fastapi.responses import StreamingResponse
from core.http.errors import ApiProblemError, problem_payload
from services.base.redis import get_or_init_redis_client
from v1.inbox_messages.realtime import to_inbox_sse_event
from v1.inbox_messages.dependencies import get_inbox_message_service
from v1.inbox_messages.schemas import InboxMessageResponse
from v1.inbox_messages.service import InboxMessageService
router = APIRouter(prefix="/inbox/messages", tags=["inbox-messages"])
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
_MAX_SSE_CONNECTIONS_PER_USER = 3
_SSE_SLOT_TTL_SECONDS = 15 * 60
async def _acquire_sse_slot(*, user_id: str) -> bool:
redis = await get_or_init_redis_client()
key = f"inbox:sse-active:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
await redis.decr(key)
return False
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
return True
async def _release_sse_slot(*, user_id: str) -> None:
redis = await get_or_init_redis_client()
key = f"inbox:sse-active:{user_id}"
count = await redis.decr(key)
if count <= 0:
await redis.delete(key)
else:
ttl = await redis.ttl(key)
if ttl < 0:
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
@router.get("", response_model=list[InboxMessageResponse])
@@ -26,3 +65,64 @@ async def mark_as_read(
service: Annotated[InboxMessageService, Depends(get_inbox_message_service)],
) -> InboxMessageResponse:
return await service.mark_as_read(message_id)
@router.get("/stream")
async def stream_inbox_events(
request: Request,
service: Annotated[InboxMessageService, Depends(get_inbox_message_service)],
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
idle_limit: int = Query(default=300, ge=1, le=3600),
) -> StreamingResponse:
if last_event_id is not None and (
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
):
raise ApiProblemError(
status_code=422,
detail=problem_payload(
code="INBOX_INVALID_LAST_EVENT_ID",
detail="Invalid Last-Event-ID",
),
)
user_id = str(service.require_user_id())
slot_acquired = await _acquire_sse_slot(user_id=user_id)
if not slot_acquired:
raise ApiProblemError(
status_code=429,
detail=problem_payload(
code="INBOX_SSE_CONNECTION_LIMIT",
detail="Too many SSE connections",
),
)
async def _event_iter() -> AsyncIterator[str]:
cursor = last_event_id
idle_polls = 0
try:
while not await request.is_disconnected() and idle_polls < idle_limit:
rows = await service.stream_events(last_event_id=cursor)
if not rows:
idle_polls += 1
yield ": keep-alive\n\n"
await asyncio.sleep(0.2)
continue
idle_polls = 0
for row in rows:
stream_id = row.get("id")
event = row.get("event")
if not isinstance(stream_id, str) or not isinstance(event, dict):
continue
cursor = stream_id
event_type = event.get("event_type")
if not isinstance(event_type, str) or not event_type:
event_type = "INBOX_MESSAGE"
yield to_inbox_sse_event(stream_id, event_type, event)
finally:
await _release_sse_slot(user_id=user_id)
response = StreamingResponse(_event_iter(), media_type="text/event-stream")
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive"
return response
+46 -1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from uuid import UUID
import json
@@ -10,6 +11,11 @@ from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.http.errors import ApiProblemError
from v1.inbox_messages.realtime import (
publish_inbox_message_read_changed,
read_inbox_events,
snapshot_from_inbox_message,
)
from core.logging import get_logger
from models.inbox_messages import InboxMessage
from v1.inbox_messages.repository import InboxMessageRepository
@@ -71,6 +77,8 @@ class InboxMessageService(BaseService):
code="INBOX_MESSAGE_NOT_FOUND",
detail="Inbox message not found",
)
event_snapshot = snapshot_from_inbox_message(updated)
response = self._to_response(updated)
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
@@ -85,7 +93,44 @@ class InboxMessageService(BaseService):
detail="Inbox message store unavailable",
)
return self._to_response(updated)
try:
await publish_inbox_message_read_changed(event_snapshot)
except Exception as exc: # noqa: BLE001
logger.exception(
"Failed to publish inbox read-changed event",
message_id=str(event_snapshot.message_id),
user_id=str(event_snapshot.recipient_id),
)
raise _inbox_error(
status_code=503,
code="INBOX_EVENT_STREAM_UNAVAILABLE",
detail="Inbox event stream unavailable",
) from exc
return response
async def stream_events(
self,
*,
last_event_id: str | None,
) -> list[dict[str, Any]]:
user_id = self.require_user_id()
try:
return await read_inbox_events(
recipient_id=user_id,
last_event_id=last_event_id,
)
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to read inbox event stream",
user_id=str(user_id),
reason=str(exc),
)
raise _inbox_error(
status_code=503,
code="INBOX_EVENT_STREAM_UNAVAILABLE",
detail="Inbox event stream unavailable",
) from exc
def _to_response(self, message: InboxMessage) -> InboxMessageResponse:
status_value = (
+1 -65
View File
@@ -4,7 +4,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Protocol, Sequence
from uuid import UUID
from sqlalchemy import func, or_, select, update, delete
from sqlalchemy import or_, select, update, delete
from sqlalchemy.exc import SQLAlchemyError
from core.db.base_repository import BaseRepository
@@ -28,14 +28,6 @@ class ScheduleItemRepository(Protocol):
async def list_by_date_range(
self, owner_id: UUID, start_at: datetime, end_at: datetime
) -> list[ScheduleItem]: ...
async def list_paginated(
self,
owner_id: UUID,
*,
page: int,
page_size: int,
query: str | None = None,
) -> tuple[list[ScheduleItem], int]: ...
async def create_subscription(self, data: dict) -> ScheduleSubscription: ...
async def get_subscriptions_by_item_id(
self, item_id: UUID
@@ -154,62 +146,6 @@ class SQLAlchemyScheduleItemRepository(BaseRepository[ScheduleItem]):
logger.exception("Schedule item list failed", owner_id=str(owner_id))
raise
async def list_paginated(
self,
owner_id: UUID,
*,
page: int,
page_size: int,
query: str | None = None,
) -> tuple[list[ScheduleItem], int]:
offset = (page - 1) * page_size
normalized_query = (query or "").strip()
has_query = bool(normalized_query)
query_like = f"%{normalized_query}%"
try:
count_stmt = (
select(func.count())
.select_from(ScheduleItem)
.where(ScheduleItem.owner_id == owner_id)
.where(ScheduleItem.deleted_at.is_(None))
)
if has_query:
count_stmt = count_stmt.where(
or_(
ScheduleItem.title.ilike(query_like),
ScheduleItem.description.ilike(query_like),
)
)
count_result = await self._session.execute(count_stmt)
total = int(count_result.scalar_one() or 0)
items_stmt = (
select(ScheduleItem)
.where(ScheduleItem.owner_id == owner_id)
.where(ScheduleItem.deleted_at.is_(None))
.order_by(ScheduleItem.start_at.asc(), ScheduleItem.id.asc())
.offset(offset)
.limit(page_size)
)
if has_query:
items_stmt = items_stmt.where(
or_(
ScheduleItem.title.ilike(query_like),
ScheduleItem.description.ilike(query_like),
)
)
items_result = await self._session.execute(items_stmt)
items = list(items_result.scalars().all())
return items, total
except SQLAlchemyError:
logger.exception(
"Schedule item paginated list failed",
owner_id=str(owner_id),
page=page,
page_size=page_size,
)
raise
async def create_subscription(self, data: dict) -> ScheduleSubscription:
sub = ScheduleSubscription(**data)
self._session.add(sub)
+15 -8
View File
@@ -21,6 +21,7 @@ from schemas.domain.schedule import (
ScheduleItemSourceType,
ScheduleItemStatus,
)
from schemas.enums import SubscriptionPermission
__all__ = [
"AttachmentType",
@@ -41,9 +42,20 @@ __all__ = [
"ScheduleItemShareRequest",
"ScheduleItemShareResponse",
"SubscriberInfo",
"parse_permission",
]
def parse_permission(permission_int: int) -> dict[str, bool]:
return {
"can_view": bool(permission_int & SubscriptionPermission.VIEW.value),
"can_invite": bool(permission_int & SubscriptionPermission.INVITE.value),
"can_edit": bool(permission_int & SubscriptionPermission.EDIT.value),
"can_delete": bool(permission_int & SubscriptionPermission.DELETE.value),
"is_owner": permission_int == SubscriptionPermission.OWNER.value,
}
class ScheduleItemCreateRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@@ -160,11 +172,6 @@ class ScheduleItemListRequest(BaseModel):
return value
_PERMISSION_VIEW = 1
_PERMISSION_INVITE = 2
_PERMISSION_EDIT = 4
class ScheduleItemShareRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@@ -180,11 +187,11 @@ class ScheduleItemShareRequest(BaseModel):
def _permission_value(self) -> int:
value = 0
if self.permission_view:
value |= _PERMISSION_VIEW
value |= SubscriptionPermission.VIEW.value
if self.permission_edit:
value |= _PERMISSION_EDIT
value |= SubscriptionPermission.EDIT.value
if self.permission_invite:
value |= _PERMISSION_INVITE
value |= SubscriptionPermission.INVITE.value
return value
+244 -66
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Protocol, Literal
from typing import TYPE_CHECKING, Any, Protocol, Literal
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
@@ -9,6 +9,12 @@ from sqlalchemy.exc import SQLAlchemyError
from core.auth.models import CurrentUser
from core.db.base_service import BaseService
from core.http.errors import ApiProblemError, problem_payload
from v1.inbox_messages.realtime import (
InboxMessageEventSnapshot,
publish_inbox_message_created,
publish_inbox_message_status_changed,
snapshot_from_inbox_message,
)
from core.logging import get_logger
from models.inbox_messages import InboxMessage
from models.profile import Profile
@@ -196,7 +202,14 @@ class ScheduleItemService(BaseService):
subscriber_ids
)
except SQLAlchemyError:
logger.exception("Failed to get subscriber profiles")
logger.exception("Failed to fetch subscriber profiles")
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="SCHEDULE_ITEM_STORE_UNAVAILABLE",
detail="Schedule item store unavailable",
),
)
resolved_contacts = await resolve_contacts_by_user_ids(
user_ids=subscriber_ids,
profiles_by_id=profiles,
@@ -302,10 +315,30 @@ class ScheduleItemService(BaseService):
if not update_data:
return self._to_response(existing, is_owner=is_owner)
before_state = self._capture_calendar_state(existing)
item = await self._repository.update_item(item_id, update_data)
await self._notify_subscribers(item_id, existing.title, "updated")
if item is None:
raise ApiProblemError(
status_code=404,
detail=problem_payload(
code="SCHEDULE_ITEM_NOT_FOUND",
detail="Schedule item not found",
),
)
changes = self._build_calendar_changes(before_state, item, update_data)
created_messages = await self._notify_subscribers(
item_id,
item.title,
"updated",
changes=changes,
)
if created_messages:
await self._session.flush()
created_snapshots = [
snapshot_from_inbox_message(message) for message in created_messages
]
await self._session.commit()
await self._publish_created_events(created_snapshots)
except SQLAlchemyError:
await self._session.rollback()
logger.exception("Failed to update schedule item", item_id=str(item_id))
@@ -317,15 +350,6 @@ class ScheduleItemService(BaseService):
),
)
if item is None:
raise ApiProblemError(
status_code=404,
detail=problem_payload(
code="SCHEDULE_ITEM_NOT_FOUND",
detail="Schedule item not found",
),
)
return self._to_response(item, is_owner=is_owner)
async def delete(self, item_id: UUID) -> None:
@@ -359,9 +383,20 @@ class ScheduleItemService(BaseService):
title = existing.title
await self._repository.delete_subscriptions_by_item_id(item_id)
await self._notify_subscribers(item_id, title, "deleted")
created_messages = await self._notify_subscribers(
item_id,
title,
"deleted",
changes=[],
)
await self._repository.delete_item(item_id)
if created_messages:
await self._session.flush()
created_snapshots = [
snapshot_from_inbox_message(message) for message in created_messages
]
await self._session.commit()
await self._publish_created_events(created_snapshots)
except SQLAlchemyError:
await self._session.rollback()
logger.exception("Failed to delete schedule item", item_id=str(item_id))
@@ -427,54 +462,6 @@ class ScheduleItemService(BaseService):
),
)
async def list_paginated(
self,
*,
page: int,
page_size: int,
query: str | None = None,
) -> tuple[list[ScheduleItemResponse], int]:
user_id = self.require_user_id()
if page < 1:
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="SCHEDULE_ITEM_PAGE_INVALID",
detail="page must be >= 1",
params={"min": 1, "page": page},
),
)
if page_size < 1 or page_size > 100:
raise ApiProblemError(
status_code=400,
detail=problem_payload(
code="SCHEDULE_ITEM_PAGE_SIZE_INVALID",
detail="page_size must be 1..100",
params={"min": 1, "max": 100, "page_size": page_size},
),
)
try:
items, total = await self._repository.list_paginated(
user_id,
page=page,
page_size=page_size,
query=query,
)
except SQLAlchemyError:
logger.exception(
"Failed to list schedule items with pagination",
page=page,
page_size=page_size,
)
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="SCHEDULE_ITEM_STORE_UNAVAILABLE",
detail="Schedule item store unavailable",
),
)
return [self._to_response(item) for item in items], total
async def share(
self, item_id: UUID, request: ScheduleItemShareRequest
) -> ScheduleItemShareResponse:
@@ -518,6 +505,10 @@ class ScheduleItemService(BaseService):
),
)
actor_username = await self._resolve_actor_username(user_id)
actor = await self._auth_gateway.get_user_by_id(str(user_id))
actor_phone = actor.phone
target_user = await self._auth_gateway.get_user_by_phone(request.phone)
recipient_id = UUID(target_user.id)
@@ -563,6 +554,8 @@ class ScheduleItemService(BaseService):
existing_msg = await self._inbox_repository.get_calendar_invite(
item.id, recipient_id
)
event_target_message: InboxMessage | None = None
event_is_created = False
if existing_msg:
if existing_msg.status == InboxMessageStatus.ACCEPTED:
raise ApiProblemError(
@@ -584,9 +577,25 @@ class ScheduleItemService(BaseService):
existing_msg.status = InboxMessageStatus.PENDING
existing_msg.content = {
"type": "invite",
"schema_version": 2,
"item": {
"id": str(item.id),
"title": item.title,
"description": item.description,
"start_at": item.start_at.isoformat(),
"end_at": item.end_at.isoformat() if item.end_at else None,
"timezone": item.timezone,
},
"actor": {
"user_id": str(user_id),
"username": actor_username,
"phone": actor_phone,
},
"summary": f"{item.title} 邀请您加入日历",
"permission": request_permission,
"action": "pending",
}
event_target_message = existing_msg
else:
message = InboxMessage(
recipient_id=recipient_id,
@@ -595,14 +604,44 @@ class ScheduleItemService(BaseService):
schedule_item_id=item.id,
content={
"type": "invite",
"schema_version": 2,
"item": {
"id": str(item.id),
"title": item.title,
"description": item.description,
"start_at": item.start_at.isoformat(),
"end_at": item.end_at.isoformat() if item.end_at else None,
"timezone": item.timezone,
},
"actor": {
"user_id": str(user_id),
"username": actor_username,
"phone": actor_phone,
},
"summary": f"{item.title} 邀请您加入日历",
"permission": request_permission,
"action": "pending",
},
created_by=user_id,
)
self._session.add(message)
event_target_message = message
event_is_created = True
if event_target_message is not None and event_is_created:
await self._session.flush()
event_snapshot = (
snapshot_from_inbox_message(event_target_message)
if event_target_message is not None
else None
)
await self._session.commit()
if event_target_message is not None:
assert event_snapshot is not None
if event_is_created:
await self._publish_created_events([event_snapshot])
else:
await self._publish_status_events([event_snapshot])
except ApiProblemError:
raise
except SQLAlchemyError:
@@ -703,7 +742,9 @@ class ScheduleItemService(BaseService):
)
inbox.status = InboxMessageStatus.ACCEPTED
event_snapshot = snapshot_from_inbox_message(inbox)
await self._session.commit()
await self._publish_status_events([event_snapshot])
return {"message": "Subscription accepted"}
except ApiProblemError:
@@ -742,7 +783,9 @@ class ScheduleItemService(BaseService):
)
inbox.status = InboxMessageStatus.REJECTED
event_snapshot = snapshot_from_inbox_message(inbox)
await self._session.commit()
await self._publish_status_events([event_snapshot])
return {"message": "Subscription rejected"}
except ApiProblemError:
@@ -763,18 +806,36 @@ class ScheduleItemService(BaseService):
item_id: UUID,
title: str,
action_type: Literal["updated", "deleted"],
):
*,
changes: list[dict[str, Any]],
) -> list[InboxMessage]:
user_id = self.require_user_id()
subscriptions = await self._repository.get_subscriptions_by_item_id(item_id)
if not subscriptions:
return []
actor_username = await self._resolve_actor_username(user_id)
created_messages: list[InboxMessage] = []
for sub in subscriptions:
if sub.subscriber_id == user_id:
continue
content = {
"type": action_type,
"title": title,
"schema_version": 2,
"item": {
"id": str(item_id),
"title": title,
},
"actor": {
"user_id": str(user_id),
"username": actor_username,
},
"summary": f"{actor_username} 更新了日历 {title}"
if action_type == "updated"
else f"{actor_username} 删除了日历 {title}",
"changes": changes,
"action": action_type,
}
@@ -787,9 +848,126 @@ class ScheduleItemService(BaseService):
created_by=user_id,
)
self._session.add(message)
created_messages.append(message)
return created_messages
if subscriptions:
await self._session.commit()
async def _resolve_actor_username(self, user_id: UUID) -> str:
if self._user_repository is None:
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="SCHEDULE_ITEM_ACTOR_LOOKUP_UNAVAILABLE",
detail="Actor lookup unavailable",
),
)
profile = await self._user_repository.get_by_user_id(user_id)
if profile is not None and profile.username:
return profile.username
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="SCHEDULE_ITEM_ACTOR_LOOKUP_UNAVAILABLE",
detail="Actor lookup unavailable",
),
)
def _capture_calendar_state(self, item: ScheduleItem) -> dict[str, Any]:
return {
"title": item.title,
"description": item.description,
"start_at": item.start_at,
"end_at": item.end_at,
"timezone": item.timezone,
"status": item.status,
}
def _build_calendar_changes(
self,
before: dict[str, Any],
after: ScheduleItem,
update_data: dict[str, Any],
) -> list[dict[str, Any]]:
label_map = {
"title": "标题",
"description": "描述",
"start_at": "开始时间",
"end_at": "结束时间",
"timezone": "时区",
"status": "状态",
}
changes: list[dict[str, Any]] = []
for field in label_map:
if field not in update_data:
continue
before_value = before.get(field)
after_value = getattr(after, field)
if before_value == after_value:
continue
change_type = "modified"
if before_value is None and after_value is not None:
change_type = "added"
elif before_value is not None and after_value is None:
change_type = "removed"
changes.append(
{
"field": field,
"label": label_map[field],
"before": before_value.isoformat()
if isinstance(before_value, datetime)
else before_value,
"after": after_value.isoformat()
if isinstance(after_value, datetime)
else after_value,
"display_before": str(before_value)
if before_value is not None
else None,
"display_after": str(after_value)
if after_value is not None
else None,
"change_type": change_type,
}
)
return changes
async def _publish_created_events(
self, messages: list[InboxMessageEventSnapshot]
) -> None:
for message in messages:
try:
await publish_inbox_message_created(message)
except Exception as exc: # noqa: BLE001
logger.exception(
"Failed to publish inbox created event",
message_id=str(message.message_id),
recipient_id=str(message.recipient_id),
)
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="INBOX_EVENT_STREAM_UNAVAILABLE",
detail="Inbox event stream unavailable",
),
) from exc
async def _publish_status_events(
self, messages: list[InboxMessageEventSnapshot]
) -> None:
for message in messages:
try:
await publish_inbox_message_status_changed(message)
except Exception as exc: # noqa: BLE001
logger.exception(
"Failed to publish inbox status event",
message_id=str(message.message_id),
recipient_id=str(message.recipient_id),
)
raise ApiProblemError(
status_code=503,
detail=problem_payload(
code="INBOX_EVENT_STREAM_UNAVAILABLE",
detail="Inbox event stream unavailable",
),
) from exc
def _to_utc(self, dt: datetime | None) -> datetime | None:
if dt is None:
+1 -1
View File
@@ -19,7 +19,7 @@ class UserSearchRequest(BaseModel):
class UserUpdateRequest(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
username: str | None = Field(default=None, min_length=3, max_length=30)
username: str | None = Field(default=None, max_length=30)
avatar_url: str | None = Field(default=None)
bio: str | None = Field(default=None, max_length=200)
+118 -81
View File
@@ -12,7 +12,7 @@ from core.agentscope.caches.user_context_cache import (
from core.auth.models import CurrentUser
from core.config.settings import config
from core.db.base_service import BaseService
from core.http.errors import ApiProblemError
from core.http.errors import ApiProblemError, problem_payload
from core.logging import get_logger
from schemas.shared.user import UserContext, parse_profile_settings
from services.base.supabase import supabase_service
@@ -23,27 +23,13 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from schemas.shared.user import UserContext
from v1.auth.schemas import UserByIdResponse
logger = get_logger("v1.users.service")
_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$")
def _user_error(
*,
status_code: int,
code: str,
detail: str,
params: dict[str, object] | None = None,
) -> ApiProblemError:
return ApiProblemError(
status_code=status_code,
code=code,
detail=detail,
params=params,
)
def _mime_to_suffix(mime_type: str) -> str:
"""Convert MIME type to file suffix."""
mapping = {
@@ -59,11 +45,7 @@ class AuthLookupGateway(Protocol):
self, query: str, limit: int = 20
) -> list[str]: ...
class AuthByPhoneGateway(Protocol):
async def search_user_ids_by_phone(
self, query: str, limit: int = 20
) -> list[str]: ...
async def get_user_by_id(self, user_id: str) -> "UserByIdResponse": ...
class UserContextInvalidator(Protocol):
@@ -71,7 +53,7 @@ class UserContextInvalidator(Protocol):
class AuthLookupAdapter:
def __init__(self, gateway: AuthByPhoneGateway) -> None:
def __init__(self, gateway: AuthLookupGateway) -> None:
self._gateway = gateway
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
@@ -80,6 +62,12 @@ class AuthLookupAdapter:
except ApiProblemError:
return []
async def get_user_by_id(self, user_id: str) -> "UserByIdResponse | None":
try:
return await self._gateway.get_user_by_id(user_id)
except ApiProblemError:
return None
class UserService(BaseService):
"""User service handling business logic and transactions.
@@ -117,17 +105,21 @@ class UserService(BaseService):
try:
user = await self._repository.get_by_user_id(user_id)
except SQLAlchemyError:
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
detail=problem_payload(
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
),
)
if user is None:
raise _user_error(
raise ApiProblemError(
status_code=404,
code="USER_NOT_FOUND",
detail="User not found",
detail=problem_payload(
code="USER_NOT_FOUND",
detail="User not found",
),
)
phone = self._current_user.phone if self._current_user else None
return UserContext(
@@ -145,22 +137,38 @@ class UserService(BaseService):
try:
profile = await self._repository.get_by_user_id(user_id)
except SQLAlchemyError:
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
detail=problem_payload(
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
),
)
if profile is None:
raise _user_error(
raise ApiProblemError(
status_code=404,
code="USER_NOT_FOUND",
detail="User not found",
detail=problem_payload(
code="USER_NOT_FOUND",
detail="User not found",
),
)
phone: str | None = None
if self._auth_gateway is not None:
try:
auth_user = await self._auth_gateway.get_user_by_id(str(user_id))
phone = auth_user.phone
except Exception:
logger.warning(
"Failed to resolve auth phone",
user_id=str(user_id),
)
return UserContext(
id=str(profile.id),
username=profile.username,
avatar_url=profile.avatar_url,
phone=phone,
bio=profile.bio,
)
async def update_me(self, update: UserUpdateRequest) -> UserContext:
@@ -176,10 +184,12 @@ class UserService(BaseService):
}
if not update_data:
raise _user_error(
raise ApiProblemError(
status_code=400,
code="USER_UPDATE_FIELDS_EMPTY",
detail="No fields to update",
detail=problem_payload(
code="USER_UPDATE_FIELDS_EMPTY",
detail="No fields to update",
),
)
try:
@@ -187,17 +197,21 @@ class UserService(BaseService):
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
detail=problem_payload(
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
),
)
if user is None:
raise _user_error(
raise ApiProblemError(
status_code=404,
code="USER_NOT_FOUND",
detail="User not found",
detail=problem_payload(
code="USER_NOT_FOUND",
detail="User not found",
),
)
try:
@@ -229,38 +243,46 @@ class UserService(BaseService):
user_id = self.require_user_id()
if not isinstance(content_type, str):
raise _user_error(
raise ApiProblemError(
status_code=422,
code="USER_AVATAR_UNSUPPORTED_TYPE",
detail="Unsupported image type",
detail=problem_payload(
code="USER_AVATAR_UNSUPPORTED_TYPE",
detail="Unsupported image type",
),
)
mime_type = content_type.lower()
allowed_types = {"image/jpeg", "image/png", "image/webp"}
if mime_type not in allowed_types:
raise _user_error(
raise ApiProblemError(
status_code=422,
code="USER_AVATAR_UNSUPPORTED_TYPE",
detail="Unsupported image type. Allowed: JPEG, PNG, WebP",
params={"allowed": ["image/jpeg", "image/png", "image/webp"]},
detail=problem_payload(
code="USER_AVATAR_UNSUPPORTED_TYPE",
detail="Unsupported image type. Allowed: JPEG, PNG, WebP",
params={"allowed": ["image/jpeg", "image/png", "image/webp"]},
),
)
max_size_bytes = config.storage.avatar.max_size_mb * 1024 * 1024
if len(payload) > max_size_bytes:
raise _user_error(
raise ApiProblemError(
status_code=413,
code="USER_AVATAR_TOO_LARGE",
detail=(
f"Image too large. Maximum size: {config.storage.avatar.max_size_mb}MB"
detail=problem_payload(
code="USER_AVATAR_TOO_LARGE",
detail=(
f"Image too large. Maximum size: {config.storage.avatar.max_size_mb}MB"
),
params={"max_size_mb": config.storage.avatar.max_size_mb},
),
params={"max_size_mb": config.storage.avatar.max_size_mb},
)
if not payload:
raise _user_error(
raise ApiProblemError(
status_code=422,
code="USER_AVATAR_EMPTY",
detail="Empty image",
detail=problem_payload(
code="USER_AVATAR_EMPTY",
detail="Empty image",
),
)
suffix = _mime_to_suffix(mime_type)
@@ -284,13 +306,16 @@ class UserService(BaseService):
"user_id": str(user_id),
},
)
raise _user_error(
raise ApiProblemError(
status_code=502,
code="USER_AVATAR_UPLOAD_FAILED",
detail="Failed to upload avatar",
detail=problem_payload(
code="USER_AVATAR_UPLOAD_FAILED",
detail="Failed to upload avatar",
),
)
public_url = f"{config.supabase.public_url}/storage/v1/object/public/{bucket_name}/{stored_path}"
base_url = str(config.supabase.public_url).rstrip("/")
public_url = f"{base_url}/storage/v1/object/public/{bucket_name}/{stored_path}"
update_data: dict[str, str | None] = {"avatar_url": public_url}
try:
@@ -298,17 +323,21 @@ class UserService(BaseService):
await self._session.commit()
except SQLAlchemyError:
await self._session.rollback()
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
detail=problem_payload(
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
),
)
if user is None:
raise _user_error(
raise ApiProblemError(
status_code=404,
code="USER_NOT_FOUND",
detail="User not found",
detail=problem_payload(
code="USER_NOT_FOUND",
detail="User not found",
),
)
try:
@@ -326,17 +355,19 @@ class UserService(BaseService):
try:
user = await self._repository.get_by_username(username)
except SQLAlchemyError:
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
)
if user is None:
raise _user_error(
raise ApiProblemError(
status_code=404,
code="USER_NOT_FOUND",
detail="User not found",
detail=problem_payload(
code="USER_NOT_FOUND",
detail="User not found",
),
)
return UserContext(
id=str(user.id),
@@ -365,10 +396,12 @@ class UserService(BaseService):
async def _search_by_phone(self, phone: str) -> list[UserContext]:
if self._auth_gateway is None:
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_AUTH_LOOKUP_UNAVAILABLE",
detail="Auth lookup unavailable",
detail=problem_payload(
code="USER_AUTH_LOOKUP_UNAVAILABLE",
detail="Auth lookup unavailable",
),
)
user_id_values = await self._auth_gateway.search_user_ids_by_phone(
@@ -389,10 +422,12 @@ class UserService(BaseService):
try:
users_by_id = await self._repository.get_by_user_ids(user_ids)
except SQLAlchemyError:
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
detail=problem_payload(
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
),
)
results: list[UserContext] = []
@@ -415,10 +450,12 @@ class UserService(BaseService):
try:
users = await self._repository.search_users(query, limit=20)
except SQLAlchemyError:
raise _user_error(
raise ApiProblemError(
status_code=503,
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
detail=problem_payload(
code="USER_STORE_UNAVAILABLE",
detail="User store unavailable",
),
)
return [
@@ -25,6 +25,13 @@ class FakeInboxMessageService:
) -> None:
self._messages = messages
self._read_message = read_message
self._stream_rows: list[dict[str, object]] = []
def set_stream_rows(self, rows: list[dict[str, object]]) -> None:
self._stream_rows = rows
def require_user_id(self) -> UUID:
return self._read_message.recipient_id
async def list_messages(
self, is_read: bool | None = None
@@ -38,6 +45,16 @@ class FakeInboxMessageService:
raise HTTPException(status_code=404, detail="Inbox message not found")
return self._read_message
async def stream_events(
self,
*,
last_event_id: str | None,
) -> list[dict[str, object]]:
del last_event_id
rows = self._stream_rows
self._stream_rows = []
return rows
def _override_inbox_message_service(
service: FakeInboxMessageService,
@@ -58,7 +75,7 @@ def _build_message(
sender_id=uuid4(),
message_type=InboxMessageType.CALENDAR,
schedule_item_id=uuid4(),
content='{"permission": 1}',
content={"permission": 1},
is_read=False,
status=status,
created_at=datetime(2026, 2, 28, 9, 0, 0, tzinfo=timezone.utc),
@@ -108,3 +125,62 @@ def test_mark_as_read_returns_200() -> None:
assert body["is_read"] is True
finally:
app.dependency_overrides = {}
def test_stream_inbox_events_returns_sse_payload() -> None:
read_message = _build_message(uuid4(), InboxMessageStatus.PENDING)
service = FakeInboxMessageService(
messages=[read_message], read_message=read_message
)
service.set_stream_rows(
[
{
"id": "1743313300000-0",
"event": {
"event_id": str(uuid4()),
"occurred_at": "2026-03-30T07:00:00+00:00",
"user_id": str(read_message.recipient_id),
"message_id": str(read_message.id),
"event_type": "INBOX_MESSAGE_CREATED",
"op": "created",
"version": 1743313300000,
"data": {"message": {"id": str(read_message.id)}},
},
}
]
)
app.dependency_overrides[get_inbox_message_service] = (
_override_inbox_message_service(service)
)
client = TestClient(app)
try:
response = client.get("/api/v1/inbox/messages/stream?idle_limit=1")
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
payload = response.text
assert "event: INBOX_MESSAGE_CREATED" in payload
assert '"op":"created"' in payload
finally:
app.dependency_overrides = {}
def test_stream_inbox_events_rejects_invalid_last_event_id() -> None:
read_message = _build_message(uuid4(), InboxMessageStatus.PENDING)
service = FakeInboxMessageService(
messages=[read_message], read_message=read_message
)
app.dependency_overrides[get_inbox_message_service] = (
_override_inbox_message_service(service)
)
client = TestClient(app)
try:
response = client.get(
"/api/v1/inbox/messages/stream",
headers={"Last-Event-ID": "not-a-stream-id"},
)
assert response.status_code == 422
body = response.json()
assert body.get("code") == "INBOX_INVALID_LAST_EVENT_ID"
finally:
app.dependency_overrides = {}
@@ -4,7 +4,7 @@ import json
from dataclasses import dataclass, field
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import Any
from typing import Any, cast
from uuid import UUID, uuid4
import pytest
@@ -27,6 +27,7 @@ class _FakeService:
created_request: Any = None
created_id: str = field(default_factory=lambda: str(uuid4()))
list_calls: list[dict[str, Any]] = field(default_factory=list)
range_calls: list[dict[str, Any]] = field(default_factory=list)
deleted_ids: list[str] = field(default_factory=list)
async def list_paginated(
@@ -47,6 +48,29 @@ class _FakeService:
)
return [item], 1
async def list_by_date_range(self, request: Any):
self.range_calls.append(
{
"start_at": request.start_at,
"end_at": request.end_at,
}
)
return [
SimpleNamespace(
id=UUID(self.created_id),
owner_id=uuid4(),
title="会议",
description="今天下午五点的会议",
start_at=datetime(2026, 3, 17, 9, 0, tzinfo=timezone.utc),
end_at=datetime(2026, 3, 17, 9, 30, tzinfo=timezone.utc),
timezone="Asia/Shanghai",
status="active",
source_type="manual",
metadata=None,
subscribers=[],
)
]
async def create_agent_generated(self, request):
self.created_request = request
return SimpleNamespace(
@@ -235,22 +259,48 @@ async def test_calendar_read_returns_structured_result_with_ids(
)
result = await calendar_module.calendar_read(
query="会议",
page=1,
page_size=20,
start_at="2026-03-17T00:00:00+08:00",
end_at="2026-03-18T00:00:00+08:00",
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(result)
result_data = json.loads(payload["result"])
assert payload["status"] == "success"
assert result_data["total"] == 1
assert result_data["items"][0]["id"] == fake_service.created_id
assert result_data["items"][0]["timezone"] == "Asia/Shanghai"
assert result_data["items"][0]["description"] == "今天下午五点的会议"
assert result_data["items"][0]["status"] == "active"
assert fake_service.range_calls == [
{
"start_at": datetime(2026, 3, 16, 16, 0, tzinfo=timezone.utc),
"end_at": datetime(2026, 3, 17, 16, 0, tzinfo=timezone.utc),
}
]
@pytest.mark.asyncio
async def test_calendar_read_rejects_naive_datetime_string(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeService()
monkeypatch.setattr(
calendar_module, "create_schedule_service", lambda *_: fake_service
)
result = await calendar_module.calendar_read(
start_at="2026-03-17T00:00:00",
end_at="2026-03-18T00:00:00+08:00",
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(result)
assert payload["status"] == "success"
assert payload["result"].startswith("status=success")
assert "total=1" in payload["result"]
assert "timezone=Asia/Shanghai" in payload["result"]
assert "description=今天下午五点的会议" in payload["result"]
assert "status=active" in payload["result"]
assert fake_service.created_id in payload["result"]
assert fake_service.list_calls == [{"page": 1, "page_size": 20, "query": "会议"}]
assert payload["status"] == "failure"
assert payload["error"]["code"] == "INVALID_ARGUMENT"
assert "时区" in payload["error"]["message"]
@pytest.mark.asyncio
@@ -312,3 +362,39 @@ async def test_calendar_share_rejects_invalid_phone(
assert payload["status"] == "failure"
assert payload["error"]["code"] == "INVALID_ARGUMENT"
@pytest.mark.asyncio
async def test_calendar_share_accepts_json_invitee_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_service = _FakeService()
monkeypatch.setattr(
calendar_module, "create_schedule_service", lambda *_: fake_service
)
event_id = str(uuid4())
result = await calendar_module.calendar_share(
event_id=event_id,
invitees=cast(
Any,
[
{
"phone": "8613900001234",
"permissionView": True,
"permissionEdit": False,
"permissionInvite": False,
}
],
),
session=SimpleNamespace(),
owner_id=uuid4(),
)
payload = _decode_tool_response(result)
assert payload["status"] == "success"
assert payload["result"].startswith("status=success success=1 failed=0")
assert len(fake_service.share_calls) == 1
share_call = fake_service.share_calls[0]
assert share_call["item_id"] == event_id
assert share_call["request"].phone == "+8613900001234"
@@ -0,0 +1,96 @@
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from core.runtime import cli
class _FakeScheduler:
def __init__(self) -> None:
self.started = False
self.shutdown_called = False
self.jobs: list[dict[str, Any]] = []
def add_job(self, func: Any, **kwargs: Any) -> None:
self.jobs.append({"func": func, **kwargs})
def start(self) -> None:
self.started = True
def shutdown(self, *, wait: bool) -> None:
self.shutdown_called = True
self.shutdown_wait = wait
class _StopEvent:
async def wait(self) -> None:
raise asyncio.CancelledError
@pytest.mark.asyncio
async def test_run_automation_scheduler_forever_uses_async_scheduler(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_scheduler = _FakeScheduler()
dispatch_limits: list[int] = []
async def _fake_scan(*, limit: int) -> None:
dispatch_limits.append(limit)
monkeypatch.setattr(cli, "AsyncIOScheduler", lambda: fake_scheduler)
monkeypatch.setattr(cli, "run_automation_scheduler_scan", _fake_scan)
monkeypatch.setattr(cli.asyncio, "Event", lambda: _StopEvent())
settings = cli.config.automation_scheduler
old_enabled = settings.enabled
old_interval = settings.interval_seconds
old_limit = settings.batch_limit
settings.enabled = True
settings.interval_seconds = 9
settings.batch_limit = 7
try:
with pytest.raises(asyncio.CancelledError):
await cli.run_automation_scheduler_forever()
finally:
settings.enabled = old_enabled
settings.interval_seconds = old_interval
settings.batch_limit = old_limit
assert fake_scheduler.started is True
assert fake_scheduler.shutdown_called is True
assert len(fake_scheduler.jobs) == 1
assert fake_scheduler.jobs[0]["max_instances"] == 1
assert fake_scheduler.jobs[0]["coalesce"] is True
scan_job = fake_scheduler.jobs[0]["func"]
await scan_job()
assert dispatch_limits == [7]
@pytest.mark.asyncio
async def test_run_automation_scheduler_forever_disabled_noop(
monkeypatch: pytest.MonkeyPatch,
) -> None:
settings = cli.config.automation_scheduler
old_enabled = settings.enabled
settings.enabled = False
called = False
def _unexpected_scheduler() -> _FakeScheduler:
nonlocal called
called = True
return _FakeScheduler()
monkeypatch.setattr(cli, "AsyncIOScheduler", _unexpected_scheduler)
try:
await cli.run_automation_scheduler_forever()
finally:
settings.enabled = old_enabled
assert called is False
@@ -1,6 +1,6 @@
from __future__ import annotations
from schemas.agent.runtime_models import RouterAgentOutput
from schemas.agent.runtime_models import RouterAgentOutput, WorkerAgentOutputRich
def test_router_agent_output_coerces_key_entity_value_to_string() -> None:
@@ -32,3 +32,59 @@ def test_router_agent_output_coerces_key_entity_value_to_string() -> None:
model = RouterAgentOutput.model_validate(payload)
assert model.key_entities[0].value == "8"
def test_router_agent_output_coerces_constraint_value_to_string() -> None:
payload = {
"normalized_task_input": {
"user_text": "test",
"multimodal_summary": [],
"context_summary": "",
},
"key_entities": [],
"constraints": [
{
"key": "strict_mode",
"value": True,
"required": True,
}
],
"task_typing": {
"primary": "planning",
"secondary": [],
},
"execution_mode": "onestep",
"result_typing": {
"primary": "summary",
"secondary": [],
},
}
model = RouterAgentOutput.model_validate(payload)
assert model.constraints[0].value == "True"
def test_worker_agent_output_rich_accepts_list_item_status_object() -> None:
payload = {
"status": "success",
"answer": "done",
"result_type": "summary",
"ui_hints": {
"intent": "status",
"status": "info",
"title": "状态",
"listItems": [
{
"title": "任务A",
"status": {"type": "info", "value": "已归档"},
}
],
},
}
model = WorkerAgentOutputRich.model_validate(payload)
assert model.ui_hints is not None
assert model.ui_hints.list_items[0].status is not None
assert model.ui_hints.list_items[0].status.value == "info"
@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import datetime
from datetime import datetime, timezone
from typing import cast
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
@@ -64,10 +64,14 @@ class FakeFriendshipRepo:
inbox.id = uuid4()
inbox.recipient_id = recipient_id
inbox.sender_id = initiator_id
inbox.schedule_item_id = None
inbox.status = InboxMessageStatus.PENDING
inbox.message_type = InboxMessageType.FRIEND_REQUEST
inbox.friendship_id = friendship.id
inbox.content = {"type": "request", "message": content}
inbox.is_read = False
inbox.created_at = datetime.now(timezone.utc)
inbox.updated_at = datetime.now(timezone.utc)
self._inbox_messages.append(inbox)
return friendship, inbox
@@ -91,10 +95,14 @@ class FakeFriendshipRepo:
inbox.id = uuid4()
inbox.recipient_id = recipient_id
inbox.sender_id = initiator_id
inbox.schedule_item_id = None
inbox.status = InboxMessageStatus.PENDING
inbox.message_type = InboxMessageType.FRIEND_REQUEST
inbox.friendship_id = friendship.id
inbox.content = {"type": "request", "message": content}
inbox.is_read = False
inbox.created_at = datetime.now(timezone.utc)
inbox.updated_at = datetime.now(timezone.utc)
self._inbox_messages.append(inbox)
return friendship, inbox
@@ -0,0 +1,109 @@
from __future__ import annotations
from datetime import UTC, datetime
from uuid import uuid4
import pytest
from models.inbox_messages import InboxMessage
from schemas.enums import InboxMessageStatus, InboxMessageType
from v1.inbox_messages import realtime
class _FakeRedis:
def __init__(self) -> None:
self.last_stream: str | None = None
self.last_payload: str | None = None
self.last_block: int | None = None
async def xadd(self, stream: str, fields: dict[str, str]) -> str:
self.last_stream = stream
self.last_payload = fields.get("event")
return "1743313300000-0"
async def xread(self, _streams: dict[str, str], count: int, block: int):
del count
self.last_block = block
return [
(
"inbox:events:test",
[
(
"1743313300000-0",
{
"event": '{"event_id":"e1","event_type":"INBOX_MESSAGE_CREATED","op":"created"}',
},
)
],
)
]
@pytest.mark.asyncio
async def test_publish_inbox_message_created_writes_stream(monkeypatch) -> None:
fake_redis = _FakeRedis()
async def _fake_get_redis():
return fake_redis
monkeypatch.setattr(realtime, "get_or_init_redis_client", _fake_get_redis)
message = InboxMessage(
id=uuid4(),
recipient_id=uuid4(),
sender_id=uuid4(),
message_type=InboxMessageType.CALENDAR,
friendship_id=None,
schedule_item_id=uuid4(),
group_id=None,
content={"type": "invite"},
is_read=False,
status=InboxMessageStatus.PENDING,
created_by=uuid4(),
)
message.created_at = datetime(2026, 3, 30, 7, 0, tzinfo=UTC)
message.updated_at = datetime(2026, 3, 30, 7, 0, tzinfo=UTC)
stream_id = await realtime.publish_inbox_message_created(message)
assert stream_id == "1743313300000-0"
assert fake_redis.last_stream == f"inbox:events:{message.recipient_id}"
assert fake_redis.last_payload is not None
assert '"op":"created"' in fake_redis.last_payload
@pytest.mark.asyncio
async def test_read_inbox_events_decodes_rows(monkeypatch) -> None:
fake_redis = _FakeRedis()
async def _fake_get_redis():
return fake_redis
monkeypatch.setattr(realtime, "get_or_init_redis_client", _fake_get_redis)
rows = await realtime.read_inbox_events(
recipient_id=uuid4(),
last_event_id=None,
)
assert len(rows) == 1
assert rows[0]["id"] == "1743313300000-0"
assert rows[0]["event"]["event_type"] == "INBOX_MESSAGE_CREATED"
@pytest.mark.asyncio
async def test_read_inbox_events_handles_redis_timeout(monkeypatch) -> None:
class _TimeoutRedis(_FakeRedis):
async def xread(self, _streams: dict[str, str], count: int, block: int):
del _streams, count, block
raise TimeoutError("read timeout")
fake_redis = _TimeoutRedis()
async def _fake_get_redis():
return fake_redis
monkeypatch.setattr(realtime, "get_or_init_redis_client", _fake_get_redis)
rows = await realtime.read_inbox_events(recipient_id=uuid4(), last_event_id=None)
assert rows == []
@@ -59,6 +59,9 @@ class FakeRepo:
return self._item
return None
async def get_item(self, item_id: UUID) -> ScheduleItem | None:
return await self.get_by_id(item_id)
async def create(self, data: dict) -> ScheduleItem:
return _create_mock_schedule_item(
owner_id=data["owner_id"],
@@ -74,6 +77,23 @@ class FakeRepo:
self._item.title = data["title"]
return self._item
async def update_item(self, item_id: UUID, data: dict) -> ScheduleItem | None:
if self._item is None:
return None
if "title" in data:
self._item.title = data["title"]
if "description" in data:
self._item.description = data["description"]
if "start_at" in data:
self._item.start_at = data["start_at"]
if "end_at" in data:
self._item.end_at = data["end_at"]
if "timezone" in data:
self._item.timezone = data["timezone"]
if "extra_metadata" in data:
self._item.extra_metadata = data["extra_metadata"]
return self._item
async def delete_by_item_id(
self, item_id: UUID, owner_id: UUID
) -> ScheduleItem | None:
@@ -81,6 +101,9 @@ class FakeRepo:
return None
return self._item
async def delete_item(self, item_id: UUID) -> None:
del item_id
async def list_by_date_range(
self, owner_id: UUID, start_at: datetime, end_at: datetime
) -> list[ScheduleItem]:
@@ -327,12 +350,11 @@ async def test_update_maps_metadata_to_extra_metadata(
captured: dict | None = None
class CaptureRepo(FakeRepo):
async def update_by_item_id(
self, item_id: UUID, owner_id: UUID, data: dict
) -> ScheduleItem | None:
async def update_item(self, item_id: UUID, data: dict) -> ScheduleItem | None:
nonlocal captured
del item_id
captured = data
return await super().update_by_item_id(item_id, owner_id, data)
return await super().update_item(item.id, data)
service = ScheduleItemService(
repository=CaptureRepo(item),
@@ -370,12 +392,11 @@ async def test_update_maps_null_metadata_to_extra_metadata_null(
captured: dict | None = None
class CaptureRepo(FakeRepo):
async def update_by_item_id(
self, item_id: UUID, owner_id: UUID, data: dict
) -> ScheduleItem | None:
async def update_item(self, item_id: UUID, data: dict) -> ScheduleItem | None:
nonlocal captured
del item_id
captured = data
return await super().update_by_item_id(item_id, owner_id, data)
return await super().update_item(item.id, data)
service = ScheduleItemService(
repository=CaptureRepo(item),
@@ -157,6 +157,14 @@ class FriendshipRepoStub:
return friendship
class UserRepoStub:
async def get_by_user_id(self, user_id: UUID):
profile = MagicMock()
profile.id = user_id
profile.username = "owner"
return profile
@pytest.mark.asyncio
async def test_share_forbidden_when_not_owner() -> None:
owner_id = UUID("00000000-0000-0000-0000-000000000001")
@@ -172,6 +180,7 @@ async def test_share_forbidden_when_not_owner() -> None:
auth_gateway=cast(Any, AuthGatewayStub()),
inbox_repository=InboxRepoStub(),
friendship_repository=cast(Any, FriendshipRepoStub()),
user_repository=cast(Any, UserRepoStub()),
)
with pytest.raises(ApiProblemError) as exc_info:
@@ -204,6 +213,7 @@ async def test_share_success_creates_calendar_invitation_message() -> None:
auth_gateway=cast(Any, AuthGatewayStub()),
inbox_repository=InboxRepoStub(),
friendship_repository=cast(Any, FriendshipRepoStub()),
user_repository=cast(Any, UserRepoStub()),
)
result = await service.share(
@@ -223,7 +233,17 @@ async def test_share_success_creates_calendar_invitation_message() -> None:
assert message.sender_id == owner_id
assert message.schedule_item_id == item_id
assert message.message_type == InboxMessageType.CALENDAR
assert message.content == {"type": "invite", "permission": 5, "action": "pending"}
assert message.content is not None
assert message.content["type"] == "invite"
assert message.content["schema_version"] == 2
assert message.content["permission"] == 5
assert message.content["item"]["id"] == str(item_id)
assert message.content["item"]["title"] == "test"
assert message.content["item"]["start_at"] == "2026-02-28T16:00:00+00:00"
assert message.content["item"]["end_at"] is None
assert message.content["item"]["timezone"] == "UTC"
assert message.content["actor"]["username"] == "owner"
assert message.content["actor"]["phone"] == "+8613810000000"
session.commit.assert_awaited_once()
@@ -237,6 +257,7 @@ async def test_share_returns_not_found_when_item_missing() -> None:
auth_gateway=cast(Any, AuthGatewayStub()),
inbox_repository=InboxRepoStub(),
friendship_repository=cast(Any, FriendshipRepoStub()),
user_repository=cast(Any, UserRepoStub()),
)
with pytest.raises(ApiProblemError) as exc_info:
@@ -268,6 +289,7 @@ async def test_share_invalid_auth_user_id_returns_503() -> None:
auth_gateway=cast(Any, AuthGatewayInvalidIdStub()),
inbox_repository=InboxRepoStub(),
friendship_repository=cast(Any, FriendshipRepoStub()),
user_repository=cast(Any, UserRepoStub()),
)
with pytest.raises(ApiProblemError) as exc_info:
@@ -302,6 +324,7 @@ async def test_share_sqlalchemy_error_rolls_back() -> None:
auth_gateway=cast(Any, AuthGatewayStub()),
inbox_repository=InboxRepoStub(),
friendship_repository=cast(Any, FriendshipRepoStub()),
user_repository=cast(Any, UserRepoStub()),
)
with pytest.raises(ApiProblemError) as exc_info:
@@ -334,6 +357,7 @@ async def test_share_returns_forbidden_when_target_is_not_friend() -> None:
auth_gateway=cast(Any, AuthGatewayStub()),
inbox_repository=InboxRepoStub(),
friendship_repository=cast(Any, FriendshipRepoStub(accepted=False)),
user_repository=cast(Any, UserRepoStub()),
)
with pytest.raises(ApiProblemError) as exc_info: