feat: 重构 Reminder Notification 系统并更新应用包名
This commit is contained in:
@@ -346,7 +346,7 @@ class AgentScopeRunner:
|
||||
*messages_for_router,
|
||||
],
|
||||
output_model=RouterAgentOutput,
|
||||
retries=0,
|
||||
retries=3,
|
||||
)
|
||||
response_msg = Msg(
|
||||
name="router",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
@@ -21,6 +22,7 @@ from core.agentscope.tools.tool_call_context import get_current_tool_call_id
|
||||
from schemas.agent.runtime_models import ErrorInfo, ToolAgentOutput, ToolStatus
|
||||
from v1.schedule_items.schemas import (
|
||||
ScheduleItemCreateRequest,
|
||||
ScheduleItemListRequest,
|
||||
ScheduleItemShareRequest,
|
||||
ScheduleItemStatus,
|
||||
ScheduleItemUpdateRequest,
|
||||
@@ -98,6 +100,13 @@ class CalendarWriteBatchArgs(BaseModel):
|
||||
operations: list[CalendarWriteOperation] = Field(min_length=1, max_length=20)
|
||||
|
||||
|
||||
class CalendarShareArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
event_id: str
|
||||
invitees: list[CalendarShareInvitee] = Field(min_length=1)
|
||||
|
||||
|
||||
def _validate_runtime_context(
|
||||
*,
|
||||
tool_name: str,
|
||||
@@ -116,69 +125,60 @@ def _validate_runtime_context(
|
||||
return None
|
||||
|
||||
|
||||
def _format_event_brief(event_items: list[dict[str, Any]], limit: int = 3) -> str:
|
||||
briefs: list[str] = []
|
||||
for item in event_items[:limit]:
|
||||
event_id = str(item.get("id") or "")
|
||||
title = str(item.get("title") or "")
|
||||
start_at = str(item.get("startAt") or "")
|
||||
end_at = str(item.get("endAt") or "")
|
||||
timezone = str(item.get("timezone") or "")
|
||||
status = str(item.get("status") or "")
|
||||
description = str(item.get("description") or "")
|
||||
location = str(item.get("location") or "")
|
||||
reminder_minutes = item.get("reminderMinutes")
|
||||
color = str(item.get("color") or "")
|
||||
source_type = str(item.get("sourceType") or "")
|
||||
updated_at = str(item.get("updatedAt") or "")
|
||||
permission = item.get("permission")
|
||||
is_owner = item.get("isOwner")
|
||||
if not event_id:
|
||||
continue
|
||||
briefs.append(
|
||||
"{"
|
||||
f"id={event_id},title={title},startAt={start_at},endAt={end_at},"
|
||||
f"timezone={timezone},status={status},description={description},"
|
||||
f"location={location},reminderMinutes={reminder_minutes},color={color},"
|
||||
f"sourceType={source_type},updatedAt={updated_at},permission={permission},"
|
||||
f"isOwner={is_owner}"
|
||||
"}"
|
||||
)
|
||||
return ",".join(briefs)
|
||||
|
||||
|
||||
async def calendar_read(
|
||||
query: Annotated[
|
||||
str | None,
|
||||
Field(description="Optional keyword to filter calendar events."),
|
||||
] = None,
|
||||
page: Annotated[
|
||||
int,
|
||||
Field(description="Page number, starting from 1.", ge=1),
|
||||
] = 1,
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(description="Number of items per page (1-100).", ge=1, le=100),
|
||||
] = 20,
|
||||
start_at: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="Start of date range in ISO8601 with timezone, e.g. 2026-03-30T00:00:00+08:00."
|
||||
),
|
||||
],
|
||||
end_at: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="End of date range in ISO8601 with timezone, e.g. 2026-03-30T23:59:59+08:00."
|
||||
),
|
||||
],
|
||||
session: Any = None,
|
||||
owner_id: Any = None,
|
||||
) -> ToolResponse:
|
||||
"""Read calendar events with optional keyword filtering and pagination.
|
||||
"""Read calendar events within a date range.
|
||||
|
||||
Status semantics for returned events:
|
||||
- active: Event is actionable.
|
||||
- archived: Event is historical/expired and should not trigger reminders.
|
||||
Returns subscribed calendar events (owned or shared) with permission info.
|
||||
|
||||
Status: active=actionable, archived=past/expired.
|
||||
|
||||
Permission flags: is_owner, can_view, can_edit, can_invite, can_delete.
|
||||
|
||||
Args:
|
||||
query: Optional keyword used to filter events by text fields.
|
||||
page: Page number starting from 1.
|
||||
page_size: Number of items per page, between 1 and 100.
|
||||
start_at: Start of date range (required).
|
||||
end_at: End of date range (required).
|
||||
|
||||
Returns:
|
||||
ToolResponse with serialized ToolAgentOutput payload.
|
||||
ToolResponse with JSON result:
|
||||
{
|
||||
"total": int,
|
||||
"items": [{
|
||||
"id": "uuid",
|
||||
"owner_id": "uuid",
|
||||
"title": "string",
|
||||
"description": "string|null",
|
||||
"start_at": "ISO8601 datetime",
|
||||
"end_at": "ISO8601 datetime|null",
|
||||
"timezone": "IANA timezone",
|
||||
"status": "active|archived",
|
||||
"source_type": "manual|imported|agent_generated",
|
||||
"permission": {"can_view", "can_edit", "can_invite", "can_delete", "is_owner"},
|
||||
"is_owner": boolean,
|
||||
"metadata": {color, location, reminder_minutes}|null,
|
||||
"subscribers": [{user_id, username, phone, permission, status}],
|
||||
"created_at": "ISO8601 datetime",
|
||||
"updated_at": "ISO8601 datetime"
|
||||
}]
|
||||
}
|
||||
"""
|
||||
tool_name = "calendar_read"
|
||||
tool_call_args = {"query": query, "page": page, "page_size": page_size}
|
||||
tool_call_args: dict[str, Any] = {"start_at": start_at, "end_at": end_at}
|
||||
|
||||
runtime_error = _validate_runtime_context(
|
||||
tool_name=tool_name,
|
||||
tool_call_args=tool_call_args,
|
||||
@@ -189,30 +189,30 @@ async def calendar_read(
|
||||
return runtime_error
|
||||
|
||||
try:
|
||||
parsed_start = parse_iso_datetime(start_at)
|
||||
parsed_end = parse_iso_datetime(end_at)
|
||||
if parsed_start is None or parsed_end is None:
|
||||
raise ValueError("start_at 和 end_at 都是必填项")
|
||||
if parsed_start >= parsed_end:
|
||||
raise ValueError("start_at 必须早于 end_at")
|
||||
|
||||
service = create_schedule_service(
|
||||
cast(AsyncSession, session), cast(UUID, owner_id)
|
||||
)
|
||||
items, total = await service.list_paginated(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
query=query,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size if total else 0
|
||||
request = ScheduleItemListRequest(start_at=parsed_start, end_at=parsed_end)
|
||||
items = await service.list_by_date_range(request)
|
||||
event_items = [schedule_event_to_dict(item) for item in items]
|
||||
event_brief = _format_event_brief(event_items)
|
||||
summary = (
|
||||
f"status=success total={total} total_pages={total_pages or 1} "
|
||||
f"returned={len(event_items)} has_next={str(page < (total_pages or 1)).lower()}"
|
||||
result = json.dumps(
|
||||
{"total": len(event_items), "items": event_items},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if event_brief:
|
||||
summary = f"{summary} items=[{event_brief}]"
|
||||
return dump_tool_output(
|
||||
ToolAgentOutput(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=get_current_tool_call_id(tool_name=tool_name),
|
||||
tool_call_args=tool_call_args,
|
||||
status=ToolStatus.SUCCESS,
|
||||
result=summary,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -532,10 +532,25 @@ async def calendar_share(
|
||||
ToolResponse with serialized ToolAgentOutput payload.
|
||||
"""
|
||||
tool_name = "calendar_share"
|
||||
try:
|
||||
parsed_args = CalendarShareArgs.model_validate(
|
||||
{"event_id": event_id, "invitees": invitees}
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
code, message, retryable = map_calendar_exception(exc)
|
||||
return calendar_error_output(
|
||||
tool_name=tool_name,
|
||||
tool_call_args={"event_id": event_id, "invitees": invitees},
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
)
|
||||
|
||||
tool_call_args = {
|
||||
"event_id": event_id,
|
||||
"event_id": parsed_args.event_id,
|
||||
"invitees": [
|
||||
invitee.model_dump(mode="json", by_alias=True) for invitee in invitees
|
||||
invitee.model_dump(mode="json", by_alias=True)
|
||||
for invitee in parsed_args.invitees
|
||||
],
|
||||
}
|
||||
runtime_error = _validate_runtime_context(
|
||||
@@ -551,11 +566,11 @@ async def calendar_share(
|
||||
service = create_schedule_service(
|
||||
cast(AsyncSession, session), cast(UUID, owner_id)
|
||||
)
|
||||
target_uuid = UUID(event_id)
|
||||
target_uuid = UUID(parsed_args.event_id)
|
||||
|
||||
invited: list[str] = []
|
||||
result_items: list[dict[str, str]] = []
|
||||
for invitee in invitees:
|
||||
for invitee in parsed_args.invitees:
|
||||
raw_phone = invitee.phone.strip()
|
||||
normalized_phone = raw_phone
|
||||
for separator in (" ", "-", "(", ")"):
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -12,8 +13,9 @@ from core.auth.models import CurrentUser
|
||||
from core.http.errors import ApiProblemError
|
||||
from v1.inbox_messages.repository import SQLAlchemyInboxMessageRepository
|
||||
from v1.schedule_items.repository import SQLAlchemyScheduleItemRepository
|
||||
from v1.schedule_items.schemas import ScheduleItemMetadata
|
||||
from v1.schedule_items.schemas import ScheduleItemMetadata, parse_permission
|
||||
from v1.schedule_items.service import ScheduleItemService
|
||||
from v1.users.repository import SQLAlchemyUserRepository
|
||||
|
||||
_HEX_COLOR_PATTERN = re.compile(r"^#[0-9A-Fa-f]{6}$")
|
||||
|
||||
@@ -39,31 +41,66 @@ def create_schedule_service(
|
||||
session=session,
|
||||
current_user=CurrentUser(id=owner_id),
|
||||
inbox_repository=SQLAlchemyInboxMessageRepository(session),
|
||||
user_repository=SQLAlchemyUserRepository(session),
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_event_timezone(dt: datetime, event_timezone: str) -> datetime:
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
tz = ZoneInfo(event_timezone) if event_timezone else ZoneInfo("UTC")
|
||||
return dt.astimezone(tz)
|
||||
|
||||
|
||||
def schedule_event_to_dict(event: object) -> dict[str, Any]:
|
||||
event_id = str(getattr(event, "id"))
|
||||
event_timezone = str(getattr(event, "timezone") or "UTC")
|
||||
start_at_utc = getattr(event, "start_at")
|
||||
end_at_utc = getattr(event, "end_at")
|
||||
permission_int = getattr(event, "permission", 1)
|
||||
is_owner = getattr(event, "is_owner", permission_int == 15)
|
||||
metadata = getattr(event, "metadata", None)
|
||||
subscribers = getattr(event, "subscribers", []) or []
|
||||
|
||||
def _serialize_dt(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
return _convert_to_event_timezone(dt, event_timezone).isoformat()
|
||||
|
||||
def _serialize_subscriber(sub: object) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": str(getattr(sub, "user_id", "")),
|
||||
"username": getattr(sub, "username", None),
|
||||
"avatar_url": getattr(sub, "avatar_url", None),
|
||||
"phone": getattr(sub, "phone", None),
|
||||
"permission": getattr(sub, "permission", 1),
|
||||
"status": str(getattr(sub, "status", "active")),
|
||||
"subscribed_at": _serialize_dt(getattr(sub, "subscribed_at", None)),
|
||||
}
|
||||
|
||||
status_value = getattr(event, "status", None)
|
||||
if hasattr(status_value, "value"):
|
||||
status_value = getattr(status_value, "value")
|
||||
location_value = getattr(metadata, "location", None)
|
||||
color_value = getattr(metadata, "color", None) or "#4F46E5"
|
||||
reminder_minutes_value = getattr(metadata, "reminder_minutes", None)
|
||||
if status_value is not None and hasattr(status_value, "value"):
|
||||
status_value = status_value.value
|
||||
|
||||
source_type_value = getattr(event, "source_type", None)
|
||||
if source_type_value is not None and hasattr(source_type_value, "value"):
|
||||
source_type_value = source_type_value.value
|
||||
|
||||
return {
|
||||
"id": event_id,
|
||||
"title": getattr(event, "title"),
|
||||
"description": getattr(event, "description"),
|
||||
"startAt": getattr(event, "start_at").isoformat(),
|
||||
"endAt": getattr(event, "end_at").isoformat()
|
||||
if getattr(event, "end_at") is not None
|
||||
else None,
|
||||
"timezone": getattr(event, "timezone"),
|
||||
"id": str(getattr(event, "id", "")),
|
||||
"owner_id": str(getattr(event, "owner_id", "")),
|
||||
"title": getattr(event, "title", ""),
|
||||
"description": getattr(event, "description", None),
|
||||
"start_at": _serialize_dt(start_at_utc),
|
||||
"end_at": _serialize_dt(end_at_utc),
|
||||
"timezone": event_timezone,
|
||||
"metadata": metadata.model_dump(mode="json") if metadata else None,
|
||||
"status": status_value,
|
||||
"location": location_value,
|
||||
"color": color_value,
|
||||
"reminderMinutes": reminder_minutes_value,
|
||||
"source_type": source_type_value,
|
||||
"created_at": _serialize_dt(getattr(event, "created_at", None)),
|
||||
"updated_at": _serialize_dt(getattr(event, "updated_at", None)),
|
||||
"permission": parse_permission(permission_int),
|
||||
"is_owner": is_owner,
|
||||
"subscribers": [_serialize_subscriber(sub) for sub in subscribers],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,34 @@
|
||||
input_template: |
|
||||
你正在执行一次“自动化记忆回顾与整理”任务。
|
||||
你正在执行一次"自动化记忆回顾与整理"任务。
|
||||
|
||||
任务目标:
|
||||
1) 回顾最近两天的聊天与上下文,识别用户长期偏好、习惯和关键事实的变化。
|
||||
2) 对已经失效、被否定或明显过期的信息执行遗忘。
|
||||
3) 对新增且有证据支持的信息执行写入。
|
||||
4) 严禁编造;没有证据就不要写入。
|
||||
5) 只更新最小必要字段,避免过度覆盖。
|
||||
任务目标:
|
||||
1) 回顾最近两天的聊天与上下文,识别用户长期偏好、习惯和关键事实的变化。
|
||||
2) 对已经失效、被否定或明显过期的信息执行遗忘。
|
||||
3) 对新增且有证据支持的信息执行写入。
|
||||
4) 严禁编造;没有证据就不要写入。
|
||||
5) 只更新最小必要字段,避免过度覆盖。
|
||||
|
||||
输出要求:
|
||||
- 必须使用以下固定格式输出;每一行都要有:
|
||||
【记忆回顾】<一句人性化总结,说明今天主要发生了什么>
|
||||
【新增记忆】<按“X条:要点1;要点2”描述;没有则写“0条”>
|
||||
【遗忘记忆】<按“X条:要点1;要点2”描述;没有则写“0条”>
|
||||
【未来展望】<基于本次记忆变化,给出1-2条温和、可执行的后续建议;若暂无建议则说明“可继续观察”>
|
||||
输出要求:
|
||||
- 必须使用以下固定格式输出:
|
||||
<----------【周期任务输出】---------->
|
||||
【记忆回顾】<一句人性化总结,说明今天主要发生了什么>
|
||||
【新增记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
|
||||
【遗忘记忆】<按"X条:要点1;要点2"描述;没有则写"0条">
|
||||
【未来展望】<基于本次记忆变化,给出1-2条温和、可执行的后续建议;若暂无建议则说明"可继续观察">
|
||||
|
||||
表达风格:
|
||||
- 语言自然、温和、可读,像助理在做每日回顾。
|
||||
- 结论先行,避免空话,不要输出与任务无关的闲聊内容。
|
||||
表达风格:
|
||||
- 语言自然、温和、可读,像助理在做每日回顾。
|
||||
- 结论先行,避免空话,不要输出与任务无关的闲聊内容。
|
||||
enabled_tools:
|
||||
- memory.write
|
||||
- memory.forget
|
||||
- memory.write
|
||||
- memory.forget
|
||||
context:
|
||||
source: latest_chat
|
||||
window_mode: day
|
||||
window_count: 2
|
||||
source: latest_chat
|
||||
window_mode: day
|
||||
window_count: 2
|
||||
schedule:
|
||||
type: daily
|
||||
run_at:
|
||||
hour: 8
|
||||
minute: 0
|
||||
weekdays: null
|
||||
type: daily
|
||||
run_at:
|
||||
hour: 8
|
||||
minute: 0
|
||||
weekdays: null
|
||||
|
||||
@@ -5,7 +5,7 @@ import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from core.automation.scheduler import run_automation_scheduler_scan
|
||||
from core.config.initial.init_data import initialize_data
|
||||
@@ -118,22 +118,30 @@ async def run_automation_scheduler_forever() -> None:
|
||||
batch_limit=batch_limit,
|
||||
)
|
||||
|
||||
def scan_job() -> None:
|
||||
async def scan_job() -> None:
|
||||
try:
|
||||
asyncio.run(run_automation_scheduler_scan(limit=batch_limit))
|
||||
await run_automation_scheduler_scan(limit=batch_limit)
|
||||
except Exception as exc:
|
||||
logger.exception("Automation scheduler scan failed", error=str(exc))
|
||||
|
||||
scheduler = BlockingScheduler()
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(
|
||||
scan_job,
|
||||
trigger=IntervalTrigger(seconds=interval_seconds),
|
||||
id="automation_scheduler_scan",
|
||||
name="Automation scheduler scan",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
stop_event = asyncio.Event()
|
||||
try:
|
||||
await stop_event.wait()
|
||||
finally:
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""CLI entry point."""
|
||||
|
||||
@@ -100,6 +100,13 @@ class ConstraintItem(BaseModel):
|
||||
value: str
|
||||
required: bool = True
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: object) -> object:
|
||||
if isinstance(value, bool | int | float):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
class NormalizedTaskInput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@@ -211,6 +211,20 @@ class UiHintListItem(UiHintBaseModel):
|
||||
status: UiHintStatus | None = Field(default=None)
|
||||
actions: list[UiHintAction] = Field(default_factory=list)
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def normalize_status(cls, value: object) -> object:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
status_type = value.get("type")
|
||||
if isinstance(status_type, str):
|
||||
return status_type
|
||||
status_value = value.get("status")
|
||||
if isinstance(status_value, str):
|
||||
return status_value
|
||||
return value
|
||||
|
||||
|
||||
class UiHintSection(UiHintBaseModel):
|
||||
title: str | None = Field(default=None, description="Section title.")
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db import get_db
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.friendships.repository import SQLAlchemyFriendshipRepository
|
||||
from v1.friendships.service import FriendshipService
|
||||
from v1.users.dependencies import get_current_user
|
||||
@@ -25,9 +26,11 @@ async def get_friendship_service(
|
||||
) -> FriendshipService:
|
||||
friendship_repository = SQLAlchemyFriendshipRepository(session)
|
||||
user_repository = SQLAlchemyUserRepository(session)
|
||||
auth_gateway = SupabaseAuthGateway()
|
||||
return FriendshipService(
|
||||
repository=friendship_repository,
|
||||
user_repository=user_repository,
|
||||
session=session,
|
||||
current_user=current_user,
|
||||
auth_gateway=auth_gateway,
|
||||
)
|
||||
|
||||
@@ -9,10 +9,17 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from v1.inbox_messages.realtime import (
|
||||
InboxMessageEventSnapshot,
|
||||
publish_inbox_message_created,
|
||||
publish_inbox_message_status_changed,
|
||||
snapshot_from_inbox_message,
|
||||
)
|
||||
from core.logging import get_logger
|
||||
from models.friendships import Friendship
|
||||
from models.inbox_messages import InboxMessage
|
||||
from schemas.enums import FriendshipStatus, InboxMessageStatus, InboxMessageType
|
||||
from v1.auth.gateway import SupabaseAuthGateway
|
||||
from v1.friendships.repository import FriendshipRepository
|
||||
from v1.friendships.schemas import (
|
||||
FriendRequestCreate,
|
||||
@@ -42,6 +49,7 @@ class FriendshipService(BaseService):
|
||||
_repository: FriendshipRepository
|
||||
_user_repository: UserRepository
|
||||
_session: AsyncSession
|
||||
_auth_gateway: SupabaseAuthGateway | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,11 +57,13 @@ class FriendshipService(BaseService):
|
||||
user_repository: UserRepository,
|
||||
session: AsyncSession,
|
||||
current_user: CurrentUser | None,
|
||||
auth_gateway: SupabaseAuthGateway | None = None,
|
||||
) -> None:
|
||||
super().__init__(current_user=current_user)
|
||||
self._repository = repository
|
||||
self._user_repository = user_repository
|
||||
self._session = session
|
||||
self._auth_gateway = auth_gateway
|
||||
|
||||
async def send_request(self, request: FriendRequestCreate) -> FriendRequestResponse:
|
||||
user_id = self.require_user_id()
|
||||
@@ -103,6 +113,8 @@ class FriendshipService(BaseService):
|
||||
friendship, inbox = await self._repository.reactivate_request(
|
||||
existing, user_id, request.content
|
||||
)
|
||||
await self._session.flush()
|
||||
inbox_event = snapshot_from_inbox_message(inbox)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
@@ -121,6 +133,7 @@ class FriendshipService(BaseService):
|
||||
"target_id": str(target_user_id),
|
||||
},
|
||||
)
|
||||
await self._publish_created_events([inbox_event])
|
||||
return await self._build_friend_request_response(
|
||||
friendship, inbox, user_id, target_user_id
|
||||
)
|
||||
@@ -129,6 +142,8 @@ class FriendshipService(BaseService):
|
||||
friendship, inbox = await self._repository.create_request(
|
||||
user_id, target_user_id, request.content
|
||||
)
|
||||
await self._session.flush()
|
||||
inbox_event = snapshot_from_inbox_message(inbox)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
@@ -144,6 +159,7 @@ class FriendshipService(BaseService):
|
||||
"friend_request_sent",
|
||||
extra={"initiator_id": str(user_id), "target_id": str(target_user_id)},
|
||||
)
|
||||
await self._publish_created_events([inbox_event])
|
||||
|
||||
return await self._build_friend_request_response(
|
||||
friendship, inbox, user_id, target_user_id
|
||||
@@ -172,11 +188,7 @@ class FriendshipService(BaseService):
|
||||
),
|
||||
)
|
||||
|
||||
recipient_id = (
|
||||
friendship.user_low_id
|
||||
if friendship.initiator_id == friendship.user_high_id
|
||||
else friendship.user_high_id
|
||||
)
|
||||
recipient_id = self._get_recipient_id(friendship)
|
||||
|
||||
if recipient_id != user_id:
|
||||
logger.warning(
|
||||
@@ -218,6 +230,7 @@ class FriendshipService(BaseService):
|
||||
|
||||
friendship.status = FriendshipStatus.ACCEPTED
|
||||
inbox.status = InboxMessageStatus.ACCEPTED
|
||||
inbox_event = snapshot_from_inbox_message(inbox)
|
||||
|
||||
try:
|
||||
await self._session.commit()
|
||||
@@ -249,6 +262,7 @@ class FriendshipService(BaseService):
|
||||
"initiator_id": str(sender_id),
|
||||
},
|
||||
)
|
||||
await self._publish_status_events([inbox_event])
|
||||
|
||||
sender = await self._user_repository.get_by_user_id(sender_id)
|
||||
recipient = await self._user_repository.get_by_user_id(user_id)
|
||||
@@ -285,11 +299,7 @@ class FriendshipService(BaseService):
|
||||
),
|
||||
)
|
||||
|
||||
recipient_id = (
|
||||
friendship.user_low_id
|
||||
if friendship.initiator_id == friendship.user_high_id
|
||||
else friendship.user_high_id
|
||||
)
|
||||
recipient_id = self._get_recipient_id(friendship)
|
||||
|
||||
if recipient_id != user_id:
|
||||
logger.warning(
|
||||
@@ -322,8 +332,10 @@ class FriendshipService(BaseService):
|
||||
)
|
||||
|
||||
friendship.status = FriendshipStatus.DECLINED
|
||||
inbox_event: InboxMessageEventSnapshot | None = None
|
||||
if inbox:
|
||||
inbox.status = InboxMessageStatus.REJECTED
|
||||
inbox_event = snapshot_from_inbox_message(inbox)
|
||||
|
||||
try:
|
||||
await self._session.commit()
|
||||
@@ -355,6 +367,8 @@ class FriendshipService(BaseService):
|
||||
"initiator_id": str(sender_id),
|
||||
},
|
||||
)
|
||||
if inbox_event:
|
||||
await self._publish_status_events([inbox_event])
|
||||
|
||||
sender = await self._user_repository.get_by_user_id(sender_id)
|
||||
recipient = await self._user_repository.get_by_user_id(user_id)
|
||||
@@ -422,8 +436,10 @@ class FriendshipService(BaseService):
|
||||
)
|
||||
|
||||
friendship.status = FriendshipStatus.CANCELED
|
||||
inbox_event: InboxMessageEventSnapshot | None = None
|
||||
if inbox:
|
||||
inbox.status = InboxMessageStatus.DISMISSED
|
||||
inbox_event = snapshot_from_inbox_message(inbox)
|
||||
|
||||
try:
|
||||
await self._session.commit()
|
||||
@@ -457,6 +473,8 @@ class FriendshipService(BaseService):
|
||||
"target_id": str(recipient_id),
|
||||
},
|
||||
)
|
||||
if inbox_event:
|
||||
await self._publish_status_events([inbox_event])
|
||||
|
||||
return FriendRequestResponse(
|
||||
id=friendship.id,
|
||||
@@ -583,11 +601,7 @@ class FriendshipService(BaseService):
|
||||
)
|
||||
|
||||
sender = await self._user_repository.get_by_user_id(initiator_id)
|
||||
recipient_id = (
|
||||
friendship.user_low_id
|
||||
if friendship.user_low_id != initiator_id
|
||||
else friendship.user_high_id
|
||||
)
|
||||
recipient_id = self._get_recipient_id(friendship)
|
||||
recipient = await self._user_repository.get_by_user_id(recipient_id)
|
||||
|
||||
# Map FriendshipStatus to response status
|
||||
@@ -676,15 +690,24 @@ class FriendshipService(BaseService):
|
||||
]
|
||||
profiles_by_id = await self._user_repository.get_by_user_ids(friend_ids)
|
||||
|
||||
auth_users_by_id: dict[str, str | None] = {}
|
||||
if self._auth_gateway is not None:
|
||||
auth_users = await self._auth_gateway.get_users_by_ids(
|
||||
[str(fid) for fid in friend_ids]
|
||||
)
|
||||
for uid, auth_user in auth_users.items():
|
||||
auth_users_by_id[uid] = auth_user.phone
|
||||
|
||||
result: list[FriendResponse] = []
|
||||
for friendship in friendships:
|
||||
friend_id = self._get_other_user_id(friendship, user_id)
|
||||
friend = profiles_by_id.get(friend_id)
|
||||
phone = auth_users_by_id.get(str(friend_id))
|
||||
|
||||
result.append(
|
||||
FriendResponse(
|
||||
id=friendship.id,
|
||||
friend=self._build_user_basic_info(friend),
|
||||
friend=self._build_user_basic_info(friend, phone),
|
||||
status="active",
|
||||
created_at=friendship.created_at,
|
||||
accepted_at=friendship.updated_at,
|
||||
@@ -760,7 +783,9 @@ class FriendshipService(BaseService):
|
||||
accepted_at=friendship.updated_at,
|
||||
)
|
||||
|
||||
def _build_user_basic_info(self, profile: Any) -> "UserContext":
|
||||
def _build_user_basic_info(
|
||||
self, profile: Any, phone: str | None = None
|
||||
) -> "UserContext":
|
||||
from schemas.shared.user import UserContext
|
||||
|
||||
if profile is None:
|
||||
@@ -770,7 +795,9 @@ class FriendshipService(BaseService):
|
||||
return UserContext(
|
||||
id=str(p.id),
|
||||
username=p.username,
|
||||
phone=phone,
|
||||
avatar_url=p.avatar_url if hasattr(p, "avatar_url") else None,
|
||||
bio=p.bio if hasattr(p, "bio") else None,
|
||||
)
|
||||
|
||||
async def _build_friend_request_response(
|
||||
@@ -800,3 +827,50 @@ class FriendshipService(BaseService):
|
||||
if friendship.user_low_id == current_user_id
|
||||
else friendship.user_low_id
|
||||
)
|
||||
|
||||
def _get_recipient_id(self, friendship: Friendship) -> UUID:
|
||||
return (
|
||||
friendship.user_low_id
|
||||
if friendship.initiator_id == friendship.user_high_id
|
||||
else friendship.user_high_id
|
||||
)
|
||||
|
||||
async def _publish_created_events(
|
||||
self, messages: list[InboxMessageEventSnapshot]
|
||||
) -> None:
|
||||
for message in messages:
|
||||
try:
|
||||
await publish_inbox_message_created(message)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Failed to publish inbox created event",
|
||||
message_id=str(message.message_id),
|
||||
recipient_id=str(message.recipient_id),
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="INBOX_EVENT_STREAM_UNAVAILABLE",
|
||||
detail="Inbox event stream unavailable",
|
||||
),
|
||||
) from exc
|
||||
|
||||
async def _publish_status_events(
|
||||
self, messages: list[InboxMessageEventSnapshot]
|
||||
) -> None:
|
||||
for message in messages:
|
||||
try:
|
||||
await publish_inbox_message_status_changed(message)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Failed to publish inbox status event",
|
||||
message_id=str(message.message_id),
|
||||
recipient_id=str(message.recipient_id),
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="INBOX_EVENT_STREAM_UNAVAILABLE",
|
||||
detail="Inbox event stream unavailable",
|
||||
),
|
||||
) from exc
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from redis.exceptions import TimeoutError as RedisTimeoutError
|
||||
|
||||
from core.config.settings import config
|
||||
from core.logging import get_logger
|
||||
from models.inbox_messages import InboxMessage
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
|
||||
logger = get_logger("v1.inbox_messages.realtime")
|
||||
|
||||
INBOX_STREAM_PREFIX = "inbox:events"
|
||||
|
||||
EVENT_MESSAGE_CREATED = "INBOX_MESSAGE_CREATED"
|
||||
EVENT_MESSAGE_READ_CHANGED = "INBOX_MESSAGE_READ_CHANGED"
|
||||
EVENT_MESSAGE_STATUS_CHANGED = "INBOX_MESSAGE_STATUS_CHANGED"
|
||||
EVENT_SNAPSHOT_REQUIRED = "INBOX_SNAPSHOT_REQUIRED"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InboxMessageEventSnapshot:
|
||||
message_id: UUID
|
||||
recipient_id: UUID
|
||||
sender_id: UUID | None
|
||||
message_type: str
|
||||
schedule_item_id: UUID | None
|
||||
friendship_id: UUID | None
|
||||
content: dict[str, Any] | None
|
||||
is_read: bool
|
||||
status: str
|
||||
created_at: datetime
|
||||
occurred_at: datetime
|
||||
|
||||
|
||||
def snapshot_from_inbox_message(message: InboxMessage) -> InboxMessageEventSnapshot:
|
||||
message_type = (
|
||||
message.message_type.value
|
||||
if hasattr(message.message_type, "value")
|
||||
else str(message.message_type)
|
||||
)
|
||||
status = (
|
||||
message.status.value
|
||||
if hasattr(message.status, "value")
|
||||
else str(message.status)
|
||||
)
|
||||
if status in {"None", ""}:
|
||||
status = "pending"
|
||||
created_at = (
|
||||
message.created_at
|
||||
if isinstance(message.created_at, datetime)
|
||||
else datetime.now(UTC)
|
||||
)
|
||||
occurred_at = (
|
||||
message.updated_at if isinstance(message.updated_at, datetime) else created_at
|
||||
)
|
||||
message_id = message.id if isinstance(message.id, UUID) else uuid4()
|
||||
return InboxMessageEventSnapshot(
|
||||
message_id=message_id,
|
||||
recipient_id=message.recipient_id,
|
||||
sender_id=message.sender_id,
|
||||
message_type=message_type,
|
||||
schedule_item_id=message.schedule_item_id,
|
||||
friendship_id=message.friendship_id,
|
||||
content=message.content,
|
||||
is_read=bool(message.is_read),
|
||||
status=status,
|
||||
created_at=created_at,
|
||||
occurred_at=occurred_at,
|
||||
)
|
||||
|
||||
|
||||
def to_inbox_sse_event(stream_id: str, event_type: str, payload: dict[str, Any]) -> str:
|
||||
safe_stream_id = str(stream_id).replace("\r", "").replace("\n", "")
|
||||
safe_event_type = str(event_type).replace("\r", "").replace("\n", "")
|
||||
data = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
|
||||
return f"id: {safe_stream_id}\nevent: {safe_event_type}\ndata: {data}\n\n"
|
||||
|
||||
|
||||
def _stream_name(recipient_id: UUID) -> str:
|
||||
return f"{INBOX_STREAM_PREFIX}:{recipient_id}"
|
||||
|
||||
|
||||
def _to_epoch_ms(value: datetime) -> int:
|
||||
normalized = value.astimezone(UTC)
|
||||
return int(normalized.timestamp() * 1000)
|
||||
|
||||
|
||||
def _resolve_occurred_at(snapshot: InboxMessageEventSnapshot) -> datetime:
|
||||
if isinstance(snapshot.occurred_at, datetime):
|
||||
return snapshot.occurred_at
|
||||
if isinstance(snapshot.created_at, datetime):
|
||||
return snapshot.created_at
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def _safe_stream_block_ms(requested_ms: int) -> int:
|
||||
try:
|
||||
socket_timeout_ms = max(int(float(config.redis.socket_timeout) * 1000), 1)
|
||||
except (TypeError, ValueError):
|
||||
socket_timeout_ms = 5000
|
||||
safe_max = max(socket_timeout_ms - 100, 1)
|
||||
return max(1, min(int(requested_ms), safe_max))
|
||||
|
||||
|
||||
def _message_to_payload(snapshot: InboxMessageEventSnapshot) -> dict[str, Any]:
|
||||
return {
|
||||
"id": str(snapshot.message_id),
|
||||
"recipient_id": str(snapshot.recipient_id),
|
||||
"sender_id": str(snapshot.sender_id) if snapshot.sender_id else None,
|
||||
"message_type": snapshot.message_type,
|
||||
"schedule_item_id": str(snapshot.schedule_item_id)
|
||||
if snapshot.schedule_item_id
|
||||
else None,
|
||||
"friendship_id": str(snapshot.friendship_id)
|
||||
if snapshot.friendship_id
|
||||
else None,
|
||||
"content": snapshot.content,
|
||||
"is_read": bool(snapshot.is_read),
|
||||
"status": snapshot.status,
|
||||
"created_at": snapshot.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def _publish_event(recipient_id: UUID, payload: dict[str, Any]) -> str:
|
||||
redis = await get_or_init_redis_client()
|
||||
stream_name = _stream_name(recipient_id)
|
||||
event_json = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
|
||||
result = redis.xadd(stream_name, {"event": event_json})
|
||||
if inspect.isawaitable(result):
|
||||
return str(await result)
|
||||
return str(result)
|
||||
|
||||
|
||||
async def publish_inbox_message_created(
|
||||
message: InboxMessage | InboxMessageEventSnapshot,
|
||||
) -> str:
|
||||
snapshot = (
|
||||
message
|
||||
if isinstance(message, InboxMessageEventSnapshot)
|
||||
else snapshot_from_inbox_message(message)
|
||||
)
|
||||
occurred_at = _resolve_occurred_at(snapshot)
|
||||
version = _to_epoch_ms(occurred_at)
|
||||
payload = {
|
||||
"event_id": str(uuid4()),
|
||||
"occurred_at": occurred_at.isoformat(),
|
||||
"user_id": str(snapshot.recipient_id),
|
||||
"message_id": str(snapshot.message_id),
|
||||
"event_type": EVENT_MESSAGE_CREATED,
|
||||
"op": "created",
|
||||
"version": version,
|
||||
"data": {"message": _message_to_payload(snapshot)},
|
||||
}
|
||||
return await _publish_event(snapshot.recipient_id, payload)
|
||||
|
||||
|
||||
async def publish_inbox_message_read_changed(
|
||||
message: InboxMessage | InboxMessageEventSnapshot,
|
||||
) -> str:
|
||||
snapshot = (
|
||||
message
|
||||
if isinstance(message, InboxMessageEventSnapshot)
|
||||
else snapshot_from_inbox_message(message)
|
||||
)
|
||||
occurred_at = _resolve_occurred_at(snapshot)
|
||||
payload = {
|
||||
"event_id": str(uuid4()),
|
||||
"occurred_at": occurred_at.isoformat(),
|
||||
"user_id": str(snapshot.recipient_id),
|
||||
"message_id": str(snapshot.message_id),
|
||||
"event_type": EVENT_MESSAGE_READ_CHANGED,
|
||||
"op": "read_changed",
|
||||
"version": _to_epoch_ms(occurred_at),
|
||||
"data": {"is_read": bool(snapshot.is_read)},
|
||||
}
|
||||
return await _publish_event(snapshot.recipient_id, payload)
|
||||
|
||||
|
||||
async def publish_inbox_message_status_changed(
|
||||
message: InboxMessage | InboxMessageEventSnapshot,
|
||||
) -> str:
|
||||
snapshot = (
|
||||
message
|
||||
if isinstance(message, InboxMessageEventSnapshot)
|
||||
else snapshot_from_inbox_message(message)
|
||||
)
|
||||
occurred_at = _resolve_occurred_at(snapshot)
|
||||
payload = {
|
||||
"event_id": str(uuid4()),
|
||||
"occurred_at": occurred_at.isoformat(),
|
||||
"user_id": str(snapshot.recipient_id),
|
||||
"message_id": str(snapshot.message_id),
|
||||
"event_type": EVENT_MESSAGE_STATUS_CHANGED,
|
||||
"op": "status_changed",
|
||||
"version": _to_epoch_ms(occurred_at),
|
||||
"data": {"status": snapshot.status},
|
||||
}
|
||||
return await _publish_event(snapshot.recipient_id, payload)
|
||||
|
||||
|
||||
async def publish_inbox_snapshot_required(
|
||||
*, recipient_id: UUID, message_id: UUID
|
||||
) -> str:
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
"event_id": str(uuid4()),
|
||||
"occurred_at": now.isoformat(),
|
||||
"user_id": str(recipient_id),
|
||||
"message_id": str(message_id),
|
||||
"event_type": EVENT_SNAPSHOT_REQUIRED,
|
||||
"op": "snapshot_required",
|
||||
"version": _to_epoch_ms(now),
|
||||
"data": {},
|
||||
}
|
||||
return await _publish_event(recipient_id, payload)
|
||||
|
||||
|
||||
async def read_inbox_events(
|
||||
*,
|
||||
recipient_id: UUID,
|
||||
last_event_id: str | None,
|
||||
count: int = 100,
|
||||
block_ms: int = 5000,
|
||||
) -> list[dict[str, Any]]:
|
||||
redis = await get_or_init_redis_client()
|
||||
stream = _stream_name(recipient_id)
|
||||
start_id = "0-0" if not last_event_id else last_event_id
|
||||
safe_block_ms = _safe_stream_block_ms(block_ms)
|
||||
try:
|
||||
raw = redis.xread({stream: start_id}, count=count, block=safe_block_ms)
|
||||
response = await raw if inspect.isawaitable(raw) else raw
|
||||
except (TimeoutError, asyncio.TimeoutError, RedisTimeoutError):
|
||||
return []
|
||||
if not response:
|
||||
return []
|
||||
|
||||
first = response[0]
|
||||
if not isinstance(first, (list, tuple)) or len(first) != 2:
|
||||
return []
|
||||
entries_raw = first[1]
|
||||
if not isinstance(entries_raw, list):
|
||||
return []
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for entry in entries_raw:
|
||||
if not isinstance(entry, (list, tuple)) or len(entry) != 2:
|
||||
continue
|
||||
entry_id_raw, fields = entry
|
||||
if isinstance(entry_id_raw, bytes):
|
||||
stream_id = entry_id_raw.decode("utf-8", errors="replace")
|
||||
elif isinstance(entry_id_raw, str):
|
||||
stream_id = entry_id_raw
|
||||
else:
|
||||
continue
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
payload_raw = fields.get("event")
|
||||
if isinstance(payload_raw, bytes):
|
||||
payload_raw = payload_raw.decode("utf-8", errors="replace")
|
||||
if not isinstance(payload_raw, str):
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(payload_raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Discard malformed inbox stream payload", stream_id=stream_id
|
||||
)
|
||||
continue
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
rows.append({"id": stream_id, "event": payload})
|
||||
return rows
|
||||
@@ -1,15 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from services.base.redis import get_or_init_redis_client
|
||||
from v1.inbox_messages.realtime import to_inbox_sse_event
|
||||
|
||||
from v1.inbox_messages.dependencies import get_inbox_message_service
|
||||
from v1.inbox_messages.schemas import InboxMessageResponse
|
||||
from v1.inbox_messages.service import InboxMessageService
|
||||
|
||||
router = APIRouter(prefix="/inbox/messages", tags=["inbox-messages"])
|
||||
_LAST_EVENT_ID_RE = re.compile(r"^\d+-\d+$")
|
||||
_MAX_SSE_CONNECTIONS_PER_USER = 3
|
||||
_SSE_SLOT_TTL_SECONDS = 15 * 60
|
||||
|
||||
|
||||
async def _acquire_sse_slot(*, user_id: str) -> bool:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"inbox:sse-active:{user_id}"
|
||||
count = await redis.incr(key)
|
||||
if count == 1:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
elif count > _MAX_SSE_CONNECTIONS_PER_USER:
|
||||
await redis.decr(key)
|
||||
return False
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
return True
|
||||
|
||||
|
||||
async def _release_sse_slot(*, user_id: str) -> None:
|
||||
redis = await get_or_init_redis_client()
|
||||
key = f"inbox:sse-active:{user_id}"
|
||||
count = await redis.decr(key)
|
||||
if count <= 0:
|
||||
await redis.delete(key)
|
||||
else:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl < 0:
|
||||
await redis.expire(key, _SSE_SLOT_TTL_SECONDS)
|
||||
|
||||
|
||||
@router.get("", response_model=list[InboxMessageResponse])
|
||||
@@ -26,3 +65,64 @@ async def mark_as_read(
|
||||
service: Annotated[InboxMessageService, Depends(get_inbox_message_service)],
|
||||
) -> InboxMessageResponse:
|
||||
return await service.mark_as_read(message_id)
|
||||
|
||||
|
||||
@router.get("/stream")
|
||||
async def stream_inbox_events(
|
||||
request: Request,
|
||||
service: Annotated[InboxMessageService, Depends(get_inbox_message_service)],
|
||||
last_event_id: str | None = Header(default=None, alias="Last-Event-ID"),
|
||||
idle_limit: int = Query(default=300, ge=1, le=3600),
|
||||
) -> StreamingResponse:
|
||||
if last_event_id is not None and (
|
||||
len(last_event_id) > 32 or _LAST_EVENT_ID_RE.fullmatch(last_event_id) is None
|
||||
):
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
detail=problem_payload(
|
||||
code="INBOX_INVALID_LAST_EVENT_ID",
|
||||
detail="Invalid Last-Event-ID",
|
||||
),
|
||||
)
|
||||
|
||||
user_id = str(service.require_user_id())
|
||||
slot_acquired = await _acquire_sse_slot(user_id=user_id)
|
||||
if not slot_acquired:
|
||||
raise ApiProblemError(
|
||||
status_code=429,
|
||||
detail=problem_payload(
|
||||
code="INBOX_SSE_CONNECTION_LIMIT",
|
||||
detail="Too many SSE connections",
|
||||
),
|
||||
)
|
||||
|
||||
async def _event_iter() -> AsyncIterator[str]:
|
||||
cursor = last_event_id
|
||||
idle_polls = 0
|
||||
try:
|
||||
while not await request.is_disconnected() and idle_polls < idle_limit:
|
||||
rows = await service.stream_events(last_event_id=cursor)
|
||||
if not rows:
|
||||
idle_polls += 1
|
||||
yield ": keep-alive\n\n"
|
||||
await asyncio.sleep(0.2)
|
||||
continue
|
||||
|
||||
idle_polls = 0
|
||||
for row in rows:
|
||||
stream_id = row.get("id")
|
||||
event = row.get("event")
|
||||
if not isinstance(stream_id, str) or not isinstance(event, dict):
|
||||
continue
|
||||
cursor = stream_id
|
||||
event_type = event.get("event_type")
|
||||
if not isinstance(event_type, str) or not event_type:
|
||||
event_type = "INBOX_MESSAGE"
|
||||
yield to_inbox_sse_event(stream_id, event_type, event)
|
||||
finally:
|
||||
await _release_sse_slot(user_id=user_id)
|
||||
|
||||
response = StreamingResponse(_event_iter(), media_type="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-cache"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
return response
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import json
|
||||
@@ -10,6 +11,11 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.http.errors import ApiProblemError
|
||||
from v1.inbox_messages.realtime import (
|
||||
publish_inbox_message_read_changed,
|
||||
read_inbox_events,
|
||||
snapshot_from_inbox_message,
|
||||
)
|
||||
from core.logging import get_logger
|
||||
from models.inbox_messages import InboxMessage
|
||||
from v1.inbox_messages.repository import InboxMessageRepository
|
||||
@@ -71,6 +77,8 @@ class InboxMessageService(BaseService):
|
||||
code="INBOX_MESSAGE_NOT_FOUND",
|
||||
detail="Inbox message not found",
|
||||
)
|
||||
event_snapshot = snapshot_from_inbox_message(updated)
|
||||
response = self._to_response(updated)
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
@@ -85,7 +93,44 @@ class InboxMessageService(BaseService):
|
||||
detail="Inbox message store unavailable",
|
||||
)
|
||||
|
||||
return self._to_response(updated)
|
||||
try:
|
||||
await publish_inbox_message_read_changed(event_snapshot)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Failed to publish inbox read-changed event",
|
||||
message_id=str(event_snapshot.message_id),
|
||||
user_id=str(event_snapshot.recipient_id),
|
||||
)
|
||||
raise _inbox_error(
|
||||
status_code=503,
|
||||
code="INBOX_EVENT_STREAM_UNAVAILABLE",
|
||||
detail="Inbox event stream unavailable",
|
||||
) from exc
|
||||
|
||||
return response
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
user_id = self.require_user_id()
|
||||
try:
|
||||
return await read_inbox_events(
|
||||
recipient_id=user_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to read inbox event stream",
|
||||
user_id=str(user_id),
|
||||
reason=str(exc),
|
||||
)
|
||||
raise _inbox_error(
|
||||
status_code=503,
|
||||
code="INBOX_EVENT_STREAM_UNAVAILABLE",
|
||||
detail="Inbox event stream unavailable",
|
||||
) from exc
|
||||
|
||||
def _to_response(self, message: InboxMessage) -> InboxMessageResponse:
|
||||
status_value = (
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Protocol, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select, update, delete
|
||||
from sqlalchemy import or_, select, update, delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
@@ -28,14 +28,6 @@ class ScheduleItemRepository(Protocol):
|
||||
async def list_by_date_range(
|
||||
self, owner_id: UUID, start_at: datetime, end_at: datetime
|
||||
) -> list[ScheduleItem]: ...
|
||||
async def list_paginated(
|
||||
self,
|
||||
owner_id: UUID,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
query: str | None = None,
|
||||
) -> tuple[list[ScheduleItem], int]: ...
|
||||
async def create_subscription(self, data: dict) -> ScheduleSubscription: ...
|
||||
async def get_subscriptions_by_item_id(
|
||||
self, item_id: UUID
|
||||
@@ -154,62 +146,6 @@ class SQLAlchemyScheduleItemRepository(BaseRepository[ScheduleItem]):
|
||||
logger.exception("Schedule item list failed", owner_id=str(owner_id))
|
||||
raise
|
||||
|
||||
async def list_paginated(
|
||||
self,
|
||||
owner_id: UUID,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
query: str | None = None,
|
||||
) -> tuple[list[ScheduleItem], int]:
|
||||
offset = (page - 1) * page_size
|
||||
normalized_query = (query or "").strip()
|
||||
has_query = bool(normalized_query)
|
||||
query_like = f"%{normalized_query}%"
|
||||
try:
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(ScheduleItem)
|
||||
.where(ScheduleItem.owner_id == owner_id)
|
||||
.where(ScheduleItem.deleted_at.is_(None))
|
||||
)
|
||||
if has_query:
|
||||
count_stmt = count_stmt.where(
|
||||
or_(
|
||||
ScheduleItem.title.ilike(query_like),
|
||||
ScheduleItem.description.ilike(query_like),
|
||||
)
|
||||
)
|
||||
count_result = await self._session.execute(count_stmt)
|
||||
total = int(count_result.scalar_one() or 0)
|
||||
|
||||
items_stmt = (
|
||||
select(ScheduleItem)
|
||||
.where(ScheduleItem.owner_id == owner_id)
|
||||
.where(ScheduleItem.deleted_at.is_(None))
|
||||
.order_by(ScheduleItem.start_at.asc(), ScheduleItem.id.asc())
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
if has_query:
|
||||
items_stmt = items_stmt.where(
|
||||
or_(
|
||||
ScheduleItem.title.ilike(query_like),
|
||||
ScheduleItem.description.ilike(query_like),
|
||||
)
|
||||
)
|
||||
items_result = await self._session.execute(items_stmt)
|
||||
items = list(items_result.scalars().all())
|
||||
return items, total
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Schedule item paginated list failed",
|
||||
owner_id=str(owner_id),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_subscription(self, data: dict) -> ScheduleSubscription:
|
||||
sub = ScheduleSubscription(**data)
|
||||
self._session.add(sub)
|
||||
|
||||
@@ -21,6 +21,7 @@ from schemas.domain.schedule import (
|
||||
ScheduleItemSourceType,
|
||||
ScheduleItemStatus,
|
||||
)
|
||||
from schemas.enums import SubscriptionPermission
|
||||
|
||||
__all__ = [
|
||||
"AttachmentType",
|
||||
@@ -41,9 +42,20 @@ __all__ = [
|
||||
"ScheduleItemShareRequest",
|
||||
"ScheduleItemShareResponse",
|
||||
"SubscriberInfo",
|
||||
"parse_permission",
|
||||
]
|
||||
|
||||
|
||||
def parse_permission(permission_int: int) -> dict[str, bool]:
|
||||
return {
|
||||
"can_view": bool(permission_int & SubscriptionPermission.VIEW.value),
|
||||
"can_invite": bool(permission_int & SubscriptionPermission.INVITE.value),
|
||||
"can_edit": bool(permission_int & SubscriptionPermission.EDIT.value),
|
||||
"can_delete": bool(permission_int & SubscriptionPermission.DELETE.value),
|
||||
"is_owner": permission_int == SubscriptionPermission.OWNER.value,
|
||||
}
|
||||
|
||||
|
||||
class ScheduleItemCreateRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@@ -160,11 +172,6 @@ class ScheduleItemListRequest(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
_PERMISSION_VIEW = 1
|
||||
_PERMISSION_INVITE = 2
|
||||
_PERMISSION_EDIT = 4
|
||||
|
||||
|
||||
class ScheduleItemShareRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@@ -180,11 +187,11 @@ class ScheduleItemShareRequest(BaseModel):
|
||||
def _permission_value(self) -> int:
|
||||
value = 0
|
||||
if self.permission_view:
|
||||
value |= _PERMISSION_VIEW
|
||||
value |= SubscriptionPermission.VIEW.value
|
||||
if self.permission_edit:
|
||||
value |= _PERMISSION_EDIT
|
||||
value |= SubscriptionPermission.EDIT.value
|
||||
if self.permission_invite:
|
||||
value |= _PERMISSION_INVITE
|
||||
value |= SubscriptionPermission.INVITE.value
|
||||
return value
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Protocol, Literal
|
||||
from typing import TYPE_CHECKING, Any, Protocol, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -9,6 +9,12 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from core.auth.models import CurrentUser
|
||||
from core.db.base_service import BaseService
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from v1.inbox_messages.realtime import (
|
||||
InboxMessageEventSnapshot,
|
||||
publish_inbox_message_created,
|
||||
publish_inbox_message_status_changed,
|
||||
snapshot_from_inbox_message,
|
||||
)
|
||||
from core.logging import get_logger
|
||||
from models.inbox_messages import InboxMessage
|
||||
from models.profile import Profile
|
||||
@@ -196,7 +202,14 @@ class ScheduleItemService(BaseService):
|
||||
subscriber_ids
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
logger.exception("Failed to get subscriber profiles")
|
||||
logger.exception("Failed to fetch subscriber profiles")
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_STORE_UNAVAILABLE",
|
||||
detail="Schedule item store unavailable",
|
||||
),
|
||||
)
|
||||
resolved_contacts = await resolve_contacts_by_user_ids(
|
||||
user_ids=subscriber_ids,
|
||||
profiles_by_id=profiles,
|
||||
@@ -302,10 +315,30 @@ class ScheduleItemService(BaseService):
|
||||
if not update_data:
|
||||
return self._to_response(existing, is_owner=is_owner)
|
||||
|
||||
before_state = self._capture_calendar_state(existing)
|
||||
item = await self._repository.update_item(item_id, update_data)
|
||||
|
||||
await self._notify_subscribers(item_id, existing.title, "updated")
|
||||
if item is None:
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_NOT_FOUND",
|
||||
detail="Schedule item not found",
|
||||
),
|
||||
)
|
||||
changes = self._build_calendar_changes(before_state, item, update_data)
|
||||
created_messages = await self._notify_subscribers(
|
||||
item_id,
|
||||
item.title,
|
||||
"updated",
|
||||
changes=changes,
|
||||
)
|
||||
if created_messages:
|
||||
await self._session.flush()
|
||||
created_snapshots = [
|
||||
snapshot_from_inbox_message(message) for message in created_messages
|
||||
]
|
||||
await self._session.commit()
|
||||
await self._publish_created_events(created_snapshots)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to update schedule item", item_id=str(item_id))
|
||||
@@ -317,15 +350,6 @@ class ScheduleItemService(BaseService):
|
||||
),
|
||||
)
|
||||
|
||||
if item is None:
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_NOT_FOUND",
|
||||
detail="Schedule item not found",
|
||||
),
|
||||
)
|
||||
|
||||
return self._to_response(item, is_owner=is_owner)
|
||||
|
||||
async def delete(self, item_id: UUID) -> None:
|
||||
@@ -359,9 +383,20 @@ class ScheduleItemService(BaseService):
|
||||
|
||||
title = existing.title
|
||||
await self._repository.delete_subscriptions_by_item_id(item_id)
|
||||
await self._notify_subscribers(item_id, title, "deleted")
|
||||
created_messages = await self._notify_subscribers(
|
||||
item_id,
|
||||
title,
|
||||
"deleted",
|
||||
changes=[],
|
||||
)
|
||||
await self._repository.delete_item(item_id)
|
||||
if created_messages:
|
||||
await self._session.flush()
|
||||
created_snapshots = [
|
||||
snapshot_from_inbox_message(message) for message in created_messages
|
||||
]
|
||||
await self._session.commit()
|
||||
await self._publish_created_events(created_snapshots)
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
logger.exception("Failed to delete schedule item", item_id=str(item_id))
|
||||
@@ -427,54 +462,6 @@ class ScheduleItemService(BaseService):
|
||||
),
|
||||
)
|
||||
|
||||
async def list_paginated(
|
||||
self,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
query: str | None = None,
|
||||
) -> tuple[list[ScheduleItemResponse], int]:
|
||||
user_id = self.require_user_id()
|
||||
if page < 1:
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_PAGE_INVALID",
|
||||
detail="page must be >= 1",
|
||||
params={"min": 1, "page": page},
|
||||
),
|
||||
)
|
||||
if page_size < 1 or page_size > 100:
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_PAGE_SIZE_INVALID",
|
||||
detail="page_size must be 1..100",
|
||||
params={"min": 1, "max": 100, "page_size": page_size},
|
||||
),
|
||||
)
|
||||
try:
|
||||
items, total = await self._repository.list_paginated(
|
||||
user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
query=query,
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
logger.exception(
|
||||
"Failed to list schedule items with pagination",
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_STORE_UNAVAILABLE",
|
||||
detail="Schedule item store unavailable",
|
||||
),
|
||||
)
|
||||
return [self._to_response(item) for item in items], total
|
||||
|
||||
async def share(
|
||||
self, item_id: UUID, request: ScheduleItemShareRequest
|
||||
) -> ScheduleItemShareResponse:
|
||||
@@ -518,6 +505,10 @@ class ScheduleItemService(BaseService):
|
||||
),
|
||||
)
|
||||
|
||||
actor_username = await self._resolve_actor_username(user_id)
|
||||
actor = await self._auth_gateway.get_user_by_id(str(user_id))
|
||||
actor_phone = actor.phone
|
||||
|
||||
target_user = await self._auth_gateway.get_user_by_phone(request.phone)
|
||||
recipient_id = UUID(target_user.id)
|
||||
|
||||
@@ -563,6 +554,8 @@ class ScheduleItemService(BaseService):
|
||||
existing_msg = await self._inbox_repository.get_calendar_invite(
|
||||
item.id, recipient_id
|
||||
)
|
||||
event_target_message: InboxMessage | None = None
|
||||
event_is_created = False
|
||||
if existing_msg:
|
||||
if existing_msg.status == InboxMessageStatus.ACCEPTED:
|
||||
raise ApiProblemError(
|
||||
@@ -584,9 +577,25 @@ class ScheduleItemService(BaseService):
|
||||
existing_msg.status = InboxMessageStatus.PENDING
|
||||
existing_msg.content = {
|
||||
"type": "invite",
|
||||
"schema_version": 2,
|
||||
"item": {
|
||||
"id": str(item.id),
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"start_at": item.start_at.isoformat(),
|
||||
"end_at": item.end_at.isoformat() if item.end_at else None,
|
||||
"timezone": item.timezone,
|
||||
},
|
||||
"actor": {
|
||||
"user_id": str(user_id),
|
||||
"username": actor_username,
|
||||
"phone": actor_phone,
|
||||
},
|
||||
"summary": f"{item.title} 邀请您加入日历",
|
||||
"permission": request_permission,
|
||||
"action": "pending",
|
||||
}
|
||||
event_target_message = existing_msg
|
||||
else:
|
||||
message = InboxMessage(
|
||||
recipient_id=recipient_id,
|
||||
@@ -595,14 +604,44 @@ class ScheduleItemService(BaseService):
|
||||
schedule_item_id=item.id,
|
||||
content={
|
||||
"type": "invite",
|
||||
"schema_version": 2,
|
||||
"item": {
|
||||
"id": str(item.id),
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"start_at": item.start_at.isoformat(),
|
||||
"end_at": item.end_at.isoformat() if item.end_at else None,
|
||||
"timezone": item.timezone,
|
||||
},
|
||||
"actor": {
|
||||
"user_id": str(user_id),
|
||||
"username": actor_username,
|
||||
"phone": actor_phone,
|
||||
},
|
||||
"summary": f"{item.title} 邀请您加入日历",
|
||||
"permission": request_permission,
|
||||
"action": "pending",
|
||||
},
|
||||
created_by=user_id,
|
||||
)
|
||||
self._session.add(message)
|
||||
event_target_message = message
|
||||
event_is_created = True
|
||||
|
||||
if event_target_message is not None and event_is_created:
|
||||
await self._session.flush()
|
||||
event_snapshot = (
|
||||
snapshot_from_inbox_message(event_target_message)
|
||||
if event_target_message is not None
|
||||
else None
|
||||
)
|
||||
await self._session.commit()
|
||||
if event_target_message is not None:
|
||||
assert event_snapshot is not None
|
||||
if event_is_created:
|
||||
await self._publish_created_events([event_snapshot])
|
||||
else:
|
||||
await self._publish_status_events([event_snapshot])
|
||||
except ApiProblemError:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
@@ -703,7 +742,9 @@ class ScheduleItemService(BaseService):
|
||||
)
|
||||
|
||||
inbox.status = InboxMessageStatus.ACCEPTED
|
||||
event_snapshot = snapshot_from_inbox_message(inbox)
|
||||
await self._session.commit()
|
||||
await self._publish_status_events([event_snapshot])
|
||||
|
||||
return {"message": "Subscription accepted"}
|
||||
except ApiProblemError:
|
||||
@@ -742,7 +783,9 @@ class ScheduleItemService(BaseService):
|
||||
)
|
||||
|
||||
inbox.status = InboxMessageStatus.REJECTED
|
||||
event_snapshot = snapshot_from_inbox_message(inbox)
|
||||
await self._session.commit()
|
||||
await self._publish_status_events([event_snapshot])
|
||||
|
||||
return {"message": "Subscription rejected"}
|
||||
except ApiProblemError:
|
||||
@@ -763,18 +806,36 @@ class ScheduleItemService(BaseService):
|
||||
item_id: UUID,
|
||||
title: str,
|
||||
action_type: Literal["updated", "deleted"],
|
||||
):
|
||||
*,
|
||||
changes: list[dict[str, Any]],
|
||||
) -> list[InboxMessage]:
|
||||
user_id = self.require_user_id()
|
||||
|
||||
subscriptions = await self._repository.get_subscriptions_by_item_id(item_id)
|
||||
if not subscriptions:
|
||||
return []
|
||||
actor_username = await self._resolve_actor_username(user_id)
|
||||
|
||||
created_messages: list[InboxMessage] = []
|
||||
for sub in subscriptions:
|
||||
if sub.subscriber_id == user_id:
|
||||
continue
|
||||
|
||||
content = {
|
||||
"type": action_type,
|
||||
"title": title,
|
||||
"schema_version": 2,
|
||||
"item": {
|
||||
"id": str(item_id),
|
||||
"title": title,
|
||||
},
|
||||
"actor": {
|
||||
"user_id": str(user_id),
|
||||
"username": actor_username,
|
||||
},
|
||||
"summary": f"{actor_username} 更新了日历 {title}"
|
||||
if action_type == "updated"
|
||||
else f"{actor_username} 删除了日历 {title}",
|
||||
"changes": changes,
|
||||
"action": action_type,
|
||||
}
|
||||
|
||||
@@ -787,9 +848,126 @@ class ScheduleItemService(BaseService):
|
||||
created_by=user_id,
|
||||
)
|
||||
self._session.add(message)
|
||||
created_messages.append(message)
|
||||
return created_messages
|
||||
|
||||
if subscriptions:
|
||||
await self._session.commit()
|
||||
async def _resolve_actor_username(self, user_id: UUID) -> str:
|
||||
if self._user_repository is None:
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_ACTOR_LOOKUP_UNAVAILABLE",
|
||||
detail="Actor lookup unavailable",
|
||||
),
|
||||
)
|
||||
profile = await self._user_repository.get_by_user_id(user_id)
|
||||
if profile is not None and profile.username:
|
||||
return profile.username
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="SCHEDULE_ITEM_ACTOR_LOOKUP_UNAVAILABLE",
|
||||
detail="Actor lookup unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
def _capture_calendar_state(self, item: ScheduleItem) -> dict[str, Any]:
|
||||
return {
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"start_at": item.start_at,
|
||||
"end_at": item.end_at,
|
||||
"timezone": item.timezone,
|
||||
"status": item.status,
|
||||
}
|
||||
|
||||
def _build_calendar_changes(
|
||||
self,
|
||||
before: dict[str, Any],
|
||||
after: ScheduleItem,
|
||||
update_data: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
label_map = {
|
||||
"title": "标题",
|
||||
"description": "描述",
|
||||
"start_at": "开始时间",
|
||||
"end_at": "结束时间",
|
||||
"timezone": "时区",
|
||||
"status": "状态",
|
||||
}
|
||||
changes: list[dict[str, Any]] = []
|
||||
for field in label_map:
|
||||
if field not in update_data:
|
||||
continue
|
||||
before_value = before.get(field)
|
||||
after_value = getattr(after, field)
|
||||
if before_value == after_value:
|
||||
continue
|
||||
change_type = "modified"
|
||||
if before_value is None and after_value is not None:
|
||||
change_type = "added"
|
||||
elif before_value is not None and after_value is None:
|
||||
change_type = "removed"
|
||||
changes.append(
|
||||
{
|
||||
"field": field,
|
||||
"label": label_map[field],
|
||||
"before": before_value.isoformat()
|
||||
if isinstance(before_value, datetime)
|
||||
else before_value,
|
||||
"after": after_value.isoformat()
|
||||
if isinstance(after_value, datetime)
|
||||
else after_value,
|
||||
"display_before": str(before_value)
|
||||
if before_value is not None
|
||||
else None,
|
||||
"display_after": str(after_value)
|
||||
if after_value is not None
|
||||
else None,
|
||||
"change_type": change_type,
|
||||
}
|
||||
)
|
||||
return changes
|
||||
|
||||
async def _publish_created_events(
|
||||
self, messages: list[InboxMessageEventSnapshot]
|
||||
) -> None:
|
||||
for message in messages:
|
||||
try:
|
||||
await publish_inbox_message_created(message)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Failed to publish inbox created event",
|
||||
message_id=str(message.message_id),
|
||||
recipient_id=str(message.recipient_id),
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="INBOX_EVENT_STREAM_UNAVAILABLE",
|
||||
detail="Inbox event stream unavailable",
|
||||
),
|
||||
) from exc
|
||||
|
||||
async def _publish_status_events(
|
||||
self, messages: list[InboxMessageEventSnapshot]
|
||||
) -> None:
|
||||
for message in messages:
|
||||
try:
|
||||
await publish_inbox_message_status_changed(message)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception(
|
||||
"Failed to publish inbox status event",
|
||||
message_id=str(message.message_id),
|
||||
recipient_id=str(message.recipient_id),
|
||||
)
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
detail=problem_payload(
|
||||
code="INBOX_EVENT_STREAM_UNAVAILABLE",
|
||||
detail="Inbox event stream unavailable",
|
||||
),
|
||||
) from exc
|
||||
|
||||
def _to_utc(self, dt: datetime | None) -> datetime | None:
|
||||
if dt is None:
|
||||
|
||||
@@ -19,7 +19,7 @@ class UserSearchRequest(BaseModel):
|
||||
class UserUpdateRequest(BaseModel):
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
username: str | None = Field(default=None, min_length=3, max_length=30)
|
||||
username: str | None = Field(default=None, max_length=30)
|
||||
avatar_url: str | None = Field(default=None)
|
||||
bio: str | None = Field(default=None, max_length=200)
|
||||
|
||||
|
||||
+118
-81
@@ -12,7 +12,7 @@ from core.agentscope.caches.user_context_cache import (
|
||||
from core.auth.models import CurrentUser
|
||||
from core.config.settings import config
|
||||
from core.db.base_service import BaseService
|
||||
from core.http.errors import ApiProblemError
|
||||
from core.http.errors import ApiProblemError, problem_payload
|
||||
from core.logging import get_logger
|
||||
from schemas.shared.user import UserContext, parse_profile_settings
|
||||
from services.base.supabase import supabase_service
|
||||
@@ -23,27 +23,13 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from schemas.shared.user import UserContext
|
||||
from v1.auth.schemas import UserByIdResponse
|
||||
|
||||
logger = get_logger("v1.users.service")
|
||||
|
||||
_PHONE_QUERY_PATTERN = re.compile(r"^[+()\-\s\d]{4,32}$")
|
||||
|
||||
|
||||
def _user_error(
|
||||
*,
|
||||
status_code: int,
|
||||
code: str,
|
||||
detail: str,
|
||||
params: dict[str, object] | None = None,
|
||||
) -> ApiProblemError:
|
||||
return ApiProblemError(
|
||||
status_code=status_code,
|
||||
code=code,
|
||||
detail=detail,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
def _mime_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix."""
|
||||
mapping = {
|
||||
@@ -59,11 +45,7 @@ class AuthLookupGateway(Protocol):
|
||||
self, query: str, limit: int = 20
|
||||
) -> list[str]: ...
|
||||
|
||||
|
||||
class AuthByPhoneGateway(Protocol):
|
||||
async def search_user_ids_by_phone(
|
||||
self, query: str, limit: int = 20
|
||||
) -> list[str]: ...
|
||||
async def get_user_by_id(self, user_id: str) -> "UserByIdResponse": ...
|
||||
|
||||
|
||||
class UserContextInvalidator(Protocol):
|
||||
@@ -71,7 +53,7 @@ class UserContextInvalidator(Protocol):
|
||||
|
||||
|
||||
class AuthLookupAdapter:
|
||||
def __init__(self, gateway: AuthByPhoneGateway) -> None:
|
||||
def __init__(self, gateway: AuthLookupGateway) -> None:
|
||||
self._gateway = gateway
|
||||
|
||||
async def search_user_ids_by_phone(self, query: str, limit: int = 20) -> list[str]:
|
||||
@@ -80,6 +62,12 @@ class AuthLookupAdapter:
|
||||
except ApiProblemError:
|
||||
return []
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> "UserByIdResponse | None":
|
||||
try:
|
||||
return await self._gateway.get_user_by_id(user_id)
|
||||
except ApiProblemError:
|
||||
return None
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
"""User service handling business logic and transactions.
|
||||
@@ -117,17 +105,21 @@ class UserService(BaseService):
|
||||
try:
|
||||
user = await self._repository.get_by_user_id(user_id)
|
||||
except SQLAlchemyError:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
detail=problem_payload(
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
),
|
||||
)
|
||||
phone = self._current_user.phone if self._current_user else None
|
||||
return UserContext(
|
||||
@@ -145,22 +137,38 @@ class UserService(BaseService):
|
||||
try:
|
||||
profile = await self._repository.get_by_user_id(user_id)
|
||||
except SQLAlchemyError:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
if profile is None:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
detail=problem_payload(
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
),
|
||||
)
|
||||
phone: str | None = None
|
||||
if self._auth_gateway is not None:
|
||||
try:
|
||||
auth_user = await self._auth_gateway.get_user_by_id(str(user_id))
|
||||
phone = auth_user.phone
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve auth phone",
|
||||
user_id=str(user_id),
|
||||
)
|
||||
return UserContext(
|
||||
id=str(profile.id),
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatar_url,
|
||||
phone=phone,
|
||||
bio=profile.bio,
|
||||
)
|
||||
|
||||
async def update_me(self, update: UserUpdateRequest) -> UserContext:
|
||||
@@ -176,10 +184,12 @@ class UserService(BaseService):
|
||||
}
|
||||
|
||||
if not update_data:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=400,
|
||||
code="USER_UPDATE_FIELDS_EMPTY",
|
||||
detail="No fields to update",
|
||||
detail=problem_payload(
|
||||
code="USER_UPDATE_FIELDS_EMPTY",
|
||||
detail="No fields to update",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -187,17 +197,21 @@ class UserService(BaseService):
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
detail=problem_payload(
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -229,38 +243,46 @@ class UserService(BaseService):
|
||||
user_id = self.require_user_id()
|
||||
|
||||
if not isinstance(content_type, str):
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="USER_AVATAR_UNSUPPORTED_TYPE",
|
||||
detail="Unsupported image type",
|
||||
detail=problem_payload(
|
||||
code="USER_AVATAR_UNSUPPORTED_TYPE",
|
||||
detail="Unsupported image type",
|
||||
),
|
||||
)
|
||||
|
||||
mime_type = content_type.lower()
|
||||
allowed_types = {"image/jpeg", "image/png", "image/webp"}
|
||||
if mime_type not in allowed_types:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="USER_AVATAR_UNSUPPORTED_TYPE",
|
||||
detail="Unsupported image type. Allowed: JPEG, PNG, WebP",
|
||||
params={"allowed": ["image/jpeg", "image/png", "image/webp"]},
|
||||
detail=problem_payload(
|
||||
code="USER_AVATAR_UNSUPPORTED_TYPE",
|
||||
detail="Unsupported image type. Allowed: JPEG, PNG, WebP",
|
||||
params={"allowed": ["image/jpeg", "image/png", "image/webp"]},
|
||||
),
|
||||
)
|
||||
|
||||
max_size_bytes = config.storage.avatar.max_size_mb * 1024 * 1024
|
||||
if len(payload) > max_size_bytes:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=413,
|
||||
code="USER_AVATAR_TOO_LARGE",
|
||||
detail=(
|
||||
f"Image too large. Maximum size: {config.storage.avatar.max_size_mb}MB"
|
||||
detail=problem_payload(
|
||||
code="USER_AVATAR_TOO_LARGE",
|
||||
detail=(
|
||||
f"Image too large. Maximum size: {config.storage.avatar.max_size_mb}MB"
|
||||
),
|
||||
params={"max_size_mb": config.storage.avatar.max_size_mb},
|
||||
),
|
||||
params={"max_size_mb": config.storage.avatar.max_size_mb},
|
||||
)
|
||||
|
||||
if not payload:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=422,
|
||||
code="USER_AVATAR_EMPTY",
|
||||
detail="Empty image",
|
||||
detail=problem_payload(
|
||||
code="USER_AVATAR_EMPTY",
|
||||
detail="Empty image",
|
||||
),
|
||||
)
|
||||
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
@@ -284,13 +306,16 @@ class UserService(BaseService):
|
||||
"user_id": str(user_id),
|
||||
},
|
||||
)
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=502,
|
||||
code="USER_AVATAR_UPLOAD_FAILED",
|
||||
detail="Failed to upload avatar",
|
||||
detail=problem_payload(
|
||||
code="USER_AVATAR_UPLOAD_FAILED",
|
||||
detail="Failed to upload avatar",
|
||||
),
|
||||
)
|
||||
|
||||
public_url = f"{config.supabase.public_url}/storage/v1/object/public/{bucket_name}/{stored_path}"
|
||||
base_url = str(config.supabase.public_url).rstrip("/")
|
||||
public_url = f"{base_url}/storage/v1/object/public/{bucket_name}/{stored_path}"
|
||||
|
||||
update_data: dict[str, str | None] = {"avatar_url": public_url}
|
||||
try:
|
||||
@@ -298,17 +323,21 @@ class UserService(BaseService):
|
||||
await self._session.commit()
|
||||
except SQLAlchemyError:
|
||||
await self._session.rollback()
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
detail=problem_payload(
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -326,17 +355,19 @@ class UserService(BaseService):
|
||||
try:
|
||||
user = await self._repository.get_by_username(username)
|
||||
except SQLAlchemyError:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=404,
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
detail=problem_payload(
|
||||
code="USER_NOT_FOUND",
|
||||
detail="User not found",
|
||||
),
|
||||
)
|
||||
return UserContext(
|
||||
id=str(user.id),
|
||||
@@ -365,10 +396,12 @@ class UserService(BaseService):
|
||||
|
||||
async def _search_by_phone(self, phone: str) -> list[UserContext]:
|
||||
if self._auth_gateway is None:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_AUTH_LOOKUP_UNAVAILABLE",
|
||||
detail="Auth lookup unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_AUTH_LOOKUP_UNAVAILABLE",
|
||||
detail="Auth lookup unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
user_id_values = await self._auth_gateway.search_user_ids_by_phone(
|
||||
@@ -389,10 +422,12 @@ class UserService(BaseService):
|
||||
try:
|
||||
users_by_id = await self._repository.get_by_user_ids(user_ids)
|
||||
except SQLAlchemyError:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
results: list[UserContext] = []
|
||||
@@ -415,10 +450,12 @@ class UserService(BaseService):
|
||||
try:
|
||||
users = await self._repository.search_users(query, limit=20)
|
||||
except SQLAlchemyError:
|
||||
raise _user_error(
|
||||
raise ApiProblemError(
|
||||
status_code=503,
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
detail=problem_payload(
|
||||
code="USER_STORE_UNAVAILABLE",
|
||||
detail="User store unavailable",
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
|
||||
@@ -25,6 +25,13 @@ class FakeInboxMessageService:
|
||||
) -> None:
|
||||
self._messages = messages
|
||||
self._read_message = read_message
|
||||
self._stream_rows: list[dict[str, object]] = []
|
||||
|
||||
def set_stream_rows(self, rows: list[dict[str, object]]) -> None:
|
||||
self._stream_rows = rows
|
||||
|
||||
def require_user_id(self) -> UUID:
|
||||
return self._read_message.recipient_id
|
||||
|
||||
async def list_messages(
|
||||
self, is_read: bool | None = None
|
||||
@@ -38,6 +45,16 @@ class FakeInboxMessageService:
|
||||
raise HTTPException(status_code=404, detail="Inbox message not found")
|
||||
return self._read_message
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
*,
|
||||
last_event_id: str | None,
|
||||
) -> list[dict[str, object]]:
|
||||
del last_event_id
|
||||
rows = self._stream_rows
|
||||
self._stream_rows = []
|
||||
return rows
|
||||
|
||||
|
||||
def _override_inbox_message_service(
|
||||
service: FakeInboxMessageService,
|
||||
@@ -58,7 +75,7 @@ def _build_message(
|
||||
sender_id=uuid4(),
|
||||
message_type=InboxMessageType.CALENDAR,
|
||||
schedule_item_id=uuid4(),
|
||||
content='{"permission": 1}',
|
||||
content={"permission": 1},
|
||||
is_read=False,
|
||||
status=status,
|
||||
created_at=datetime(2026, 2, 28, 9, 0, 0, tzinfo=timezone.utc),
|
||||
@@ -108,3 +125,62 @@ def test_mark_as_read_returns_200() -> None:
|
||||
assert body["is_read"] is True
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_inbox_events_returns_sse_payload() -> None:
|
||||
read_message = _build_message(uuid4(), InboxMessageStatus.PENDING)
|
||||
service = FakeInboxMessageService(
|
||||
messages=[read_message], read_message=read_message
|
||||
)
|
||||
service.set_stream_rows(
|
||||
[
|
||||
{
|
||||
"id": "1743313300000-0",
|
||||
"event": {
|
||||
"event_id": str(uuid4()),
|
||||
"occurred_at": "2026-03-30T07:00:00+00:00",
|
||||
"user_id": str(read_message.recipient_id),
|
||||
"message_id": str(read_message.id),
|
||||
"event_type": "INBOX_MESSAGE_CREATED",
|
||||
"op": "created",
|
||||
"version": 1743313300000,
|
||||
"data": {"message": {"id": str(read_message.id)}},
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
app.dependency_overrides[get_inbox_message_service] = (
|
||||
_override_inbox_message_service(service)
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get("/api/v1/inbox/messages/stream?idle_limit=1")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
payload = response.text
|
||||
assert "event: INBOX_MESSAGE_CREATED" in payload
|
||||
assert '"op":"created"' in payload
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_stream_inbox_events_rejects_invalid_last_event_id() -> None:
|
||||
read_message = _build_message(uuid4(), InboxMessageStatus.PENDING)
|
||||
service = FakeInboxMessageService(
|
||||
messages=[read_message], read_message=read_message
|
||||
)
|
||||
app.dependency_overrides[get_inbox_message_service] = (
|
||||
_override_inbox_message_service(service)
|
||||
)
|
||||
client = TestClient(app)
|
||||
try:
|
||||
response = client.get(
|
||||
"/api/v1/inbox/messages/stream",
|
||||
headers={"Last-Event-ID": "not-a-stream-id"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert body.get("code") == "INBOX_INVALID_LAST_EVENT_ID"
|
||||
finally:
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
@@ -27,6 +27,7 @@ class _FakeService:
|
||||
created_request: Any = None
|
||||
created_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
list_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
range_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
deleted_ids: list[str] = field(default_factory=list)
|
||||
|
||||
async def list_paginated(
|
||||
@@ -47,6 +48,29 @@ class _FakeService:
|
||||
)
|
||||
return [item], 1
|
||||
|
||||
async def list_by_date_range(self, request: Any):
|
||||
self.range_calls.append(
|
||||
{
|
||||
"start_at": request.start_at,
|
||||
"end_at": request.end_at,
|
||||
}
|
||||
)
|
||||
return [
|
||||
SimpleNamespace(
|
||||
id=UUID(self.created_id),
|
||||
owner_id=uuid4(),
|
||||
title="会议",
|
||||
description="今天下午五点的会议",
|
||||
start_at=datetime(2026, 3, 17, 9, 0, tzinfo=timezone.utc),
|
||||
end_at=datetime(2026, 3, 17, 9, 30, tzinfo=timezone.utc),
|
||||
timezone="Asia/Shanghai",
|
||||
status="active",
|
||||
source_type="manual",
|
||||
metadata=None,
|
||||
subscribers=[],
|
||||
)
|
||||
]
|
||||
|
||||
async def create_agent_generated(self, request):
|
||||
self.created_request = request
|
||||
return SimpleNamespace(
|
||||
@@ -235,22 +259,48 @@ async def test_calendar_read_returns_structured_result_with_ids(
|
||||
)
|
||||
|
||||
result = await calendar_module.calendar_read(
|
||||
query="会议",
|
||||
page=1,
|
||||
page_size=20,
|
||||
start_at="2026-03-17T00:00:00+08:00",
|
||||
end_at="2026-03-18T00:00:00+08:00",
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(result)
|
||||
result_data = json.loads(payload["result"])
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert result_data["total"] == 1
|
||||
assert result_data["items"][0]["id"] == fake_service.created_id
|
||||
assert result_data["items"][0]["timezone"] == "Asia/Shanghai"
|
||||
assert result_data["items"][0]["description"] == "今天下午五点的会议"
|
||||
assert result_data["items"][0]["status"] == "active"
|
||||
assert fake_service.range_calls == [
|
||||
{
|
||||
"start_at": datetime(2026, 3, 16, 16, 0, tzinfo=timezone.utc),
|
||||
"end_at": datetime(2026, 3, 17, 16, 0, tzinfo=timezone.utc),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_read_rejects_naive_datetime_string(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_service = _FakeService()
|
||||
monkeypatch.setattr(
|
||||
calendar_module, "create_schedule_service", lambda *_: fake_service
|
||||
)
|
||||
|
||||
result = await calendar_module.calendar_read(
|
||||
start_at="2026-03-17T00:00:00",
|
||||
end_at="2026-03-18T00:00:00+08:00",
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(result)
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert payload["result"].startswith("status=success")
|
||||
assert "total=1" in payload["result"]
|
||||
assert "timezone=Asia/Shanghai" in payload["result"]
|
||||
assert "description=今天下午五点的会议" in payload["result"]
|
||||
assert "status=active" in payload["result"]
|
||||
assert fake_service.created_id in payload["result"]
|
||||
assert fake_service.list_calls == [{"page": 1, "page_size": 20, "query": "会议"}]
|
||||
assert payload["status"] == "failure"
|
||||
assert payload["error"]["code"] == "INVALID_ARGUMENT"
|
||||
assert "时区" in payload["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -312,3 +362,39 @@ async def test_calendar_share_rejects_invalid_phone(
|
||||
|
||||
assert payload["status"] == "failure"
|
||||
assert payload["error"]["code"] == "INVALID_ARGUMENT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_share_accepts_json_invitee_payload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_service = _FakeService()
|
||||
monkeypatch.setattr(
|
||||
calendar_module, "create_schedule_service", lambda *_: fake_service
|
||||
)
|
||||
event_id = str(uuid4())
|
||||
|
||||
result = await calendar_module.calendar_share(
|
||||
event_id=event_id,
|
||||
invitees=cast(
|
||||
Any,
|
||||
[
|
||||
{
|
||||
"phone": "8613900001234",
|
||||
"permissionView": True,
|
||||
"permissionEdit": False,
|
||||
"permissionInvite": False,
|
||||
}
|
||||
],
|
||||
),
|
||||
session=SimpleNamespace(),
|
||||
owner_id=uuid4(),
|
||||
)
|
||||
payload = _decode_tool_response(result)
|
||||
|
||||
assert payload["status"] == "success"
|
||||
assert payload["result"].startswith("status=success success=1 failed=0")
|
||||
assert len(fake_service.share_calls) == 1
|
||||
share_call = fake_service.share_calls[0]
|
||||
assert share_call["item_id"] == event_id
|
||||
assert share_call["request"].phone == "+8613900001234"
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.runtime import cli
|
||||
|
||||
|
||||
class _FakeScheduler:
|
||||
def __init__(self) -> None:
|
||||
self.started = False
|
||||
self.shutdown_called = False
|
||||
self.jobs: list[dict[str, Any]] = []
|
||||
|
||||
def add_job(self, func: Any, **kwargs: Any) -> None:
|
||||
self.jobs.append({"func": func, **kwargs})
|
||||
|
||||
def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
def shutdown(self, *, wait: bool) -> None:
|
||||
self.shutdown_called = True
|
||||
self.shutdown_wait = wait
|
||||
|
||||
|
||||
class _StopEvent:
|
||||
async def wait(self) -> None:
|
||||
raise asyncio.CancelledError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_automation_scheduler_forever_uses_async_scheduler(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_scheduler = _FakeScheduler()
|
||||
dispatch_limits: list[int] = []
|
||||
|
||||
async def _fake_scan(*, limit: int) -> None:
|
||||
dispatch_limits.append(limit)
|
||||
|
||||
monkeypatch.setattr(cli, "AsyncIOScheduler", lambda: fake_scheduler)
|
||||
monkeypatch.setattr(cli, "run_automation_scheduler_scan", _fake_scan)
|
||||
monkeypatch.setattr(cli.asyncio, "Event", lambda: _StopEvent())
|
||||
|
||||
settings = cli.config.automation_scheduler
|
||||
old_enabled = settings.enabled
|
||||
old_interval = settings.interval_seconds
|
||||
old_limit = settings.batch_limit
|
||||
settings.enabled = True
|
||||
settings.interval_seconds = 9
|
||||
settings.batch_limit = 7
|
||||
|
||||
try:
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await cli.run_automation_scheduler_forever()
|
||||
finally:
|
||||
settings.enabled = old_enabled
|
||||
settings.interval_seconds = old_interval
|
||||
settings.batch_limit = old_limit
|
||||
|
||||
assert fake_scheduler.started is True
|
||||
assert fake_scheduler.shutdown_called is True
|
||||
assert len(fake_scheduler.jobs) == 1
|
||||
assert fake_scheduler.jobs[0]["max_instances"] == 1
|
||||
assert fake_scheduler.jobs[0]["coalesce"] is True
|
||||
|
||||
scan_job = fake_scheduler.jobs[0]["func"]
|
||||
await scan_job()
|
||||
assert dispatch_limits == [7]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_automation_scheduler_forever_disabled_noop(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
settings = cli.config.automation_scheduler
|
||||
old_enabled = settings.enabled
|
||||
settings.enabled = False
|
||||
|
||||
called = False
|
||||
|
||||
def _unexpected_scheduler() -> _FakeScheduler:
|
||||
nonlocal called
|
||||
called = True
|
||||
return _FakeScheduler()
|
||||
|
||||
monkeypatch.setattr(cli, "AsyncIOScheduler", _unexpected_scheduler)
|
||||
|
||||
try:
|
||||
await cli.run_automation_scheduler_forever()
|
||||
finally:
|
||||
settings.enabled = old_enabled
|
||||
|
||||
assert called is False
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from schemas.agent.runtime_models import RouterAgentOutput
|
||||
from schemas.agent.runtime_models import RouterAgentOutput, WorkerAgentOutputRich
|
||||
|
||||
|
||||
def test_router_agent_output_coerces_key_entity_value_to_string() -> None:
|
||||
@@ -32,3 +32,59 @@ def test_router_agent_output_coerces_key_entity_value_to_string() -> None:
|
||||
model = RouterAgentOutput.model_validate(payload)
|
||||
|
||||
assert model.key_entities[0].value == "8"
|
||||
|
||||
|
||||
def test_router_agent_output_coerces_constraint_value_to_string() -> None:
|
||||
payload = {
|
||||
"normalized_task_input": {
|
||||
"user_text": "test",
|
||||
"multimodal_summary": [],
|
||||
"context_summary": "",
|
||||
},
|
||||
"key_entities": [],
|
||||
"constraints": [
|
||||
{
|
||||
"key": "strict_mode",
|
||||
"value": True,
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
"task_typing": {
|
||||
"primary": "planning",
|
||||
"secondary": [],
|
||||
},
|
||||
"execution_mode": "onestep",
|
||||
"result_typing": {
|
||||
"primary": "summary",
|
||||
"secondary": [],
|
||||
},
|
||||
}
|
||||
|
||||
model = RouterAgentOutput.model_validate(payload)
|
||||
|
||||
assert model.constraints[0].value == "True"
|
||||
|
||||
|
||||
def test_worker_agent_output_rich_accepts_list_item_status_object() -> None:
|
||||
payload = {
|
||||
"status": "success",
|
||||
"answer": "done",
|
||||
"result_type": "summary",
|
||||
"ui_hints": {
|
||||
"intent": "status",
|
||||
"status": "info",
|
||||
"title": "状态",
|
||||
"listItems": [
|
||||
{
|
||||
"title": "任务A",
|
||||
"status": {"type": "info", "value": "已归档"},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
model = WorkerAgentOutputRich.model_validate(payload)
|
||||
|
||||
assert model.ui_hints is not None
|
||||
assert model.ui_hints.list_items[0].status is not None
|
||||
assert model.ui_hints.list_items[0].status.value == "info"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
@@ -64,10 +64,14 @@ class FakeFriendshipRepo:
|
||||
inbox.id = uuid4()
|
||||
inbox.recipient_id = recipient_id
|
||||
inbox.sender_id = initiator_id
|
||||
inbox.schedule_item_id = None
|
||||
inbox.status = InboxMessageStatus.PENDING
|
||||
inbox.message_type = InboxMessageType.FRIEND_REQUEST
|
||||
inbox.friendship_id = friendship.id
|
||||
inbox.content = {"type": "request", "message": content}
|
||||
inbox.is_read = False
|
||||
inbox.created_at = datetime.now(timezone.utc)
|
||||
inbox.updated_at = datetime.now(timezone.utc)
|
||||
self._inbox_messages.append(inbox)
|
||||
|
||||
return friendship, inbox
|
||||
@@ -91,10 +95,14 @@ class FakeFriendshipRepo:
|
||||
inbox.id = uuid4()
|
||||
inbox.recipient_id = recipient_id
|
||||
inbox.sender_id = initiator_id
|
||||
inbox.schedule_item_id = None
|
||||
inbox.status = InboxMessageStatus.PENDING
|
||||
inbox.message_type = InboxMessageType.FRIEND_REQUEST
|
||||
inbox.friendship_id = friendship.id
|
||||
inbox.content = {"type": "request", "message": content}
|
||||
inbox.is_read = False
|
||||
inbox.created_at = datetime.now(timezone.utc)
|
||||
inbox.updated_at = datetime.now(timezone.utc)
|
||||
self._inbox_messages.append(inbox)
|
||||
|
||||
return friendship, inbox
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.inbox_messages import InboxMessage
|
||||
from schemas.enums import InboxMessageStatus, InboxMessageType
|
||||
from v1.inbox_messages import realtime
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self.last_stream: str | None = None
|
||||
self.last_payload: str | None = None
|
||||
self.last_block: int | None = None
|
||||
|
||||
async def xadd(self, stream: str, fields: dict[str, str]) -> str:
|
||||
self.last_stream = stream
|
||||
self.last_payload = fields.get("event")
|
||||
return "1743313300000-0"
|
||||
|
||||
async def xread(self, _streams: dict[str, str], count: int, block: int):
|
||||
del count
|
||||
self.last_block = block
|
||||
return [
|
||||
(
|
||||
"inbox:events:test",
|
||||
[
|
||||
(
|
||||
"1743313300000-0",
|
||||
{
|
||||
"event": '{"event_id":"e1","event_type":"INBOX_MESSAGE_CREATED","op":"created"}',
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_inbox_message_created_writes_stream(monkeypatch) -> None:
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_redis():
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(realtime, "get_or_init_redis_client", _fake_get_redis)
|
||||
message = InboxMessage(
|
||||
id=uuid4(),
|
||||
recipient_id=uuid4(),
|
||||
sender_id=uuid4(),
|
||||
message_type=InboxMessageType.CALENDAR,
|
||||
friendship_id=None,
|
||||
schedule_item_id=uuid4(),
|
||||
group_id=None,
|
||||
content={"type": "invite"},
|
||||
is_read=False,
|
||||
status=InboxMessageStatus.PENDING,
|
||||
created_by=uuid4(),
|
||||
)
|
||||
message.created_at = datetime(2026, 3, 30, 7, 0, tzinfo=UTC)
|
||||
message.updated_at = datetime(2026, 3, 30, 7, 0, tzinfo=UTC)
|
||||
|
||||
stream_id = await realtime.publish_inbox_message_created(message)
|
||||
|
||||
assert stream_id == "1743313300000-0"
|
||||
assert fake_redis.last_stream == f"inbox:events:{message.recipient_id}"
|
||||
assert fake_redis.last_payload is not None
|
||||
assert '"op":"created"' in fake_redis.last_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_inbox_events_decodes_rows(monkeypatch) -> None:
|
||||
fake_redis = _FakeRedis()
|
||||
|
||||
async def _fake_get_redis():
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(realtime, "get_or_init_redis_client", _fake_get_redis)
|
||||
|
||||
rows = await realtime.read_inbox_events(
|
||||
recipient_id=uuid4(),
|
||||
last_event_id=None,
|
||||
)
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["id"] == "1743313300000-0"
|
||||
assert rows[0]["event"]["event_type"] == "INBOX_MESSAGE_CREATED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_inbox_events_handles_redis_timeout(monkeypatch) -> None:
|
||||
class _TimeoutRedis(_FakeRedis):
|
||||
async def xread(self, _streams: dict[str, str], count: int, block: int):
|
||||
del _streams, count, block
|
||||
raise TimeoutError("read timeout")
|
||||
|
||||
fake_redis = _TimeoutRedis()
|
||||
|
||||
async def _fake_get_redis():
|
||||
return fake_redis
|
||||
|
||||
monkeypatch.setattr(realtime, "get_or_init_redis_client", _fake_get_redis)
|
||||
|
||||
rows = await realtime.read_inbox_events(recipient_id=uuid4(), last_event_id=None)
|
||||
|
||||
assert rows == []
|
||||
@@ -59,6 +59,9 @@ class FakeRepo:
|
||||
return self._item
|
||||
return None
|
||||
|
||||
async def get_item(self, item_id: UUID) -> ScheduleItem | None:
|
||||
return await self.get_by_id(item_id)
|
||||
|
||||
async def create(self, data: dict) -> ScheduleItem:
|
||||
return _create_mock_schedule_item(
|
||||
owner_id=data["owner_id"],
|
||||
@@ -74,6 +77,23 @@ class FakeRepo:
|
||||
self._item.title = data["title"]
|
||||
return self._item
|
||||
|
||||
async def update_item(self, item_id: UUID, data: dict) -> ScheduleItem | None:
|
||||
if self._item is None:
|
||||
return None
|
||||
if "title" in data:
|
||||
self._item.title = data["title"]
|
||||
if "description" in data:
|
||||
self._item.description = data["description"]
|
||||
if "start_at" in data:
|
||||
self._item.start_at = data["start_at"]
|
||||
if "end_at" in data:
|
||||
self._item.end_at = data["end_at"]
|
||||
if "timezone" in data:
|
||||
self._item.timezone = data["timezone"]
|
||||
if "extra_metadata" in data:
|
||||
self._item.extra_metadata = data["extra_metadata"]
|
||||
return self._item
|
||||
|
||||
async def delete_by_item_id(
|
||||
self, item_id: UUID, owner_id: UUID
|
||||
) -> ScheduleItem | None:
|
||||
@@ -81,6 +101,9 @@ class FakeRepo:
|
||||
return None
|
||||
return self._item
|
||||
|
||||
async def delete_item(self, item_id: UUID) -> None:
|
||||
del item_id
|
||||
|
||||
async def list_by_date_range(
|
||||
self, owner_id: UUID, start_at: datetime, end_at: datetime
|
||||
) -> list[ScheduleItem]:
|
||||
@@ -327,12 +350,11 @@ async def test_update_maps_metadata_to_extra_metadata(
|
||||
captured: dict | None = None
|
||||
|
||||
class CaptureRepo(FakeRepo):
|
||||
async def update_by_item_id(
|
||||
self, item_id: UUID, owner_id: UUID, data: dict
|
||||
) -> ScheduleItem | None:
|
||||
async def update_item(self, item_id: UUID, data: dict) -> ScheduleItem | None:
|
||||
nonlocal captured
|
||||
del item_id
|
||||
captured = data
|
||||
return await super().update_by_item_id(item_id, owner_id, data)
|
||||
return await super().update_item(item.id, data)
|
||||
|
||||
service = ScheduleItemService(
|
||||
repository=CaptureRepo(item),
|
||||
@@ -370,12 +392,11 @@ async def test_update_maps_null_metadata_to_extra_metadata_null(
|
||||
captured: dict | None = None
|
||||
|
||||
class CaptureRepo(FakeRepo):
|
||||
async def update_by_item_id(
|
||||
self, item_id: UUID, owner_id: UUID, data: dict
|
||||
) -> ScheduleItem | None:
|
||||
async def update_item(self, item_id: UUID, data: dict) -> ScheduleItem | None:
|
||||
nonlocal captured
|
||||
del item_id
|
||||
captured = data
|
||||
return await super().update_by_item_id(item_id, owner_id, data)
|
||||
return await super().update_item(item.id, data)
|
||||
|
||||
service = ScheduleItemService(
|
||||
repository=CaptureRepo(item),
|
||||
|
||||
@@ -157,6 +157,14 @@ class FriendshipRepoStub:
|
||||
return friendship
|
||||
|
||||
|
||||
class UserRepoStub:
|
||||
async def get_by_user_id(self, user_id: UUID):
|
||||
profile = MagicMock()
|
||||
profile.id = user_id
|
||||
profile.username = "owner"
|
||||
return profile
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_share_forbidden_when_not_owner() -> None:
|
||||
owner_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
@@ -172,6 +180,7 @@ async def test_share_forbidden_when_not_owner() -> None:
|
||||
auth_gateway=cast(Any, AuthGatewayStub()),
|
||||
inbox_repository=InboxRepoStub(),
|
||||
friendship_repository=cast(Any, FriendshipRepoStub()),
|
||||
user_repository=cast(Any, UserRepoStub()),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiProblemError) as exc_info:
|
||||
@@ -204,6 +213,7 @@ async def test_share_success_creates_calendar_invitation_message() -> None:
|
||||
auth_gateway=cast(Any, AuthGatewayStub()),
|
||||
inbox_repository=InboxRepoStub(),
|
||||
friendship_repository=cast(Any, FriendshipRepoStub()),
|
||||
user_repository=cast(Any, UserRepoStub()),
|
||||
)
|
||||
|
||||
result = await service.share(
|
||||
@@ -223,7 +233,17 @@ async def test_share_success_creates_calendar_invitation_message() -> None:
|
||||
assert message.sender_id == owner_id
|
||||
assert message.schedule_item_id == item_id
|
||||
assert message.message_type == InboxMessageType.CALENDAR
|
||||
assert message.content == {"type": "invite", "permission": 5, "action": "pending"}
|
||||
assert message.content is not None
|
||||
assert message.content["type"] == "invite"
|
||||
assert message.content["schema_version"] == 2
|
||||
assert message.content["permission"] == 5
|
||||
assert message.content["item"]["id"] == str(item_id)
|
||||
assert message.content["item"]["title"] == "test"
|
||||
assert message.content["item"]["start_at"] == "2026-02-28T16:00:00+00:00"
|
||||
assert message.content["item"]["end_at"] is None
|
||||
assert message.content["item"]["timezone"] == "UTC"
|
||||
assert message.content["actor"]["username"] == "owner"
|
||||
assert message.content["actor"]["phone"] == "+8613810000000"
|
||||
session.commit.assert_awaited_once()
|
||||
|
||||
|
||||
@@ -237,6 +257,7 @@ async def test_share_returns_not_found_when_item_missing() -> None:
|
||||
auth_gateway=cast(Any, AuthGatewayStub()),
|
||||
inbox_repository=InboxRepoStub(),
|
||||
friendship_repository=cast(Any, FriendshipRepoStub()),
|
||||
user_repository=cast(Any, UserRepoStub()),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiProblemError) as exc_info:
|
||||
@@ -268,6 +289,7 @@ async def test_share_invalid_auth_user_id_returns_503() -> None:
|
||||
auth_gateway=cast(Any, AuthGatewayInvalidIdStub()),
|
||||
inbox_repository=InboxRepoStub(),
|
||||
friendship_repository=cast(Any, FriendshipRepoStub()),
|
||||
user_repository=cast(Any, UserRepoStub()),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiProblemError) as exc_info:
|
||||
@@ -302,6 +324,7 @@ async def test_share_sqlalchemy_error_rolls_back() -> None:
|
||||
auth_gateway=cast(Any, AuthGatewayStub()),
|
||||
inbox_repository=InboxRepoStub(),
|
||||
friendship_repository=cast(Any, FriendshipRepoStub()),
|
||||
user_repository=cast(Any, UserRepoStub()),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiProblemError) as exc_info:
|
||||
@@ -334,6 +357,7 @@ async def test_share_returns_forbidden_when_target_is_not_friend() -> None:
|
||||
auth_gateway=cast(Any, AuthGatewayStub()),
|
||||
inbox_repository=InboxRepoStub(),
|
||||
friendship_repository=cast(Any, FriendshipRepoStub(accepted=False)),
|
||||
user_repository=cast(Any, UserRepoStub()),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiProblemError) as exc_info:
|
||||
|
||||
Reference in New Issue
Block a user