File size: 8,860 Bytes
21a86d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
Test the parquet module.

Mostly auto-generated by Cursor + GPT-5.
"""

import os
import tempfile
from typing import Any

import pandas as pd
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlmodel import Field, Session, SQLModel

from parquet import export_to_parquet, import_from_parquet


# Test model for creating temporary tables
class DummyUser(SQLModel, table=True):
    id: int = Field(primary_key=True)
    name: str = Field(max_length=100)
    email: str = Field(max_length=255)
    age: int = Field()


class DummyProduct(SQLModel, table=True):
    id: int = Field(primary_key=True)
    name: str = Field(max_length=200)
    price: float = Field()
    category: str = Field(max_length=100)


@pytest.fixture
def temp_db_engine():
    """Create a temporary SQLite database engine for testing."""
    # Create temporary database file
    temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
    temp_db.close()

    # Create engine
    engine = create_engine(f"sqlite:///{temp_db.name}")

    # Create tables
    SQLModel.metadata.create_all(engine)

    yield engine

    # Cleanup
    engine.dispose()
    os.unlink(temp_db.name)


@pytest.fixture
def sample_data():
    """Sample data for testing."""
    users_data = [
        {"id": 1, "name": "Alice", "email": "[email protected]", "age": 30},
        {"id": 2, "name": "Bob", "email": "[email protected]", "age": 25},
        {"id": 3, "name": "Charlie", "email": "[email protected]", "age": 35},
    ]

    products_data = [
        {"id": 1, "name": "Laptop", "price": 999.99, "category": "Electronics"},
        {"id": 2, "name": "Book", "price": 19.99, "category": "Education"},
        {"id": 3, "name": "Coffee", "price": 4.99, "category": "Food"},
    ]

    return {"users": users_data, "products": products_data}


@pytest.fixture
def populated_db(temp_db_engine: Engine, sample_data: dict[str, list[dict[str, Any]]]):
    """Populate the temporary database with sample data."""
    with Session(temp_db_engine) as session:
        # Insert users
        for user_data in sample_data["users"]:
            user = DummyUser(**user_data)
            session.add(user)

        # Insert products
        for product_data in sample_data["products"]:
            product = DummyProduct(**product_data)
            session.add(product)

        session.commit()

    return temp_db_engine


def test_export_to_parquet_success(
    populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]]
):
    """Test successful export of tables to parquet files."""
    with tempfile.TemporaryDirectory() as temp_dir:
        export_to_parquet(populated_db, temp_dir)

        # Check that files were created
        assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet"))
        assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet"))

        # Verify data integrity
        users_df = pd.read_parquet(os.path.join(temp_dir, "dummyuser.parquet"))
        products_df = pd.read_parquet(os.path.join(temp_dir, "dummyproduct.parquet"))

        assert len(users_df) == len(sample_data["users"])
        assert len(products_df) == len(sample_data["products"])

        # Check that data is sorted
        assert users_df.equals(
            users_df.sort_values(by=list(users_df.columns)).reset_index(drop=True)
        )
        assert products_df.equals(
            products_df.sort_values(by=list(products_df.columns)).reset_index(drop=True)
        )


def test_export_to_parquet_empty_table(temp_db_engine: Engine):
    """Test export with empty table."""
    with tempfile.TemporaryDirectory() as temp_dir:
        export_to_parquet(temp_db_engine, temp_dir)

        # Should create file but skip empty table
        assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet"))
        assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet"))


def test_export_to_parquet_creates_directory(populated_db):
    """Test that export creates the backup directory if it doesn't exist."""
    temp_dir = os.path.join(tempfile.gettempdir(), "test_backup_dir")

    try:
        export_to_parquet(populated_db, temp_dir)
        assert os.path.exists(temp_dir)
        assert os.path.isdir(temp_dir)
    finally:
        if os.path.exists(temp_dir):
            import shutil

            shutil.rmtree(temp_dir)


def test_import_from_parquet_success(
    populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]]
):
    """Test successful import from parquet files."""
    with tempfile.TemporaryDirectory() as temp_dir:
        # First export
        export_to_parquet(populated_db, temp_dir)

        # Clear the database
        with Session(populated_db) as session:
            session.exec(text("DELETE FROM dummyuser"))
            session.exec(text("DELETE FROM dummyproduct"))
            session.commit()

        # Verify tables are empty
        with Session(populated_db) as session:
            users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first()
            products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first()
            assert users[0] == 0
            assert products[0] == 0

        # Import from parquet
        import_from_parquet(populated_db, temp_dir)

        # Verify data was imported
        with Session(populated_db) as session:
            users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first()
            products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first()
            assert users[0] == len(sample_data["users"])
            assert products[0] == len(sample_data["products"])


def test_import_from_parquet_missing_file(populated_db: Engine):
    """Test import handles missing parquet files gracefully."""
    with tempfile.TemporaryDirectory() as temp_dir:
        # Don't create any parquet files
        import_from_parquet(populated_db, temp_dir)
        # Should not raise an error, just skip missing files


def test_import_from_parquet_clears_existing_data(populated_db: Engine):
    """Test that import clears existing data before inserting new data."""
    with tempfile.TemporaryDirectory() as temp_dir:
        # First export
        export_to_parquet(populated_db, temp_dir)

        # Modify data in database
        with Session(populated_db) as session:
            session.exec(text("UPDATE dummyuser SET name = 'Modified' WHERE id = 1"))
            session.commit()

        # Verify modification
        with Session(populated_db) as session:
            result = session.exec(
                text("SELECT name FROM dummyuser WHERE id = 1")
            ).first()
            assert result[0] == "Modified"

        # Import should clear and restore original data
        import_from_parquet(populated_db, temp_dir)

        # Original name restored
        with Session(populated_db) as session:
            result = session.exec(
                text("SELECT name FROM dummyuser WHERE id = 1")
            ).first()
            assert result[0] == "Alice"


def test_export_import_cycle(
    populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]]
):
    """Test complete export and import cycle maintains data integrity."""
    with tempfile.TemporaryDirectory() as temp_dir:
        # Export
        export_to_parquet(populated_db, temp_dir)

        # Clear database
        with Session(populated_db) as session:
            session.exec(text("DELETE FROM dummyuser"))
            session.exec(text("DELETE FROM dummyproduct"))
            session.commit()

        # Import
        import_from_parquet(populated_db, temp_dir)

        # Verify data integrity
        with Session(populated_db) as session:
            # Check users
            users_result = session.exec(
                text("SELECT * FROM dummyuser ORDER BY id")
            ).fetchall()
            assert len(users_result) == len(sample_data["users"])

            for i, user in enumerate(users_result):
                assert user[0] == sample_data["users"][i]["id"]
                assert user[1] == sample_data["users"][i]["name"]
                assert user[2] == sample_data["users"][i]["email"]
                assert user[3] == sample_data["users"][i]["age"]

            # Check products
            products_result = session.exec(
                text("SELECT * FROM dummyproduct ORDER BY id")
            ).fetchall()
            assert len(products_result) == len(sample_data["products"])

            for i, product in enumerate(products_result):
                assert product[0] == sample_data["products"][i]["id"]
                assert product[1] == sample_data["products"][i]["name"]
                assert product[2] == sample_data["products"][i]["price"]
                assert product[3] == sample_data["products"][i]["category"]