120 lines
3.2 KiB
Python
120 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from sqlalchemy import Column, String, Table, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from core.db.base import Base
|
|
from models.agent_chat_message import AgentChatMessage
|
|
from models.agent_chat_session import AgentChatSession, AgentChatSessionStatus
|
|
from models.llm import Llm
|
|
from models.llm_factory import LlmFactory
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
users_table = Table(
|
|
"users",
|
|
Base.metadata,
|
|
Column("id", String, primary_key=True),
|
|
schema="auth",
|
|
extend_existing=True,
|
|
)
|
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
|
async with engine.begin() as conn:
|
|
await conn.exec_driver_sql("ATTACH DATABASE ':memory:' AS auth")
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
Base.metadata.remove(users_table)
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
async_session = async_sessionmaker(
|
|
bind=db_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
async with async_session() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_factory_and_llm_relationship(db_session: AsyncSession) -> None:
|
|
factory = LlmFactory(
|
|
name="qwen",
|
|
request_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
avatar="https://cdn.example.com/qwen.png",
|
|
)
|
|
db_session.add(factory)
|
|
await db_session.flush()
|
|
|
|
llm = Llm(
|
|
factory_id=factory.id,
|
|
model_code="qwen3.5-flash",
|
|
)
|
|
db_session.add(llm)
|
|
await db_session.commit()
|
|
|
|
found_llm = await db_session.get(Llm, llm.id)
|
|
assert found_llm is not None
|
|
assert found_llm.factory_id == factory.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_status_supports_required_values(
|
|
db_session: AsyncSession,
|
|
) -> None:
|
|
user_id = uuid4()
|
|
session = AgentChatSession(
|
|
user_id=user_id,
|
|
title="test",
|
|
status="pending",
|
|
)
|
|
db_session.add(session)
|
|
await db_session.commit()
|
|
|
|
statuses = [
|
|
AgentChatSessionStatus.PENDING,
|
|
AgentChatSessionStatus.RUNNING,
|
|
AgentChatSessionStatus.COMPLETED,
|
|
AgentChatSessionStatus.FAILED,
|
|
]
|
|
for status in statuses:
|
|
session.status = status
|
|
await db_session.commit()
|
|
await db_session.refresh(session)
|
|
assert session.status == status
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messages_role_supports_tool(db_session: AsyncSession) -> None:
|
|
user_id = uuid4()
|
|
session = AgentChatSession(
|
|
user_id=user_id,
|
|
title="tool test",
|
|
status="pending",
|
|
)
|
|
db_session.add(session)
|
|
await db_session.flush()
|
|
|
|
message = AgentChatMessage(
|
|
session_id=session.id,
|
|
seq=1,
|
|
role="tool",
|
|
content="tool output",
|
|
cost=0,
|
|
)
|
|
db_session.add(message)
|
|
await db_session.commit()
|
|
|
|
result = await db_session.execute(
|
|
select(AgentChatMessage).where(AgentChatMessage.session_id == session.id)
|
|
)
|
|
found = result.scalar_one()
|
|
assert found.role == "tool"
|