DeepMostInnovations's picture
Create clean_dataset.py
ab29941 verified
import pandas as pd
import numpy as np
import json
import os
import argparse
import logging
from tqdm import tqdm
import chardet
import csv
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("dataset_cleaner.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class SaaSDatasetCleaner:
"""
Class for cleaning and validating the SaaS sales conversation dataset.
Handles issues resulting from interrupted generations.
"""
def __init__(self, input_file, output_file=None, chunk_size=1000, encoding='utf-8', skip_encoding_check=False):
"""
Initialize the cleaner.
Args:
input_file: Path to the input CSV file
output_file: Path to save cleaned dataset (defaults to 'cleaned_' + input_file)
chunk_size: Number of rows to process at once
encoding: File encoding (defaults to utf-8)
skip_encoding_check: Whether to skip encoding detection and line-by-line processing
"""
self.input_file = input_file
self.output_file = output_file or f"cleaned_{os.path.basename(input_file)}"
self.chunk_size = chunk_size
self.encoding = encoding
self.skip_encoding_check = skip_encoding_check
self.stats = {
'total_rows': 0,
'valid_rows': 0,
'invalid_json': 0,
'missing_values': 0,
'invalid_embeddings': 0,
'duplicates': 0,
'encoding_errors': 0,
'recovered_rows': 0
}
# If not skipping encoding check, detect encoding
if not self.skip_encoding_check and not self.encoding:
self.detect_encoding()
# Get the columns and prepare for processing
self.initialize_columns()
def detect_encoding(self):
"""Detect the file encoding."""
logger.info("Detecting file encoding...")
# Read a sample of the file to detect encoding
with open(self.input_file, 'rb') as f:
sample = f.read(min(10000000, os.path.getsize(self.input_file))) # Read up to 10MB
result = chardet.detect(sample)
self.encoding = result['encoding']
confidence = result['confidence']
logger.info(f"Detected encoding: {self.encoding} with confidence: {confidence:.2f}")
# If confidence is low, try common encodings
if confidence < 0.7:
logger.warning(f"Low confidence in encoding detection. Will try multiple encodings.")
self.encoding = None # Will try multiple encodings later
def initialize_columns(self):
"""Initialize column information."""
# Try to read the header with different encodings if needed
encodings_to_try = ['utf-8'] if (self.skip_encoding_check or self.encoding) else ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
for enc in encodings_to_try:
try:
# Try to read just the header
with open(self.input_file, 'r', encoding=enc, errors='replace') as f:
reader = csv.reader(f)
self.columns = next(reader)
self.encoding = enc
logger.info(f"Successfully read header with encoding: {enc}")
# Identify embedding columns
self.embedding_cols = [col for col in self.columns if col.startswith('embedding_')]
logger.info(f"Found {len(self.embedding_cols)} embedding columns")
return
except Exception as e:
logger.warning(f"Failed to read header with encoding {enc}: {str(e)}")
# If we get here, all encodings failed
logger.error("Could not read column headers with any encoding")
self.columns = []
self.embedding_cols = []
def process_line_by_line(self):
"""Process the file line by line to handle encoding issues."""
logger.info("Processing file line by line to handle encoding issues...")
# Open the output file
with open(self.output_file, 'w', encoding='utf-8', newline='') as out_file:
writer = None # Will initialize after getting headers
# Process the input file
with open(self.input_file, 'rb') as in_file:
# Process line by line
line_count = 0
valid_count = 0
for line in tqdm(in_file, desc="Reading lines"):
line_count += 1
# Try to decode with multiple encodings
decoded_line = None
for enc in ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']:
try:
decoded_line = line.decode(enc)
break
except UnicodeDecodeError:
continue
if decoded_line is None:
# Could not decode with any encoding, skip line
self.stats['encoding_errors'] += 1
continue
# Parse the CSV line
try:
reader = csv.reader([decoded_line])
row = next(reader)
# Initialize writer with headers if this is the first line
if line_count == 1:
writer = csv.writer(out_file)
writer.writerow(row) # Write headers
continue
# Basic validation - check number of columns
if len(row) != len(self.columns):
logger.debug(f"Line {line_count}: Column count mismatch. Expected {len(self.columns)}, got {len(row)}")
continue
# Write the row
writer.writerow(row)
valid_count += 1
except Exception as e:
logger.debug(f"Error processing line {line_count}: {str(e)}")
self.stats['encoding_errors'] += 1
self.stats['total_rows'] = line_count - 1 # Subtract header
self.stats['recovered_rows'] = valid_count
logger.info(f"Processed {line_count} lines, recovered {valid_count} valid rows")
logger.info(f"Found {self.stats['encoding_errors']} lines with encoding errors")
def _validate_json_fields(self, df):
"""Validate and clean JSON fields."""
# List of columns that should contain JSON
json_columns = ['scenario', 'conversation', 'probability_trajectory']
for col in json_columns:
if col not in df.columns:
continue
# Create a valid indicator
df[f'{col}_valid'] = True
# Check each value
for idx, value in enumerate(df[col]):
try:
if pd.isna(value):
df.at[idx, f'{col}_valid'] = False
self.stats['invalid_json'] += 1
continue
# Attempt to parse JSON
json.loads(value)
except:
df.at[idx, f'{col}_valid'] = False
self.stats['invalid_json'] += 1
# Create an overall valid flag
valid_flags = [f'{col}_valid' for col in json_columns if f'{col}_valid' in df.columns]
if valid_flags:
df['json_valid'] = df[valid_flags].all(axis=1)
else:
df['json_valid'] = True
# Clean up the temporary columns
for col in json_columns:
if f'{col}_valid' in df.columns:
df = df.drop(columns=[f'{col}_valid'])
return df
def _validate_embeddings(self, df):
"""Check if embeddings are valid."""
if not self.embedding_cols:
return df
# Check if the first embedding column has a value as a simple check
if 'embedding_0' in df.columns:
df['embeddings_valid'] = ~df['embedding_0'].isna()
else:
df['embeddings_valid'] = True
# Count invalid embeddings
self.stats['invalid_embeddings'] += (~df['embeddings_valid']).sum()
return df
def _check_missing_values(self, df):
"""Check for missing values in important columns."""
important_cols = [
'company_id', 'company_name', 'product_name', 'conversation_id',
'conversation', 'full_text', 'outcome'
]
# Filter to columns that actually exist
important_cols = [col for col in important_cols if col in df.columns]
if not important_cols:
df['missing_important'] = False
return df
# Create a flag for rows with missing important values
missing_flags = df[important_cols].isna().any(axis=1)
df['missing_important'] = missing_flags
# Count missing values
self.stats['missing_values'] += missing_flags.sum()
return df
def _flag_valid_rows(self, df):
"""Create a single flag for valid rows."""
# A row is valid if it has valid JSON, valid embeddings, and no missing important values
required_flags = []
if 'json_valid' in df.columns:
required_flags.append('json_valid')
if 'embeddings_valid' in df.columns:
required_flags.append('embeddings_valid')
if 'missing_important' in df.columns:
required_flags.append('~missing_important')
if required_flags:
if '~missing_important' in required_flags:
required_flags.remove('~missing_important')
if required_flags:
df['row_valid'] = df[required_flags].all(axis=1) & ~df['missing_important']
else:
df['row_valid'] = ~df['missing_important']
else:
df['row_valid'] = df[required_flags].all(axis=1)
else:
df['row_valid'] = True
# Update valid rows count
self.stats['valid_rows'] += df['row_valid'].sum()
return df
def _remove_duplicates(self, df):
"""Remove duplicate conversation IDs."""
if 'conversation_id' in df.columns:
# Check for duplicates
dup_mask = df.duplicated(subset=['conversation_id'], keep='first')
df['is_duplicate'] = dup_mask
# Count duplicates
self.stats['duplicates'] += dup_mask.sum()
else:
df['is_duplicate'] = False
return df
def clean_dataset(self):
"""
Clean the dataset by first fixing encoding issues, then cleaning the data.
"""
logger.info(f"Starting to clean dataset: {self.input_file}")
# Check if the file exists
if not os.path.exists(self.input_file):
logger.error(f"Input file not found: {self.input_file}")
return
# If we're not skipping encoding checks, process line by line
if not self.skip_encoding_check:
self.process_line_by_line()
intermediate_file = self.output_file
self.output_file = f"validated_{os.path.basename(self.input_file)}"
else:
logger.info("Skipping encoding check as requested")
# Use the input file directly as the intermediate file
intermediate_file = self.input_file
# Count rows in the file for progress tracking
with open(intermediate_file, 'r', encoding=self.encoding) as f:
self.stats['total_rows'] = sum(1 for _ in f) - 1 # Subtract header
self.stats['recovered_rows'] = self.stats['total_rows']
logger.info(f"Total rows to validate: {self.stats['total_rows']}")
# Now that we have a cleaned file with proper encoding, process it for data validation
logger.info("Beginning data validation on recovered rows...")
# Get the total number of rows for progress tracking
try:
total_rows = self.stats['recovered_rows']
logger.info(f"Total rows to validate: {total_rows}")
except Exception as e:
logger.error(f"Error counting rows: {str(e)}")
total_rows = 0
# Process the dataset in chunks
try:
# Create a reader - now with known proper encoding
# Use error_bad_lines=False for older pandas versions (renamed to on_bad_lines in newer versions)
reader = pd.read_csv(
intermediate_file,
chunksize=self.chunk_size,
encoding='utf-8',
low_memory=False, # Avoid dtype warnings
error_bad_lines=False # Skip bad lines (older parameter name)
)
# Create a header flag for the first chunk
first_chunk = True
# Process each chunk
with tqdm(total=total_rows, desc="Validating data") as pbar:
for chunk_num, chunk in enumerate(reader):
logger.debug(f"Processing chunk {chunk_num+1}")
# Run validation steps
chunk = self._validate_json_fields(chunk)
chunk = self._validate_embeddings(chunk)
chunk = self._check_missing_values(chunk)
chunk = self._remove_duplicates(chunk)
chunk = self._flag_valid_rows(chunk)
# Filter to valid rows only
valid_chunk = chunk[chunk['row_valid'] & ~chunk['is_duplicate']]
# Remove the validation columns
for col in ['json_valid', 'embeddings_valid', 'missing_important', 'row_valid', 'is_duplicate']:
if col in valid_chunk.columns:
valid_chunk = valid_chunk.drop(columns=[col])
# Write the cleaned chunk
valid_chunk.to_csv(
self.output_file,
mode='w' if first_chunk else 'a',
header=first_chunk,
index=False,
encoding='utf-8'
)
# Update the first chunk flag
if first_chunk:
first_chunk = False
# Update progress
pbar.update(len(chunk))
logger.info(f"Dataset cleaning complete. Results saved to {self.output_file}")
# Print statistics
logger.info(f"Cleaning Statistics:")
logger.info(f"- Total rows processed: {self.stats['total_rows']}")
logger.info(f"- Rows recovered from encoding issues: {self.stats['recovered_rows']}")
logger.info(f"- Encoding errors: {self.stats['encoding_errors']}")
logger.info(f"- Valid rows after validation: {self.stats['valid_rows']}")
logger.info(f"- Rows with invalid JSON: {self.stats['invalid_json']}")
logger.info(f"- Rows with missing values: {self.stats['missing_values']}")
logger.info(f"- Rows with invalid embeddings: {self.stats['invalid_embeddings']}")
logger.info(f"- Duplicate rows: {self.stats['duplicates']}")
# Create a summary file
with open(f"{self.output_file}_summary.txt", 'w') as f:
f.write("Dataset Cleaning Summary\n")
f.write("=======================\n\n")
f.write(f"Input file: {self.input_file}\n")
f.write(f"Output file: {self.output_file}\n\n")
f.write(f"Total rows processed: {self.stats['total_rows']}\n")
f.write(f"Rows recovered from encoding issues: {self.stats['recovered_rows']}\n")
f.write(f"Encoding errors: {self.stats['encoding_errors']}\n")
f.write(f"Valid rows after validation: {self.stats['valid_rows']}\n")
f.write(f"Rows with invalid JSON: {self.stats['invalid_json']}\n")
f.write(f"Rows with missing values: {self.stats['missing_values']}\n")
f.write(f"Rows with invalid embeddings: {self.stats['invalid_embeddings']}\n")
f.write(f"Duplicate rows: {self.stats['duplicates']}\n")
return self.stats
except Exception as e:
logger.error(f"Error validating dataset: {str(e)}")
raise e
def main():
"""Main function to run the dataset cleaner."""
parser = argparse.ArgumentParser(description="Clean and validate SaaS sales conversation dataset")
parser.add_argument("input_file", type=str, help="Path to the input CSV file")
parser.add_argument("--output_file", type=str, default=None,
help="Path to save cleaned dataset (defaults to 'cleaned_' + input_file)")
parser.add_argument("--chunk_size", type=int, default=1000,
help="Number of rows to process at once")
parser.add_argument("--encoding", type=str, default='utf-8',
help="File encoding (defaults to utf-8)")
parser.add_argument("--skip_encoding_check", action="store_true",
help="Skip encoding detection and line-by-line processing")
args = parser.parse_args()
# Create and run the cleaner
cleaner = SaaSDatasetCleaner(
input_file=args.input_file,
output_file=args.output_file,
chunk_size=args.chunk_size,
encoding=args.encoding,
skip_encoding_check=args.skip_encoding_check
)
cleaner.clean_dataset()
if __name__ == "__main__":
main()