refactor: 简化 AgentScope 运行时模块与 prompt 系统

This commit is contained in:
zl-q
2026-03-15 17:14:15 +08:00
parent 61997f3613
commit 072c09d99d
32 changed files with 750 additions and 1863 deletions
+15 -11
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone
from typing import Protocol
from uuid import UUID
from uuid import UUID, uuid4
from fastapi import HTTPException
from sqlalchemy import select
@@ -10,7 +10,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.agent_chat_message import AgentChatMessage, AgentChatMessageRole
from models.agent_chat_session import AgentChatSession
from schemas.messages.chat_message import AgentChatMessage as AgentChatMessageSchema
from schemas.messages.chat_message import (
AgentChatMessage as AgentChatMessageSchema,
AgentChatMessageMetadata,
)
class ToolResultPayloadStorage(Protocol):
@@ -88,10 +91,11 @@ class AgentRepository:
self,
*,
session_id: str,
run_id: str,
content_text: str,
metadata: dict[str, object] | None,
content: str,
metadata: AgentChatMessageMetadata | None,
) -> None:
from models.agent_chat_message import AgentChatMessage as OrmAgentChatMessage
try:
session_uuid = UUID(session_id)
except ValueError as exc:
@@ -108,17 +112,17 @@ class AgentRepository:
next_seq = int(session_row.message_count or 0) + 1
if not _has_title(session_row.title):
session_title = _derive_session_title(content_text)
session_title = _derive_session_title(content)
if session_title is not None:
session_row.title = session_title
payload_metadata = dict(metadata or {})
payload_metadata["run_id"] = run_id
message = AgentChatMessage(
message = OrmAgentChatMessage(
id=uuid4(),
session_id=session_uuid,
seq=next_seq,
role=AgentChatMessageRole.USER,
content=content_text,
metadata_json=payload_metadata,
content=content,
metadata_json=metadata.model_dump(by_alias=True) if metadata else None,
)
self._session.add(message)
session_row.message_count = next_seq
+5 -41
View File
@@ -11,6 +11,10 @@ from typing import Annotated, Union
from ag_ui.core import RunAgentInput
from core.agentscope.events import to_sse_event
from core.agentscope.schemas.agui_input import (
parse_run_input,
validate_run_request_messages_contract,
)
from core.auth.models import CurrentUser
from core.logging import get_logger
from fastapi import (
@@ -26,11 +30,6 @@ from fastapi import (
status,
)
from fastapi.responses import JSONResponse, StreamingResponse
from core.agentscope.schemas.agui_input import (
extract_latest_tool_result,
parse_run_input,
validate_run_request_messages_contract,
)
from services.base.redis import get_or_init_redis_client
from v1.agent.dependencies import get_agent_service
from v1.agent.schemas import (
@@ -129,8 +128,7 @@ async def enqueue_run(
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
try:
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
validate_run_request_messages_contract(normalized)
validate_run_request_messages_contract(request)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
allowed = await _allow_run_request(user_id=str(current_user.id))
@@ -149,40 +147,6 @@ async def enqueue_run(
)
@router.post(
"/runs/{thread_id}/resume",
response_model=TaskAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def enqueue_resume(
thread_id: str,
request: RunAgentInput,
service: Annotated[AgentService, Depends(get_agent_service)],
current_user: Annotated[CurrentUser, Depends(get_current_user)],
) -> TaskAcceptedResponse:
if request.thread_id != thread_id:
raise HTTPException(status_code=422, detail="thread_id path/body mismatch")
try:
normalized = parse_run_input(request.model_dump(mode="json", by_alias=True))
extract_latest_tool_result(normalized)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
allowed = await _allow_run_request(user_id=str(current_user.id))
if not allowed:
raise HTTPException(status_code=429, detail="Too many run requests")
task = await service.enqueue_resume(
thread_id=thread_id,
run_input=request,
current_user=current_user,
)
return TaskAcceptedResponse(
taskId=task.task_id,
threadId=task.thread_id,
runId=task.run_id,
created=task.created,
)
@router.get("/runs/{thread_id}/events")
async def stream_events(
request: Request,
+48 -40
View File
@@ -17,6 +17,10 @@ from core.auth.models import CurrentUser
from core.agentscope.schemas.agui_input import extract_latest_user_payload
from core.config.settings import config
from core.logging import get_logger
from schemas.messages.chat_message import (
AgentChatMessageMetadata,
UserMessageAttachments,
)
logger = get_logger(__name__)
_ALLOWED_ATTACHMENT_MIME_TYPES = {"image/png", "image/jpeg", "image/webp"}
@@ -53,9 +57,8 @@ class AgentRepositoryLike(Protocol):
self,
*,
session_id: str,
run_id: str,
content_text: str,
metadata: dict[str, object] | None,
content: str,
metadata: AgentChatMessageMetadata | None,
) -> None: ...
@@ -157,8 +160,7 @@ class AgentService:
)
await self._repository.persist_user_message(
session_id=thread_id,
run_id=run_id,
content_text=user_message_text,
content=user_message_text,
metadata=user_message_metadata,
)
await self._repository.commit()
@@ -167,7 +169,12 @@ class AgentService:
command={
"command": "run",
"owner_id": str(current_user.id),
"run_input": run_input.model_dump(mode="json", by_alias=True),
"run_input": {
"messages": [
msg.model_dump(mode="json", exclude_none=True)
for msg in run_input.messages
],
},
},
dedup_key=None,
)
@@ -178,14 +185,41 @@ class AgentService:
created=created,
)
async def load_agent_input_messages(
self,
*,
thread_id: str,
) -> dict[str, object] | None:
"""Load recent messages for runtime agent input.
Returns messages from today and yesterday (if exists).
"""
today = await self._repository.get_history_day(
session_id=thread_id,
before=None,
)
if not today:
return None
yesterday = await self._repository.get_history_day(
session_id=thread_id,
before=today.get("day"), # type: ignore
)
messages: list[dict[str, object]] = []
if yesterday and yesterday.get("messages"):
messages.extend(yesterday["messages"]) # type: ignore
if today.get("messages"):
messages.extend(today["messages"]) # type: ignore
return {"messages": messages}
async def _prepare_user_message(
self,
*,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> tuple[str, dict[str, object] | None]:
from schemas.messages.chat_message import UserMessageAttachments
) -> tuple[str, AgentChatMessageMetadata | None]:
text, content_blocks = extract_latest_user_payload(run_input)
user_attachments: UserMessageAttachments | None = None
@@ -227,11 +261,12 @@ class AgentService:
logger.warning("Failed to parse signed URL", url=url, error=str(exc))
raise HTTPException(status_code=422, detail="Invalid signed image url")
metadata: dict[str, object] | None = None
metadata: AgentChatMessageMetadata | None = None
if user_attachments is not None:
metadata = {
"user_message_attachments": user_attachments.model_dump(by_alias=True),
}
metadata = AgentChatMessageMetadata(
run_id=run_input.run_id,
user_message_attachments=user_attachments,
)
return text, metadata
@@ -361,33 +396,6 @@ class AgentService:
"url": signed_url,
}
async def enqueue_resume(
self,
*,
thread_id: str,
run_input: RunAgentInput,
current_user: CurrentUser,
) -> TaskAccepted:
owner = await self._repository.get_session_owner(session_id=thread_id)
ensure_session_owner(owner_id=owner, current_user=current_user)
dedup_key = f"resume:{thread_id}:{run_input.run_id}"
task_id = await self._queue.enqueue(
command={
"command": "resume",
"owner_id": str(current_user.id),
"run_input": run_input.model_dump(mode="json", by_alias=True),
},
dedup_key=dedup_key,
)
return TaskAccepted(
task_id=task_id,
thread_id=thread_id,
run_id=run_input.run_id,
created=False,
)
async def stream_events(
self,
*,
+1 -1
View File
@@ -61,7 +61,7 @@ async def _enforce_rate_limit_with_redis(
window_seconds: int,
) -> None:
client = await get_or_init_redis_client()
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds)
current = await client.eval(_REDIS_LIMIT_SCRIPT, 1, key, window_seconds) # type: ignore[await]
if int(current) > limit:
raise HTTPException(status_code=429, detail="Too many requests")
+4 -4
View File
@@ -81,7 +81,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
super().__init__(session, Friendship)
async def create_request(
self, initiator_id: UUID, recipient_id: UUID, message: str | None = None
self, initiator_id: UUID, recipient_id: UUID, content: str | None = None
) -> tuple[Friendship, InboxMessage]:
try:
user_low_id = min(initiator_id, recipient_id)
@@ -100,7 +100,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
self._session.add(friendship)
await self._session.flush()
inbox_content = FriendshipContent(type="request", message=message)
inbox_content = FriendshipContent(type="request", message=content)
inbox = InboxMessage(
recipient_id=recipient_id,
sender_id=initiator_id,
@@ -126,7 +126,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
self,
friendship: Friendship,
initiator_id: UUID,
message: str | None = None,
content: str | None = None,
) -> tuple[Friendship, InboxMessage]:
try:
now = datetime.now(timezone.utc)
@@ -135,7 +135,7 @@ class SQLAlchemyFriendshipRepository(BaseRepository[Friendship]):
friendship.initiator_id = initiator_id
friendship.updated_by = initiator_id
inbox_content = FriendshipContent(type="request", message=message)
inbox_content = FriendshipContent(type="request", message=content)
inbox = InboxMessage(
recipient_id=(
friendship.user_low_id
+1 -1
View File
@@ -18,7 +18,7 @@ class InboxMessageResponse(BaseModel):
message_type: InboxMessageType
schedule_item_id: UUID | None = None
friendship_id: UUID | None = None
content: str | None = None
content: dict | None = None
is_read: bool = False
status: InboxMessageStatus = InboxMessageStatus.PENDING
created_at: datetime
+7 -7
View File
@@ -7,33 +7,33 @@ from fastapi import APIRouter, Depends
from schemas.user.context import UserContext
from v1.users.dependencies import get_user_service
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
from v1.users.service import UserService
router = APIRouter(prefix="/users", tags=["users"])
@router.get("/me", response_model=UserResponse)
@router.get("/me", response_model=UserContext)
async def get_me(
service: Annotated[UserService, Depends(get_user_service)],
) -> UserResponse:
) -> UserContext:
return await service.get_me()
@router.patch("/me", response_model=UserResponse)
@router.patch("/me", response_model=UserContext)
async def update_me(
payload: UserUpdateRequest,
service: Annotated[UserService, Depends(get_user_service)],
) -> UserResponse:
) -> UserContext:
return await service.update_me(payload)
@router.post("/search", response_model=list[UserResponse])
@router.post("/search", response_model=list[UserContext])
async def search_users(
payload: UserSearchRequest,
service: Annotated[UserService, Depends(get_user_service)],
) -> list[UserResponse]:
) -> list[UserContext]:
return await service.search_users(payload)
-8
View File
@@ -11,14 +11,6 @@ from pydantic import (
model_validator,
)
from schemas.user.context import UserContext
class UserResponse(UserContext):
"""当前用户,含 email,无 settings"""
settings: None = Field(default=None, exclude=True) # type: ignore[assignment]
class UserSearchRequest(BaseModel):
query: str = Field(min_length=1, max_length=100)
+18 -12
View File
@@ -13,8 +13,9 @@ from core.agentscope.persistence.user_context_cache import (
)
from core.db.base_service import BaseService
from core.logging import get_logger
from schemas.user.context import UserContext, parse_profile_settings
from v1.users.repository import UserRepository
from v1.users.schemas import UserResponse, UserSearchRequest, UserUpdateRequest
from v1.users.schemas import UserSearchRequest, UserUpdateRequest
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
@@ -82,7 +83,7 @@ class UserService(BaseService):
user_context_cache or create_user_context_cache(),
)
async def get_me(self) -> UserResponse:
async def get_me(self) -> UserContext:
user_id = self.require_user_id()
try:
user = await self._repository.get_by_user_id(user_id)
@@ -92,12 +93,13 @@ class UserService(BaseService):
if user is None:
raise HTTPException(status_code=404, detail="User not found")
email = self._current_user.email if self._current_user else None
return UserResponse(
return UserContext(
id=str(user.id),
username=user.username,
email=email,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
async def get_user_by_id(self, user_id: UUID) -> "UserContext":
@@ -116,7 +118,7 @@ class UserService(BaseService):
avatar_url=profile.avatar_url,
)
async def update_me(self, update: UserUpdateRequest) -> UserResponse:
async def update_me(self, update: UserUpdateRequest) -> UserContext:
user_id = self.require_user_id()
update_data: dict[str, str | None] = {
key: value
@@ -151,15 +153,16 @@ class UserService(BaseService):
)
email = self._current_user.email if self._current_user else None
return UserResponse(
return UserContext(
id=str(user.id),
username=user.username,
email=email,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
async def get_by_username(self, username: str) -> UserResponse:
async def get_by_username(self, username: str) -> UserContext:
try:
user = await self._repository.get_by_username(username)
except SQLAlchemyError:
@@ -167,14 +170,15 @@ class UserService(BaseService):
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return UserResponse(
return UserContext(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
async def search_users(self, request: UserSearchRequest) -> list[UserResponse]:
async def search_users(self, request: UserSearchRequest) -> list[UserContext]:
query = request.query.strip()
if _EMAIL_PATTERN.match(query):
@@ -182,7 +186,7 @@ class UserService(BaseService):
return await self._search_by_username(query)
async def _search_by_email(self, email: str) -> list[UserResponse]:
async def _search_by_email(self, email: str) -> list[UserContext]:
if self._auth_gateway is None:
raise HTTPException(status_code=503, detail="Auth lookup unavailable")
@@ -199,26 +203,28 @@ class UserService(BaseService):
return []
return [
UserResponse(
UserContext(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
]
async def _search_by_username(self, query: str) -> list[UserResponse]:
async def _search_by_username(self, query: str) -> list[UserContext]:
try:
users = await self._repository.search_users(query, limit=20)
except SQLAlchemyError:
raise HTTPException(status_code=503, detail="User store unavailable")
return [
UserResponse(
UserContext(
id=str(user.id),
username=user.username,
avatar_url=user.avatar_url,
bio=user.bio,
settings=parse_profile_settings(user.settings),
)
for user in users
]