|
""" |
|
Test the parquet module. |
|
|
|
Mostly auto-generated by Cursor + GPT-5. |
|
""" |
|
|
|
import os |
|
import tempfile |
|
from typing import Any |
|
|
|
import pandas as pd |
|
import pytest |
|
from sqlalchemy import create_engine, text |
|
from sqlalchemy.engine import Engine |
|
from sqlmodel import Field, Session, SQLModel |
|
|
|
from parquet import export_to_parquet, import_from_parquet |
|
|
|
|
|
|
|
class DummyUser(SQLModel, table=True): |
|
id: int = Field(primary_key=True) |
|
name: str = Field(max_length=100) |
|
email: str = Field(max_length=255) |
|
age: int = Field() |
|
|
|
|
|
class DummyProduct(SQLModel, table=True): |
|
id: int = Field(primary_key=True) |
|
name: str = Field(max_length=200) |
|
price: float = Field() |
|
category: str = Field(max_length=100) |
|
|
|
|
|
@pytest.fixture |
|
def temp_db_engine(): |
|
"""Create a temporary SQLite database engine for testing.""" |
|
|
|
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") |
|
temp_db.close() |
|
|
|
|
|
engine = create_engine(f"sqlite:///{temp_db.name}") |
|
|
|
|
|
SQLModel.metadata.create_all(engine) |
|
|
|
yield engine |
|
|
|
|
|
engine.dispose() |
|
os.unlink(temp_db.name) |
|
|
|
|
|
@pytest.fixture |
|
def sample_data(): |
|
"""Sample data for testing.""" |
|
users_data = [ |
|
{"id": 1, "name": "Alice", "email": "[email protected]", "age": 30}, |
|
{"id": 2, "name": "Bob", "email": "[email protected]", "age": 25}, |
|
{"id": 3, "name": "Charlie", "email": "[email protected]", "age": 35}, |
|
] |
|
|
|
products_data = [ |
|
{"id": 1, "name": "Laptop", "price": 999.99, "category": "Electronics"}, |
|
{"id": 2, "name": "Book", "price": 19.99, "category": "Education"}, |
|
{"id": 3, "name": "Coffee", "price": 4.99, "category": "Food"}, |
|
] |
|
|
|
return {"users": users_data, "products": products_data} |
|
|
|
|
|
@pytest.fixture |
|
def populated_db(temp_db_engine: Engine, sample_data: dict[str, list[dict[str, Any]]]): |
|
"""Populate the temporary database with sample data.""" |
|
with Session(temp_db_engine) as session: |
|
|
|
for user_data in sample_data["users"]: |
|
user = DummyUser(**user_data) |
|
session.add(user) |
|
|
|
|
|
for product_data in sample_data["products"]: |
|
product = DummyProduct(**product_data) |
|
session.add(product) |
|
|
|
session.commit() |
|
|
|
return temp_db_engine |
|
|
|
|
|
def test_export_to_parquet_success( |
|
populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] |
|
): |
|
"""Test successful export of tables to parquet files.""" |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
export_to_parquet(populated_db, temp_dir) |
|
|
|
|
|
assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet")) |
|
assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet")) |
|
|
|
|
|
users_df = pd.read_parquet(os.path.join(temp_dir, "dummyuser.parquet")) |
|
products_df = pd.read_parquet(os.path.join(temp_dir, "dummyproduct.parquet")) |
|
|
|
assert len(users_df) == len(sample_data["users"]) |
|
assert len(products_df) == len(sample_data["products"]) |
|
|
|
|
|
assert users_df.equals( |
|
users_df.sort_values(by=list(users_df.columns)).reset_index(drop=True) |
|
) |
|
assert products_df.equals( |
|
products_df.sort_values(by=list(products_df.columns)).reset_index(drop=True) |
|
) |
|
|
|
|
|
def test_export_to_parquet_empty_table(temp_db_engine: Engine): |
|
"""Test export with empty table.""" |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
export_to_parquet(temp_db_engine, temp_dir) |
|
|
|
|
|
assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet")) |
|
assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet")) |
|
|
|
|
|
def test_export_to_parquet_creates_directory(populated_db): |
|
"""Test that export creates the backup directory if it doesn't exist.""" |
|
temp_dir = os.path.join(tempfile.gettempdir(), "test_backup_dir") |
|
|
|
try: |
|
export_to_parquet(populated_db, temp_dir) |
|
assert os.path.exists(temp_dir) |
|
assert os.path.isdir(temp_dir) |
|
finally: |
|
if os.path.exists(temp_dir): |
|
import shutil |
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
|
|
def test_import_from_parquet_success( |
|
populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] |
|
): |
|
"""Test successful import from parquet files.""" |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
export_to_parquet(populated_db, temp_dir) |
|
|
|
|
|
with Session(populated_db) as session: |
|
session.exec(text("DELETE FROM dummyuser")) |
|
session.exec(text("DELETE FROM dummyproduct")) |
|
session.commit() |
|
|
|
|
|
with Session(populated_db) as session: |
|
users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first() |
|
products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first() |
|
assert users[0] == 0 |
|
assert products[0] == 0 |
|
|
|
|
|
import_from_parquet(populated_db, temp_dir) |
|
|
|
|
|
with Session(populated_db) as session: |
|
users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first() |
|
products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first() |
|
assert users[0] == len(sample_data["users"]) |
|
assert products[0] == len(sample_data["products"]) |
|
|
|
|
|
def test_import_from_parquet_missing_file(populated_db: Engine): |
|
"""Test import handles missing parquet files gracefully.""" |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
import_from_parquet(populated_db, temp_dir) |
|
|
|
|
|
|
|
def test_import_from_parquet_clears_existing_data(populated_db: Engine): |
|
"""Test that import clears existing data before inserting new data.""" |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
export_to_parquet(populated_db, temp_dir) |
|
|
|
|
|
with Session(populated_db) as session: |
|
session.exec(text("UPDATE dummyuser SET name = 'Modified' WHERE id = 1")) |
|
session.commit() |
|
|
|
|
|
with Session(populated_db) as session: |
|
result = session.exec( |
|
text("SELECT name FROM dummyuser WHERE id = 1") |
|
).first() |
|
assert result[0] == "Modified" |
|
|
|
|
|
import_from_parquet(populated_db, temp_dir) |
|
|
|
|
|
with Session(populated_db) as session: |
|
result = session.exec( |
|
text("SELECT name FROM dummyuser WHERE id = 1") |
|
).first() |
|
assert result[0] == "Alice" |
|
|
|
|
|
def test_export_import_cycle( |
|
populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] |
|
): |
|
"""Test complete export and import cycle maintains data integrity.""" |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
export_to_parquet(populated_db, temp_dir) |
|
|
|
|
|
with Session(populated_db) as session: |
|
session.exec(text("DELETE FROM dummyuser")) |
|
session.exec(text("DELETE FROM dummyproduct")) |
|
session.commit() |
|
|
|
|
|
import_from_parquet(populated_db, temp_dir) |
|
|
|
|
|
with Session(populated_db) as session: |
|
|
|
users_result = session.exec( |
|
text("SELECT * FROM dummyuser ORDER BY id") |
|
).fetchall() |
|
assert len(users_result) == len(sample_data["users"]) |
|
|
|
for i, user in enumerate(users_result): |
|
assert user[0] == sample_data["users"][i]["id"] |
|
assert user[1] == sample_data["users"][i]["name"] |
|
assert user[2] == sample_data["users"][i]["email"] |
|
assert user[3] == sample_data["users"][i]["age"] |
|
|
|
|
|
products_result = session.exec( |
|
text("SELECT * FROM dummyproduct ORDER BY id") |
|
).fetchall() |
|
assert len(products_result) == len(sample_data["products"]) |
|
|
|
for i, product in enumerate(products_result): |
|
assert product[0] == sample_data["products"][i]["id"] |
|
assert product[1] == sample_data["products"][i]["name"] |
|
assert product[2] == sample_data["products"][i]["price"] |
|
assert product[3] == sample_data["products"][i]["category"] |
|
|