import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session # Import Session from app.database import Base, get_db, KeyCategory, APIKey # Import Base, get_db, and models from database from app import crud # Import crud from app.crud # Use an in-memory SQLite database for testing SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:" engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Override the get_db dependency for testing def override_get_db(): try: db = TestingSessionLocal() yield db finally: db.close() # Fixture to create and drop tables for each test @pytest.fixture(scope="function") def db_session(): Base.metadata.create_all(bind=engine) db = TestingSessionLocal() yield db db.close() Base.metadata.drop_all(bind=engine) # --- Test KeyCategory CRUD --- def test_create_key_category(db_session: Session): category = crud.create_key_category(db_session, name="test_llm", type="llm", tags=["llm", "test"]) assert category.id is not None assert category.name == "test_llm" assert category.type == "llm" assert category.tags == ["llm", "test"] def test_get_key_category(db_session: Session): created_category = crud.create_key_category(db_session, name="get_test", type="test") retrieved_category = crud.get_key_category(db_session, created_category.id) assert retrieved_category is not None assert retrieved_category.name == "get_test" def test_get_key_category_by_name(db_session: Session): crud.create_key_category(db_session, name="get_by_name_test", type="test") retrieved_category = crud.get_key_category_by_name(db_session, "get_by_name_test") assert retrieved_category is not None assert retrieved_category.name == "get_by_name_test" def test_get_key_categories(db_session: Session): crud.create_key_category(db_session, name="list_test_1", type="test") crud.create_key_category(db_session, name="list_test_2", type="test") categories = crud.get_key_categories(db_session) assert len(categories) >= 2 # Account for potential other categories if not in-memory db def test_update_key_category(db_session: Session): created_category = crud.create_key_category(db_session, name="update_test", type="old_type") updated_category = crud.update_key_category(db_session, created_category.id, name="updated_test", type="new_type", tags=["updated"]) assert updated_category.name == "updated_test" assert updated_category.type == "new_type" assert updated_category.tags == ["updated"] def test_delete_key_category(db_session: Session): created_category = crud.create_key_category(db_session, name="delete_test", type="test") deleted_category = crud.delete_key_category(db_session, created_category.id) assert deleted_category is not None retrieved_category = crud.get_key_category(db_session, created_category.id) assert retrieved_category is None # --- Test APIKey CRUD --- def test_create_api_key(db_session: Session): category = crud.create_key_category(db_session, name="key_cat_for_key", type="test") api_key = crud.create_api_key(db_session, value="test_key_value", category_id=category.id) assert api_key.id is not None assert api_key.value == "test_key_value" assert api_key.category_id == category.id assert api_key.status == "active" def test_get_api_key(db_session: Session): category = crud.create_key_category(db_session, name="get_key_cat", type="test") created_key = crud.create_api_key(db_session, value="get_test_key", category_id=category.id) retrieved_key = crud.get_api_key(db_session, created_key.id) assert retrieved_key is not None assert retrieved_key.value == "get_test_key" def test_get_api_keys(db_session: Session): category1 = crud.create_key_category(db_session, name="list_key_cat_1", type="test") category2 = crud.create_key_category(db_session, name="list_key_cat_2", type="test") crud.create_api_key(db_session, value="list_key_1", category_id=category1.id) crud.create_api_key(db_session, value="list_key_2", category_id=category1.id) crud.create_api_key(db_session, value="list_key_3", category_id=category2.id) all_keys = crud.get_api_keys(db_session) assert len(all_keys) >= 3 category1_keys = crud.get_api_keys(db_session, category_id=category1.id) assert len(category1_keys) == 2 active_keys = crud.get_api_keys(db_session, status="active") assert len(active_keys) >= 3 # All created keys are active by default inactive_keys = crud.get_api_keys(db_session, status="inactive") assert len(inactive_keys) == 0 def test_update_api_key(db_session: Session): category1 = crud.create_key_category(db_session, name="update_key_cat_1", type="test") category2 = crud.create_key_category(db_session, name="update_key_cat_2", type="test") created_key = crud.create_api_key(db_session, value="old_key_value", category_id=category1.id, status="active") updated_key = crud.update_api_key(db_session, created_key.id, value="new_key_value", category_id=category2.id, status="inactive") assert updated_key.value == "new_key_value" assert updated_key.category_id == category2.id assert updated_key.status == "inactive" def test_delete_api_key(db_session: Session): category = crud.create_key_category(db_session, name="delete_key_cat", type="test") created_key = crud.create_api_key(db_session, value="delete_test_key", category_id=category.id) deleted_key = crud.delete_api_key(db_session, created_key.id) assert deleted_key is not None retrieved_key = crud.get_api_key(db_session, created_key.id) assert retrieved_key is None # --- Test Key Selection Logic Placeholder --- def test_get_available_keys_for_category(db_session: Session): category = crud.create_key_category(db_session, name="available_key_cat", type="test") crud.create_api_key(db_session, value="key1", category_id=category.id, status="active") crud.create_api_key(db_session, value="key2", category_id=category.id, status="inactive") crud.create_api_key(db_session, value="key3", category_id=category.id, status="active") available_keys = crud.get_available_keys_for_category(db_session, category.id) assert len(available_keys) == 2 assert all(key.status == "active" for key in available_keys) assert {key.value for key in available_keys} == {"key1", "key3"} def test_select_key_from_pool_no_keys(db_session: Session): category = crud.create_key_category(db_session, name="empty_pool_cat", type="test") selected_key = crud.select_key_from_pool(db_session, category.id) assert selected_key is None def test_select_key_from_pool_basic(db_session: Session): category = crud.create_key_category(db_session, name="basic_pool_cat", type="test") key1 = crud.create_api_key(db_session, value="key1", category_id=category.id, status="active") key2 = crud.create_api_key(db_session, value="key2", category_id=category.id, status="active") # Basic selection returns the first one found selected_key = crud.select_key_from_pool(db_session, category.id) assert selected_key is not None # The order might depend on DB implementation, but it should be one of the active keys assert selected_key.value in ["key1", "key2"]