refactor(agent): remove memory agent, simplify runtime config system
This commit is contained in:
@@ -18,6 +18,7 @@ from schemas.agent.forwarded_props import (
|
||||
RuntimeMode,
|
||||
)
|
||||
from schemas.agent.visibility import SystemVisibilityBit, bit_mask
|
||||
from schemas.automation import RuntimeConfig
|
||||
from schemas.messages.chat_message import (
|
||||
AgentChatMessageMetadata,
|
||||
UserMessageAttachment,
|
||||
@@ -72,6 +73,7 @@ class AgentService:
|
||||
*,
|
||||
run_input: RunAgentInput,
|
||||
current_user: CurrentUser,
|
||||
runtime_config: RuntimeConfig | None = None,
|
||||
) -> TaskAccepted:
|
||||
created = False
|
||||
thread_id = run_input.thread_id
|
||||
@@ -82,6 +84,13 @@ class AgentService:
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
if runtime_config is None:
|
||||
from v1.agent.system_agents_config import (
|
||||
build_runtime_config_from_system_agents,
|
||||
)
|
||||
|
||||
runtime_config = build_runtime_config_from_system_agents()
|
||||
|
||||
try:
|
||||
owner = await self._repository.get_session_owner(session_id=thread_id)
|
||||
except HTTPException as exc:
|
||||
@@ -124,6 +133,9 @@ class AgentService:
|
||||
"run_input": run_input.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"runtime_config": runtime_config.model_dump(
|
||||
mode="json", by_alias=True, exclude_none=True
|
||||
),
|
||||
"queue": queue,
|
||||
},
|
||||
dedup_key=None,
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
System agents 配置加载工具
|
||||
|
||||
从 system_agents.yaml 加载配置并构建 RuntimeConfig
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from schemas.agent.system_agent import SystemAgentLLMConfig
|
||||
from schemas.automation import (
|
||||
ContextSource,
|
||||
ContextWindowMode,
|
||||
MemoryContextConfig,
|
||||
RuntimeConfig,
|
||||
)
|
||||
|
||||
|
||||
def _default_system_agents_path() -> Path:
|
||||
return (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "core"
|
||||
/ "config"
|
||||
/ "static"
|
||||
/ "database"
|
||||
/ "system_agents.yaml"
|
||||
)
|
||||
|
||||
|
||||
def _load_system_agents_yaml(path: Path | None = None) -> dict:
|
||||
target_path = path or _default_system_agents_path()
|
||||
with target_path.open("r", encoding="utf-8") as f:
|
||||
loaded = yaml.safe_load(f) or {}
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"Invalid system agents format: {target_path}")
|
||||
return loaded
|
||||
|
||||
|
||||
def _parse_context_messages_config(yaml_config: dict | None) -> MemoryContextConfig:
|
||||
if not yaml_config:
|
||||
return MemoryContextConfig()
|
||||
mode_str = yaml_config.get("mode", "day")
|
||||
count = yaml_config.get("count", 2)
|
||||
try:
|
||||
source = ContextSource.LATEST_CHAT
|
||||
except ValueError:
|
||||
source = ContextSource.LATEST_CHAT
|
||||
try:
|
||||
window_mode = ContextWindowMode(mode_str)
|
||||
except ValueError:
|
||||
window_mode = ContextWindowMode.DAY
|
||||
return MemoryContextConfig(
|
||||
source=source,
|
||||
window_mode=window_mode,
|
||||
window_count=count,
|
||||
)
|
||||
|
||||
|
||||
def build_runtime_config_from_system_agents(
|
||||
yaml_path: Path | None = None,
|
||||
) -> RuntimeConfig:
|
||||
"""
|
||||
从 system_agents.yaml 构建 RuntimeConfig
|
||||
|
||||
chat 模式使用:
|
||||
- router.context_messages 配置 context
|
||||
- worker.enabled_tools 配置 tools
|
||||
"""
|
||||
raw = _load_system_agents_yaml(yaml_path)
|
||||
agents_list = raw.get("agents", [])
|
||||
|
||||
router_config: SystemAgentLLMConfig | None = None
|
||||
worker_config: SystemAgentLLMConfig | None = None
|
||||
|
||||
for agent in agents_list:
|
||||
agent_type = str(agent.get("agent_type", "")).strip().lower()
|
||||
if agent_type == "router":
|
||||
config_dict = agent.get("config") or {}
|
||||
try:
|
||||
router_config = SystemAgentLLMConfig.model_validate(config_dict)
|
||||
except ValidationError:
|
||||
router_config = SystemAgentLLMConfig()
|
||||
elif agent_type == "worker":
|
||||
config_dict = agent.get("config") or {}
|
||||
try:
|
||||
worker_config = SystemAgentLLMConfig.model_validate(config_dict)
|
||||
except ValidationError:
|
||||
worker_config = SystemAgentLLMConfig()
|
||||
|
||||
context_cfg = _parse_context_messages_config(
|
||||
router_config.context_messages.model_dump() if router_config else None
|
||||
)
|
||||
|
||||
enabled_tools: list[str] = []
|
||||
if worker_config and worker_config.enabled_tools:
|
||||
enabled_tools = [str(t) for t in worker_config.enabled_tools]
|
||||
|
||||
return RuntimeConfig(
|
||||
enabled_tools=enabled_tools,
|
||||
context=context_cfg,
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from v1.automation_jobs.service import AutomationJobsService
|
||||
|
||||
__all__ = ["AutomationJobsService"]
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.agent_chat_session import AgentChatSession, SessionType
|
||||
from models.automation_jobs import AutomationJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AutomationJobsRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def list_due_jobs(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[AutomationJob]:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.where(AutomationJob.status == "active")
|
||||
.where(AutomationJob.next_run_at <= now_utc)
|
||||
.order_by(AutomationJob.next_run_at.asc())
|
||||
.limit(max(limit, 1))
|
||||
)
|
||||
rows = (await self._session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
async def update_job_schedule(
|
||||
self,
|
||||
*,
|
||||
job_id: UUID,
|
||||
next_run_at: datetime,
|
||||
last_run_at: datetime,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
.values(next_run_at=next_run_at, last_run_at=last_run_at)
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
await self._session.flush()
|
||||
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
stmt = (
|
||||
select(AgentChatSession.id)
|
||||
.where(AgentChatSession.user_id == owner_id)
|
||||
.where(AgentChatSession.deleted_at.is_(None))
|
||||
.where(AgentChatSession.session_type == SessionType.CHAT)
|
||||
.order_by(AgentChatSession.last_activity_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
existing = (await self._session.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
new_session = AgentChatSession(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
session_type=SessionType.CHAT,
|
||||
)
|
||||
self._session.add(new_session)
|
||||
await self._session.flush()
|
||||
return new_session.id
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from models.automation_jobs import ScheduleType
|
||||
from schemas.automation import AutomationJob as AutomationJobSchema, RuntimeConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from v1.automation_jobs.repository import AutomationJobsRepository
|
||||
|
||||
|
||||
class DispatchFn(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
input_text: str,
|
||||
runtime_config: RuntimeConfig,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def _compute_next_run_at(
|
||||
*,
|
||||
current_next_run_at: datetime,
|
||||
now_utc: datetime,
|
||||
schedule_type: ScheduleType,
|
||||
) -> datetime:
|
||||
delta = timedelta(days=1 if schedule_type == ScheduleType.DAILY else 7)
|
||||
next_run_at = current_next_run_at
|
||||
while next_run_at <= now_utc:
|
||||
next_run_at = next_run_at + delta
|
||||
return next_run_at
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ScanResult:
|
||||
scanned: int
|
||||
dispatched: int
|
||||
|
||||
|
||||
class AutomationJobsService:
|
||||
def __init__(
|
||||
self,
|
||||
repository: "AutomationJobsRepository",
|
||||
session: "AsyncSession",
|
||||
) -> None:
|
||||
self._repository = repository
|
||||
self._session = session
|
||||
|
||||
async def scan_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
limit: int,
|
||||
dispatch_fn: DispatchFn,
|
||||
) -> ScanResult:
|
||||
rows = await self._repository.list_due_jobs(now_utc=now_utc, limit=limit)
|
||||
scanned = len(rows)
|
||||
dispatched = 0
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
job = AutomationJobSchema.from_orm(row)
|
||||
thread_id = await self.get_or_create_chat_session(owner_id=job.owner_id)
|
||||
run_id = f"auto-{job.id}-{int(now_utc.timestamp())}"
|
||||
|
||||
await dispatch_fn(
|
||||
owner_id=job.owner_id,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
input_text=job.config.input_template.strip(),
|
||||
runtime_config=RuntimeConfig(
|
||||
enabled_tools=job.config.enabled_tools,
|
||||
context=job.config.context,
|
||||
),
|
||||
)
|
||||
|
||||
await self._repository.update_job_schedule(
|
||||
job_id=job.id,
|
||||
next_run_at=_compute_next_run_at(
|
||||
current_next_run_at=job.next_run_at,
|
||||
now_utc=now_utc,
|
||||
schedule_type=job.schedule_type,
|
||||
),
|
||||
last_run_at=now_utc,
|
||||
)
|
||||
await self._session.commit()
|
||||
dispatched += 1
|
||||
|
||||
except Exception:
|
||||
await self._session.rollback()
|
||||
raise
|
||||
|
||||
return ScanResult(scanned=scanned, dispatched=dispatched)
|
||||
|
||||
async def get_or_create_chat_session(self, *, owner_id: UUID) -> UUID:
|
||||
return await self._repository.get_or_create_chat_session(owner_id=owner_id)
|
||||
@@ -0,0 +1,3 @@
|
||||
from v1.memories.service import MemoriesService
|
||||
|
||||
__all__ = ["MemoriesService"]
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.memories import Memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class MemoriesRepositoryLike(Protocol):
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]: ...
|
||||
|
||||
|
||||
class MemoriesRepository(BaseRepository[Memory]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=Memory)
|
||||
|
||||
async def get_active_memories(self, *, owner_id: UUID) -> list[Memory]:
|
||||
stmt = (
|
||||
select(Memory)
|
||||
.where(Memory.owner_id == owner_id)
|
||||
.where(Memory.status == "active")
|
||||
.order_by(Memory.created_at.desc())
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from models.memories import Memory
|
||||
from schemas.memories import MemoryContext, MemoryListResponse, MemorySource, MemoryType
|
||||
from v1.memories.repository import MemoriesRepositoryLike
|
||||
|
||||
|
||||
class MemoriesService:
|
||||
_repository: MemoriesRepositoryLike
|
||||
|
||||
def __init__(self, repository: MemoriesRepositoryLike) -> None:
|
||||
self._repository = repository
|
||||
|
||||
def _to_context(self, memory: Memory) -> MemoryContext:
|
||||
return MemoryContext(
|
||||
memory_type=MemoryType(memory.memory_type.value),
|
||||
source=MemorySource(memory.source.value),
|
||||
title=memory.title,
|
||||
content=memory.content,
|
||||
created_at=memory.created_at,
|
||||
updated_at=memory.updated_at,
|
||||
)
|
||||
|
||||
async def get_user_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
user_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "user"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=user_memories, total=len(user_memories)
|
||||
)
|
||||
|
||||
async def get_agent_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
agent_memories = [
|
||||
self._to_context(memory)
|
||||
for memory in memories
|
||||
if memory.memory_type.value == "work"
|
||||
]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=agent_memories, total=len(agent_memories)
|
||||
)
|
||||
|
||||
async def get_all_memories(self, *, owner_id: UUID) -> MemoryListResponse:
|
||||
memories = await self._repository.get_active_memories(owner_id=owner_id)
|
||||
memory_contexts = [self._to_context(memory) for memory in memories]
|
||||
return MemoryListResponse(
|
||||
owner_id=owner_id, memories=memory_contexts, total=len(memory_contexts)
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
from v1.memory.service import MemoryService
|
||||
|
||||
__all__ = ["MemoryService"]
|
||||
@@ -1,35 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.base_repository import BaseRepository
|
||||
from models.automation_jobs import AutomationJob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class MemoryRepositoryLike(Protocol):
|
||||
async def get_job_by_id_and_owner(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJob | None: ...
|
||||
|
||||
|
||||
class MemoryRepository(BaseRepository[AutomationJob]):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
super().__init__(session=session, model=AutomationJob)
|
||||
|
||||
async def get_job_by_id_and_owner(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJob | None:
|
||||
stmt = (
|
||||
select(AutomationJob)
|
||||
.where(AutomationJob.id == job_id)
|
||||
.where(AutomationJob.owner_id == owner_id)
|
||||
.where(AutomationJob.deleted_at.is_(None))
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from schemas.automation.config import AutomationJobConfig
|
||||
from v1.memory.repository import MemoryRepositoryLike
|
||||
|
||||
|
||||
class MemoryService:
|
||||
_repository: MemoryRepositoryLike
|
||||
|
||||
def __init__(self, repository: MemoryRepositoryLike) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def get_memory_job_config(
|
||||
self, *, job_id: UUID, owner_id: UUID
|
||||
) -> AutomationJobConfig:
|
||||
job = await self._repository.get_job_by_id_and_owner(
|
||||
job_id=job_id, owner_id=owner_id
|
||||
)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Automation job not found")
|
||||
return AutomationJobConfig.model_validate(job.config or {})
|
||||
Reference in New Issue
Block a user