fastapi-proxy-test / tests /test_crud.py
tanbushi's picture
Sun Jun 8 15:02:12 CST 2025
b7791c2
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session # Import Session
from app.database import Base, get_db, KeyCategory, APIKey # Import Base, get_db, and models from database
from app import crud # Import crud from app.crud
# Use an in-memory SQLite database for testing
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Override the get_db dependency for testing
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
# Fixture to create and drop tables for each test
@pytest.fixture(scope="function")
def db_session():
Base.metadata.create_all(bind=engine)
db = TestingSessionLocal()
yield db
db.close()
Base.metadata.drop_all(bind=engine)
# --- Test KeyCategory CRUD ---
def test_create_key_category(db_session: Session):
category = crud.create_key_category(db_session, name="test_llm", type="llm", tags=["llm", "test"])
assert category.id is not None
assert category.name == "test_llm"
assert category.type == "llm"
assert category.tags == ["llm", "test"]
def test_get_key_category(db_session: Session):
created_category = crud.create_key_category(db_session, name="get_test", type="test")
retrieved_category = crud.get_key_category(db_session, created_category.id)
assert retrieved_category is not None
assert retrieved_category.name == "get_test"
def test_get_key_category_by_name(db_session: Session):
crud.create_key_category(db_session, name="get_by_name_test", type="test")
retrieved_category = crud.get_key_category_by_name(db_session, "get_by_name_test")
assert retrieved_category is not None
assert retrieved_category.name == "get_by_name_test"
def test_get_key_categories(db_session: Session):
crud.create_key_category(db_session, name="list_test_1", type="test")
crud.create_key_category(db_session, name="list_test_2", type="test")
categories = crud.get_key_categories(db_session)
assert len(categories) >= 2 # Account for potential other categories if not in-memory db
def test_update_key_category(db_session: Session):
created_category = crud.create_key_category(db_session, name="update_test", type="old_type")
updated_category = crud.update_key_category(db_session, created_category.id, name="updated_test", type="new_type", tags=["updated"])
assert updated_category.name == "updated_test"
assert updated_category.type == "new_type"
assert updated_category.tags == ["updated"]
def test_delete_key_category(db_session: Session):
created_category = crud.create_key_category(db_session, name="delete_test", type="test")
deleted_category = crud.delete_key_category(db_session, created_category.id)
assert deleted_category is not None
retrieved_category = crud.get_key_category(db_session, created_category.id)
assert retrieved_category is None
# --- Test APIKey CRUD ---
def test_create_api_key(db_session: Session):
category = crud.create_key_category(db_session, name="key_cat_for_key", type="test")
api_key = crud.create_api_key(db_session, value="test_key_value", category_id=category.id)
assert api_key.id is not None
assert api_key.value == "test_key_value"
assert api_key.category_id == category.id
assert api_key.status == "active"
def test_get_api_key(db_session: Session):
category = crud.create_key_category(db_session, name="get_key_cat", type="test")
created_key = crud.create_api_key(db_session, value="get_test_key", category_id=category.id)
retrieved_key = crud.get_api_key(db_session, created_key.id)
assert retrieved_key is not None
assert retrieved_key.value == "get_test_key"
def test_get_api_keys(db_session: Session):
category1 = crud.create_key_category(db_session, name="list_key_cat_1", type="test")
category2 = crud.create_key_category(db_session, name="list_key_cat_2", type="test")
crud.create_api_key(db_session, value="list_key_1", category_id=category1.id)
crud.create_api_key(db_session, value="list_key_2", category_id=category1.id)
crud.create_api_key(db_session, value="list_key_3", category_id=category2.id)
all_keys = crud.get_api_keys(db_session)
assert len(all_keys) >= 3
category1_keys = crud.get_api_keys(db_session, category_id=category1.id)
assert len(category1_keys) == 2
active_keys = crud.get_api_keys(db_session, status="active")
assert len(active_keys) >= 3 # All created keys are active by default
inactive_keys = crud.get_api_keys(db_session, status="inactive")
assert len(inactive_keys) == 0
def test_update_api_key(db_session: Session):
category1 = crud.create_key_category(db_session, name="update_key_cat_1", type="test")
category2 = crud.create_key_category(db_session, name="update_key_cat_2", type="test")
created_key = crud.create_api_key(db_session, value="old_key_value", category_id=category1.id, status="active")
updated_key = crud.update_api_key(db_session, created_key.id, value="new_key_value", category_id=category2.id, status="inactive")
assert updated_key.value == "new_key_value"
assert updated_key.category_id == category2.id
assert updated_key.status == "inactive"
def test_delete_api_key(db_session: Session):
category = crud.create_key_category(db_session, name="delete_key_cat", type="test")
created_key = crud.create_api_key(db_session, value="delete_test_key", category_id=category.id)
deleted_key = crud.delete_api_key(db_session, created_key.id)
assert deleted_key is not None
retrieved_key = crud.get_api_key(db_session, created_key.id)
assert retrieved_key is None
# --- Test Key Selection Logic Placeholder ---
def test_get_available_keys_for_category(db_session: Session):
category = crud.create_key_category(db_session, name="available_key_cat", type="test")
crud.create_api_key(db_session, value="key1", category_id=category.id, status="active")
crud.create_api_key(db_session, value="key2", category_id=category.id, status="inactive")
crud.create_api_key(db_session, value="key3", category_id=category.id, status="active")
available_keys = crud.get_available_keys_for_category(db_session, category.id)
assert len(available_keys) == 2
assert all(key.status == "active" for key in available_keys)
assert {key.value for key in available_keys} == {"key1", "key3"}
def test_select_key_from_pool_no_keys(db_session: Session):
category = crud.create_key_category(db_session, name="empty_pool_cat", type="test")
selected_key = crud.select_key_from_pool(db_session, category.id)
assert selected_key is None
def test_select_key_from_pool_basic(db_session: Session):
category = crud.create_key_category(db_session, name="basic_pool_cat", type="test")
key1 = crud.create_api_key(db_session, value="key1", category_id=category.id, status="active")
key2 = crud.create_api_key(db_session, value="key2", category_id=category.id, status="active")
# Basic selection returns the first one found
selected_key = crud.select_key_from_pool(db_session, category.id)
assert selected_key is not None
# The order might depend on DB implementation, but it should be one of the active keys
assert selected_key.value in ["key1", "key2"]