File size: 7,471 Bytes
b7791c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session # Import Session
from app.database import Base, get_db, KeyCategory, APIKey # Import Base, get_db, and models from database
from app import crud # Import crud from app.crud

# Use an in-memory SQLite database for testing
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"

engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Override the get_db dependency for testing
def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()

# Fixture to create and drop tables for each test
@pytest.fixture(scope="function")
def db_session():
    Base.metadata.create_all(bind=engine)
    db = TestingSessionLocal()
    yield db
    db.close()
    Base.metadata.drop_all(bind=engine)

# --- Test KeyCategory CRUD ---

def test_create_key_category(db_session: Session):
    category = crud.create_key_category(db_session, name="test_llm", type="llm", tags=["llm", "test"])
    assert category.id is not None
    assert category.name == "test_llm"
    assert category.type == "llm"
    assert category.tags == ["llm", "test"]

def test_get_key_category(db_session: Session):
    created_category = crud.create_key_category(db_session, name="get_test", type="test")
    retrieved_category = crud.get_key_category(db_session, created_category.id)
    assert retrieved_category is not None
    assert retrieved_category.name == "get_test"

def test_get_key_category_by_name(db_session: Session):
    crud.create_key_category(db_session, name="get_by_name_test", type="test")
    retrieved_category = crud.get_key_category_by_name(db_session, "get_by_name_test")
    assert retrieved_category is not None
    assert retrieved_category.name == "get_by_name_test"

def test_get_key_categories(db_session: Session):
    crud.create_key_category(db_session, name="list_test_1", type="test")
    crud.create_key_category(db_session, name="list_test_2", type="test")
    categories = crud.get_key_categories(db_session)
    assert len(categories) >= 2 # Account for potential other categories if not in-memory db

def test_update_key_category(db_session: Session):
    created_category = crud.create_key_category(db_session, name="update_test", type="old_type")
    updated_category = crud.update_key_category(db_session, created_category.id, name="updated_test", type="new_type", tags=["updated"])
    assert updated_category.name == "updated_test"
    assert updated_category.type == "new_type"
    assert updated_category.tags == ["updated"]

def test_delete_key_category(db_session: Session):
    created_category = crud.create_key_category(db_session, name="delete_test", type="test")
    deleted_category = crud.delete_key_category(db_session, created_category.id)
    assert deleted_category is not None
    retrieved_category = crud.get_key_category(db_session, created_category.id)
    assert retrieved_category is None

# --- Test APIKey CRUD ---

def test_create_api_key(db_session: Session):
    category = crud.create_key_category(db_session, name="key_cat_for_key", type="test")
    api_key = crud.create_api_key(db_session, value="test_key_value", category_id=category.id)
    assert api_key.id is not None
    assert api_key.value == "test_key_value"
    assert api_key.category_id == category.id
    assert api_key.status == "active"

def test_get_api_key(db_session: Session):
    category = crud.create_key_category(db_session, name="get_key_cat", type="test")
    created_key = crud.create_api_key(db_session, value="get_test_key", category_id=category.id)
    retrieved_key = crud.get_api_key(db_session, created_key.id)
    assert retrieved_key is not None
    assert retrieved_key.value == "get_test_key"

def test_get_api_keys(db_session: Session):
    category1 = crud.create_key_category(db_session, name="list_key_cat_1", type="test")
    category2 = crud.create_key_category(db_session, name="list_key_cat_2", type="test")
    crud.create_api_key(db_session, value="list_key_1", category_id=category1.id)
    crud.create_api_key(db_session, value="list_key_2", category_id=category1.id)
    crud.create_api_key(db_session, value="list_key_3", category_id=category2.id)

    all_keys = crud.get_api_keys(db_session)
    assert len(all_keys) >= 3

    category1_keys = crud.get_api_keys(db_session, category_id=category1.id)
    assert len(category1_keys) == 2

    active_keys = crud.get_api_keys(db_session, status="active")
    assert len(active_keys) >= 3 # All created keys are active by default

    inactive_keys = crud.get_api_keys(db_session, status="inactive")
    assert len(inactive_keys) == 0

def test_update_api_key(db_session: Session):
    category1 = crud.create_key_category(db_session, name="update_key_cat_1", type="test")
    category2 = crud.create_key_category(db_session, name="update_key_cat_2", type="test")
    created_key = crud.create_api_key(db_session, value="old_key_value", category_id=category1.id, status="active")

    updated_key = crud.update_api_key(db_session, created_key.id, value="new_key_value", category_id=category2.id, status="inactive")
    assert updated_key.value == "new_key_value"
    assert updated_key.category_id == category2.id
    assert updated_key.status == "inactive"

def test_delete_api_key(db_session: Session):
    category = crud.create_key_category(db_session, name="delete_key_cat", type="test")
    created_key = crud.create_api_key(db_session, value="delete_test_key", category_id=category.id)
    deleted_key = crud.delete_api_key(db_session, created_key.id)
    assert deleted_key is not None
    retrieved_key = crud.get_api_key(db_session, created_key.id)
    assert retrieved_key is None

# --- Test Key Selection Logic Placeholder ---

def test_get_available_keys_for_category(db_session: Session):
    category = crud.create_key_category(db_session, name="available_key_cat", type="test")
    crud.create_api_key(db_session, value="key1", category_id=category.id, status="active")
    crud.create_api_key(db_session, value="key2", category_id=category.id, status="inactive")
    crud.create_api_key(db_session, value="key3", category_id=category.id, status="active")

    available_keys = crud.get_available_keys_for_category(db_session, category.id)
    assert len(available_keys) == 2
    assert all(key.status == "active" for key in available_keys)
    assert {key.value for key in available_keys} == {"key1", "key3"}

def test_select_key_from_pool_no_keys(db_session: Session):
    category = crud.create_key_category(db_session, name="empty_pool_cat", type="test")
    selected_key = crud.select_key_from_pool(db_session, category.id)
    assert selected_key is None

def test_select_key_from_pool_basic(db_session: Session):
    category = crud.create_key_category(db_session, name="basic_pool_cat", type="test")
    key1 = crud.create_api_key(db_session, value="key1", category_id=category.id, status="active")
    key2 = crud.create_api_key(db_session, value="key2", category_id=category.id, status="active")

    # Basic selection returns the first one found
    selected_key = crud.select_key_from_pool(db_session, category.id)
    assert selected_key is not None
    # The order might depend on DB implementation, but it should be one of the active keys
    assert selected_key.value in ["key1", "key2"]