Wauplin's picture
Wauplin HF Staff
Parquet export
21a86d3 verified
"""
Test the parquet module.
Mostly auto-generated by Cursor + GPT-5.
"""
import os
import tempfile
from typing import Any
import pandas as pd
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlmodel import Field, Session, SQLModel
from parquet import export_to_parquet, import_from_parquet
# Test model for creating temporary tables
class DummyUser(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str = Field(max_length=100)
email: str = Field(max_length=255)
age: int = Field()
class DummyProduct(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str = Field(max_length=200)
price: float = Field()
category: str = Field(max_length=100)
@pytest.fixture
def temp_db_engine():
"""Create a temporary SQLite database engine for testing."""
# Create temporary database file
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
temp_db.close()
# Create engine
engine = create_engine(f"sqlite:///{temp_db.name}")
# Create tables
SQLModel.metadata.create_all(engine)
yield engine
# Cleanup
engine.dispose()
os.unlink(temp_db.name)
@pytest.fixture
def sample_data():
"""Sample data for testing."""
users_data = [
{"id": 1, "name": "Alice", "email": "[email protected]", "age": 30},
{"id": 2, "name": "Bob", "email": "[email protected]", "age": 25},
{"id": 3, "name": "Charlie", "email": "[email protected]", "age": 35},
]
products_data = [
{"id": 1, "name": "Laptop", "price": 999.99, "category": "Electronics"},
{"id": 2, "name": "Book", "price": 19.99, "category": "Education"},
{"id": 3, "name": "Coffee", "price": 4.99, "category": "Food"},
]
return {"users": users_data, "products": products_data}
@pytest.fixture
def populated_db(temp_db_engine: Engine, sample_data: dict[str, list[dict[str, Any]]]):
"""Populate the temporary database with sample data."""
with Session(temp_db_engine) as session:
# Insert users
for user_data in sample_data["users"]:
user = DummyUser(**user_data)
session.add(user)
# Insert products
for product_data in sample_data["products"]:
product = DummyProduct(**product_data)
session.add(product)
session.commit()
return temp_db_engine
def test_export_to_parquet_success(
populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]]
):
"""Test successful export of tables to parquet files."""
with tempfile.TemporaryDirectory() as temp_dir:
export_to_parquet(populated_db, temp_dir)
# Check that files were created
assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet"))
assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet"))
# Verify data integrity
users_df = pd.read_parquet(os.path.join(temp_dir, "dummyuser.parquet"))
products_df = pd.read_parquet(os.path.join(temp_dir, "dummyproduct.parquet"))
assert len(users_df) == len(sample_data["users"])
assert len(products_df) == len(sample_data["products"])
# Check that data is sorted
assert users_df.equals(
users_df.sort_values(by=list(users_df.columns)).reset_index(drop=True)
)
assert products_df.equals(
products_df.sort_values(by=list(products_df.columns)).reset_index(drop=True)
)
def test_export_to_parquet_empty_table(temp_db_engine: Engine):
"""Test export with empty table."""
with tempfile.TemporaryDirectory() as temp_dir:
export_to_parquet(temp_db_engine, temp_dir)
# Should create file but skip empty table
assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet"))
assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet"))
def test_export_to_parquet_creates_directory(populated_db):
"""Test that export creates the backup directory if it doesn't exist."""
temp_dir = os.path.join(tempfile.gettempdir(), "test_backup_dir")
try:
export_to_parquet(populated_db, temp_dir)
assert os.path.exists(temp_dir)
assert os.path.isdir(temp_dir)
finally:
if os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir)
def test_import_from_parquet_success(
populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]]
):
"""Test successful import from parquet files."""
with tempfile.TemporaryDirectory() as temp_dir:
# First export
export_to_parquet(populated_db, temp_dir)
# Clear the database
with Session(populated_db) as session:
session.exec(text("DELETE FROM dummyuser"))
session.exec(text("DELETE FROM dummyproduct"))
session.commit()
# Verify tables are empty
with Session(populated_db) as session:
users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first()
products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first()
assert users[0] == 0
assert products[0] == 0
# Import from parquet
import_from_parquet(populated_db, temp_dir)
# Verify data was imported
with Session(populated_db) as session:
users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first()
products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first()
assert users[0] == len(sample_data["users"])
assert products[0] == len(sample_data["products"])
def test_import_from_parquet_missing_file(populated_db: Engine):
"""Test import handles missing parquet files gracefully."""
with tempfile.TemporaryDirectory() as temp_dir:
# Don't create any parquet files
import_from_parquet(populated_db, temp_dir)
# Should not raise an error, just skip missing files
def test_import_from_parquet_clears_existing_data(populated_db: Engine):
"""Test that import clears existing data before inserting new data."""
with tempfile.TemporaryDirectory() as temp_dir:
# First export
export_to_parquet(populated_db, temp_dir)
# Modify data in database
with Session(populated_db) as session:
session.exec(text("UPDATE dummyuser SET name = 'Modified' WHERE id = 1"))
session.commit()
# Verify modification
with Session(populated_db) as session:
result = session.exec(
text("SELECT name FROM dummyuser WHERE id = 1")
).first()
assert result[0] == "Modified"
# Import should clear and restore original data
import_from_parquet(populated_db, temp_dir)
# Original name restored
with Session(populated_db) as session:
result = session.exec(
text("SELECT name FROM dummyuser WHERE id = 1")
).first()
assert result[0] == "Alice"
def test_export_import_cycle(
populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]]
):
"""Test complete export and import cycle maintains data integrity."""
with tempfile.TemporaryDirectory() as temp_dir:
# Export
export_to_parquet(populated_db, temp_dir)
# Clear database
with Session(populated_db) as session:
session.exec(text("DELETE FROM dummyuser"))
session.exec(text("DELETE FROM dummyproduct"))
session.commit()
# Import
import_from_parquet(populated_db, temp_dir)
# Verify data integrity
with Session(populated_db) as session:
# Check users
users_result = session.exec(
text("SELECT * FROM dummyuser ORDER BY id")
).fetchall()
assert len(users_result) == len(sample_data["users"])
for i, user in enumerate(users_result):
assert user[0] == sample_data["users"][i]["id"]
assert user[1] == sample_data["users"][i]["name"]
assert user[2] == sample_data["users"][i]["email"]
assert user[3] == sample_data["users"][i]["age"]
# Check products
products_result = session.exec(
text("SELECT * FROM dummyproduct ORDER BY id")
).fetchall()
assert len(products_result) == len(sample_data["products"])
for i, product in enumerate(products_result):
assert product[0] == sample_data["products"][i]["id"]
assert product[1] == sample_data["products"][i]["name"]
assert product[2] == sample_data["products"][i]["price"]
assert product[3] == sample_data["products"][i]["category"]