diff --git a/backend/src/core/config/initial/init_data.py b/backend/src/core/config/initial/init_data.py index bc87829..db088f2 100644 --- a/backend/src/core/config/initial/init_data.py +++ b/backend/src/core/config/initial/init_data.py @@ -13,6 +13,7 @@ from core.db.session import AsyncSessionLocal from core.logging import get_logger from models.llm import Llm from models.llm_factory import LlmFactory +from models.user_agent_catalog import UserAgentCatalog logger = get_logger("core.config.initial.init_data") @@ -33,6 +34,17 @@ class LlmCatalogSeed(BaseModel): llms: list[LlmSeed] +class UserAgentCatalogSeed(BaseModel): + agent_type: str + llm_model_code: str + status: str + config: dict[str, Any] + + +class UserAgentCatalogYaml(BaseModel): + agents: list[UserAgentCatalogSeed] + + def _default_catalog_path() -> Path: return ( Path(__file__).resolve().parents[1] / "static" / "database" / "llm_catalog.yaml" @@ -62,6 +74,32 @@ def load_llm_catalog(catalog_path: Path | None = None) -> dict[str, Any]: return parsed.model_dump() +def _default_user_agent_catalog_path() -> Path: + return ( + Path(__file__).resolve().parents[1] + / "static" + / "database" + / "user_agent_catalog.yaml" + ) + + +def load_user_agent_catalog(catalog_path: Path | None = None) -> dict[str, Any]: + path = catalog_path or _default_user_agent_catalog_path() + with path.open("r", encoding="utf-8") as file: + loaded = yaml.safe_load(file) or {} + if not isinstance(loaded, dict): + raise ValueError(f"Invalid user agent catalog format: {path}") + raw_agents = loaded.get("agents", []) + if not isinstance(raw_agents, list): + raise ValueError(f"Invalid user agent catalog agents section: {path}") + try: + parsed = UserAgentCatalogYaml.model_validate({"agents": list(raw_agents)}) + except ValidationError as exc: + raise ValueError(f"Invalid user agent catalog data: {path}") from exc + + return parsed.model_dump() + + async def _upsert_factory( session: AsyncSession, *, @@ -97,8 +135,63 @@ async def _upsert_llm( llm.factory_id = factory_id -async def initialize_data() -> bool: - """Initialize bootstrap data.""" +async def _upsert_user_agent_catalog( + session: AsyncSession, + *, + agent_type: str, + llm_id: uuid.UUID, + status: str, + config: dict[str, Any], +) -> None: + result = await session.execute( + select(UserAgentCatalog).where(UserAgentCatalog.agent_type == agent_type) + ) + catalog_entry = result.scalar_one_or_none() + + if catalog_entry is None: + session.add( + UserAgentCatalog( + agent_type=agent_type, + llm_id=llm_id, + status=status, + config=config, + ) + ) + else: + catalog_entry.llm_id = llm_id + catalog_entry.status = status + catalog_entry.config = config + + +async def initialize_user_agent_catalog() -> None: + """Initialize user agent catalog from YAML.""" + catalog = load_user_agent_catalog() + + async with AsyncSessionLocal() as session: + async with session.begin(): + for agent in catalog["agents"]: + result = await session.execute( + select(Llm).where(Llm.model_code == agent["llm_model_code"]) + ) + llm = result.scalar_one_or_none() + if llm is None: + raise RuntimeError( + f"LLM model '{agent['llm_model_code']}' not found for agent type '{agent['agent_type']}'" + ) + + await _upsert_user_agent_catalog( + session, + agent_type=agent["agent_type"], + llm_id=llm.id, + status=agent["status"], + config=agent["config"], + ) + + logger.info("Initialized user agent catalog") + + +async def initialize_llm_catalog() -> None: + """Initialize LLM catalog from YAML.""" catalog = load_llm_catalog() async with AsyncSessionLocal() as session: @@ -127,4 +220,11 @@ async def initialize_data() -> bool: ) logger.info("Initialized LLM factory/model seed data") + + +async def initialize_data() -> bool: + """Initialize bootstrap data.""" + await initialize_llm_catalog() + await initialize_user_agent_catalog() + return True diff --git a/backend/src/models/user_agent_catalog.py b/backend/src/models/user_agent_catalog.py index 0791e87..a7bc9a8 100644 --- a/backend/src/models/user_agent_catalog.py +++ b/backend/src/models/user_agent_catalog.py @@ -1,5 +1,7 @@ from __future__ import annotations +import uuid + from sqlalchemy import ForeignKey, String from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column @@ -14,7 +16,7 @@ class UserAgentCatalog(TimestampMixin, Base): String(20), primary_key=True, ) - llm_id: Mapped[str] = mapped_column( + llm_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("llms.id", ondelete="RESTRICT"), nullable=False,