refactor: 重构 AgentScope ReAct Runner 与事件处理

- 重构 runtime/runner.py 实现 ReAct Agent 核心逻辑
- 更新事件编码器与存储机制
- 优化 prompt 系统与 tool 调用
- 调整 agent service 与 repository 配合
This commit is contained in:
qzl
2026-03-16 16:10:39 +08:00
parent ab073c88ed
commit 36b104fa37
22 changed files with 1288 additions and 319 deletions
+29 -23
View File
@@ -30,8 +30,8 @@ class AgentRepository:
*,
tool_result_storage: ToolResultPayloadStorage | None = None,
) -> None:
self._session = session
self._tool_result_storage = tool_result_storage
self._session: AsyncSession = session
self._tool_result_storage: ToolResultPayloadStorage | None = tool_result_storage
async def get_session_owner(self, *, session_id: str) -> str:
try:
@@ -138,34 +138,31 @@ class AgentRepository:
except ValueError as exc:
raise HTTPException(status_code=422, detail="Invalid session_id") from exc
timestamp_stmt = (
before_start = (
datetime.combine(before, time.min, tzinfo=timezone.utc)
if before is not None
else None
)
target_created_at_stmt = (
select(AgentChatMessage.created_at)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.order_by(AgentChatMessage.created_at.desc())
.limit(1)
)
rows = (await self._session.execute(timestamp_stmt)).scalars().all()
unique_days: list[date] = []
for created_at in rows:
if created_at is None:
continue
day = created_at.astimezone(timezone.utc).date()
if day not in unique_days:
unique_days.append(day)
if before_start is not None:
target_created_at_stmt = target_created_at_stmt.where(
AgentChatMessage.created_at < before_start
)
target_created_at = (
await self._session.execute(target_created_at_stmt)
).scalar_one_or_none()
if not unique_days:
if target_created_at is None:
return None
target_day: date | None = None
if before is None:
target_day = unique_days[0]
else:
for day in unique_days:
if day < before:
target_day = day
break
if target_day is None:
return None
target_day = target_created_at.astimezone(timezone.utc).date()
start = datetime.combine(target_day, time.min, tzinfo=timezone.utc)
end = start + timedelta(days=1)
@@ -178,7 +175,16 @@ class AgentRepository:
.order_by(AgentChatMessage.seq.asc())
)
messages = (await self._session.execute(message_stmt)).scalars().all()
has_more = any(day < target_day for day in unique_days)
has_more_stmt = (
select(AgentChatMessage.id)
.where(AgentChatMessage.session_id == session_uuid)
.where(AgentChatMessage.deleted_at.is_(None))
.where(AgentChatMessage.created_at < start)
.limit(1)
)
has_more = (
await self._session.execute(has_more_stmt)
).scalar_one_or_none() is not None
snapshot_messages: list[dict[str, object]] = []
for message in messages:
snapshot_messages.append(await self._to_snapshot_message(message))
+4
View File
@@ -128,6 +128,10 @@ async def enqueue_run(
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
request = parse_run_input(request.model_dump(by_alias=True, exclude_none=True))
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
try:
validate_run_request_messages_contract(request)
except ValueError as exc:
+14 -7
View File
@@ -170,12 +170,9 @@ class AgentService:
command={
"command": "run",
"owner_id": str(current_user.id),
"run_input": {
"messages": [
msg.model_dump(mode="json", exclude_none=True)
for msg in run_input.messages
],
},
"run_input": run_input.model_dump(
mode="json", by_alias=True, exclude_none=True
),
},
dedup_key=None,
)
@@ -204,7 +201,7 @@ class AgentService:
yesterday = await self._repository.get_history_day(
session_id=thread_id,
before=today.get("day"), # type: ignore
before=self._parse_history_day(today.get("day")),
)
messages: list[dict[str, object]] = []
@@ -215,6 +212,16 @@ class AgentService:
return {"messages": messages}
def _parse_history_day(self, value: object) -> date | None:
if isinstance(value, date):
return value
if isinstance(value, str):
try:
return date.fromisoformat(value)
except ValueError:
return None
return None
async def _prepare_user_message(
self,
*,
+25 -5
View File
@@ -17,7 +17,7 @@ from schemas.messages.chat_message import (
def convert_message_to_history(
message: AgentChatMessage,
get_signed_url_fn: Callable[[str, str], str] | None = None,
get_signed_url_fn: Callable[[dict[str, str]], str] | None = None,
) -> dict[str, Any]:
"""
将 AgentChatMessage 转换为 HistoryMessage 格式
@@ -55,14 +55,14 @@ def convert_message_to_history(
result["url"] = url
if ui_schema:
result["uiSchema"] = ui_schema
result["ui_schema"] = ui_schema
return result
def _convert_user_attachments(
metadata: AgentChatMessageMetadata | dict[str, Any] | None,
get_signed_url_fn: Callable[[str, str], str] | None,
get_signed_url_fn: Callable[[dict[str, str]], str] | None,
) -> str | None:
"""转换用户附件为临时访问 URL"""
if not metadata:
@@ -100,9 +100,19 @@ def _compile_tool_ui_hints(
tool_output_data = metadata.get("tool_agent_output")
if not tool_output_data:
return None
if isinstance(tool_output_data, dict):
raw_ui_schema = tool_output_data.get("ui_schema")
if isinstance(raw_ui_schema, dict):
return raw_ui_schema
legacy_ui_schema = tool_output_data.get("uiSchema")
if isinstance(legacy_ui_schema, dict):
return legacy_ui_schema
from schemas.agent.runtime_models import ToolAgentOutput
tool_output = ToolAgentOutput.model_validate(tool_output_data)
try:
tool_output = ToolAgentOutput.model_validate(tool_output_data)
except Exception:
return None
if not tool_output:
return None
@@ -131,9 +141,19 @@ def _compile_worker_ui_hints(
worker_output_data = metadata.get("worker_agent_output")
if not worker_output_data:
return None
if isinstance(worker_output_data, dict):
raw_ui_schema = worker_output_data.get("ui_schema")
if isinstance(raw_ui_schema, dict):
return raw_ui_schema
legacy_ui_schema = worker_output_data.get("uiSchema")
if isinstance(legacy_ui_schema, dict):
return legacy_ui_schema
from schemas.agent.runtime_models import WorkerAgentOutputRich
worker_output = WorkerAgentOutputRich.model_validate(worker_output_data)
try:
worker_output = WorkerAgentOutputRich.model_validate(worker_output_data)
except Exception:
return None
if not worker_output:
return None
+21 -9
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import time
from collections.abc import Mapping
from typing import Any, cast
from urllib.parse import urlparse
@@ -32,6 +33,11 @@ AUTH_UNAVAILABLE_DETAIL = "Auth service temporarily unavailable"
class SupabaseAuthGateway(AuthServiceGateway):
def __init__(self) -> None:
self._user_lookup_cache_ttl_seconds: int = 60
self._user_lookup_cache_expires_at: float = 0.0
self._users_by_email: dict[str, Any] = {}
def _get_client(self) -> Any:
return supabase_service.get_client()
@@ -185,16 +191,22 @@ class SupabaseAuthGateway(AuthServiceGateway):
async def get_user_by_email(self, email: str) -> UserByEmailResponse:
admin_client = self._get_admin_client()
users = await asyncio.to_thread(_list_auth_users, admin_client)
normalized_email = email.lower()
user = next(
(
candidate
for candidate in users
if str(getattr(candidate, "email", "")).lower() == normalized_email
),
None,
)
now = time.monotonic()
if now >= self._user_lookup_cache_expires_at:
users = await asyncio.to_thread(_list_auth_users, admin_client)
users_by_email: dict[str, Any] = {}
for candidate in users:
candidate_email = str(getattr(candidate, "email", "")).lower()
if candidate_email:
users_by_email[candidate_email] = candidate
self._users_by_email = users_by_email
self._user_lookup_cache_expires_at = (
now + self._user_lookup_cache_ttl_seconds
)
user = self._users_by_email.get(normalized_email)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
+28
View File
@@ -53,6 +53,12 @@ class FriendshipRepository(Protocol):
"""Get friendship by ID."""
...
async def get_friendships_by_ids(
self, friendship_ids: list[UUID]
) -> dict[UUID, Friendship]:
"""Batch get friendships by IDs."""
...
async def get_inbox_messages_for_user(
self, user_id: UUID, status: InboxMessageStatus | None = None
) -> list[InboxMessage]:
@@ -214,6 +220,28 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
)
raise
async def get_friendships_by_ids(
self, friendship_ids: list[UUID]
) -> dict[UUID, Friendship]:
if not friendship_ids:
return {}
try:
unique_ids = list(dict.fromkeys(friendship_ids))
stmt = (
select(Friendship)
.where(Friendship.id.in_(unique_ids))
.where(Friendship.deleted_at.is_(None))
)
result = await self._session.execute(stmt)
friendships = list(result.scalars().all())
return {friendship.id: friendship for friendship in friendships}
except SQLAlchemyError:
logger.exception(
"Failed to get friendships by ids",
friendship_ids=[str(i) for i in friendship_ids],
)
raise
async def get_inbox_messages_for_user(
self, user_id: UUID, status: InboxMessageStatus | None = None
) -> list[InboxMessage]:
+43 -6
View File
@@ -362,6 +362,28 @@ class FriendshipService(BaseService):
status_code=503, detail="Friendship service unavailable"
)
candidate_inbox = [
inbox
for inbox in inbox_messages
if inbox.message_type == InboxMessageType.FRIEND_REQUEST
and inbox.friendship_id is not None
and inbox.sender_id is not None
]
if not candidate_inbox:
return []
friendship_ids = [inbox.friendship_id for inbox in candidate_inbox]
friendships_by_id = await self._repository.get_friendships_by_ids(
cast(list[UUID], friendship_ids)
)
profile_ids = {user_id}
for inbox in candidate_inbox:
sender_id = cast(UUID, inbox.sender_id)
profile_ids.add(sender_id)
profiles_by_id = await self._user_repository.get_by_user_ids(list(profile_ids))
recipient = profiles_by_id.get(user_id)
result: list[FriendRequestResponse] = []
for inbox in inbox_messages:
if inbox.message_type != InboxMessageType.FRIEND_REQUEST:
@@ -371,7 +393,7 @@ class FriendshipService(BaseService):
if friendship_id is None:
continue
friendship = await self._repository.get_friendship_by_id(friendship_id)
friendship = friendships_by_id.get(friendship_id)
if friendship is None or friendship.status != FriendshipStatus.PENDING:
continue
@@ -379,8 +401,7 @@ class FriendshipService(BaseService):
if sender_id is None:
continue
sender = await self._user_repository.get_by_user_id(sender_id)
recipient = await self._user_repository.get_by_user_id(user_id)
sender = profiles_by_id.get(sender_id)
result.append(
FriendRequestResponse(
@@ -460,11 +481,19 @@ class FriendshipService(BaseService):
status_code=503, detail="Friendship service unavailable"
)
if not outgoing:
return []
user_ids = {user_id}
for friendship in outgoing:
user_ids.add(self._get_other_user_id(friendship, user_id))
profiles_by_id = await self._user_repository.get_by_user_ids(list(user_ids))
sender = profiles_by_id.get(user_id)
result: list[FriendRequestResponse] = []
for friendship in outgoing:
other_user_id = self._get_other_user_id(friendship, user_id)
sender = await self._user_repository.get_by_user_id(user_id)
recipient = await self._user_repository.get_by_user_id(other_user_id)
recipient = profiles_by_id.get(other_user_id)
result.append(
FriendRequestResponse(
@@ -489,10 +518,18 @@ class FriendshipService(BaseService):
status_code=503, detail="Friendship service unavailable"
)
if not friendships:
return []
friend_ids = [
self._get_other_user_id(friendship, user_id) for friendship in friendships
]
profiles_by_id = await self._user_repository.get_by_user_ids(friend_ids)
result: list[FriendResponse] = []
for friendship in friendships:
friend_id = self._get_other_user_id(friendship, user_id)
friend = await self._user_repository.get_by_user_id(friend_id)
friend = profiles_by_id.get(friend_id)
result.append(
FriendResponse(
+23
View File
@@ -23,6 +23,10 @@ class UserRepository(Protocol):
"""Get user by user ID."""
...
async def get_by_user_ids(self, user_ids: list[UUID]) -> dict[UUID, Profile]:
"""Batch get users by user IDs."""
...
async def get_by_username(self, username: str) -> Profile | None:
"""Get user by username."""
...
@@ -57,6 +61,25 @@ class SQLAlchemyUserRepository(BaseRepository[Profile]):
logger.exception("User lookup failed", user_id=str(user_id))
raise
async def get_by_user_ids(self, user_ids: list[UUID]) -> dict[UUID, Profile]:
if not user_ids:
return {}
try:
unique_ids = list(dict.fromkeys(user_ids))
stmt = (
select(Profile)
.where(Profile.id.in_(unique_ids))
.where(Profile.deleted_at.is_(None))
)
result = await self._session.execute(stmt)
profiles = list(result.scalars().all())
return {profile.id: profile for profile in profiles}
except SQLAlchemyError:
logger.exception(
"Batch user lookup failed", user_ids=[str(i) for i in user_ids]
)
raise
async def get_by_username(self, username: str) -> Profile | None:
try:
stmt = (