|
""" |
|
Data Processor for RAG System |
|
Processes WikiSQL dataset and prepares data for the RAG system. |
|
""" |
|
|
|
import json |
|
import os |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from pathlib import Path |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from loguru import logger |
|
|
|
class DataProcessor: |
|
"""Processes WikiSQL dataset for RAG system.""" |
|
|
|
def __init__(self, data_dir: str = "./data"): |
|
""" |
|
Initialize the data processor. |
|
|
|
Args: |
|
data_dir: Directory to store processed data |
|
""" |
|
self.data_dir = Path(data_dir) |
|
self.data_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.processed_data_path = self.data_dir / "processed_examples.json" |
|
self.vector_store_data_path = self.data_dir / "vector_store_data.json" |
|
self.statistics_path = self.data_dir / "data_statistics.json" |
|
|
|
logger.info(f"Data processor initialized at {self.data_dir}") |
|
|
|
def process_wikisql_dataset(self, |
|
max_examples: Optional[int] = None, |
|
split: str = "train") -> List[Dict[str, Any]]: |
|
""" |
|
Process WikiSQL dataset and prepare examples for RAG system. |
|
|
|
Args: |
|
max_examples: Maximum number of examples to process (None for all) |
|
split: Dataset split to use ('train', 'validation', 'test') |
|
|
|
Returns: |
|
List of processed examples |
|
""" |
|
try: |
|
logger.info(f"Loading WikiSQL {split} dataset...") |
|
|
|
|
|
dataset = load_dataset("wikisql", split=split) |
|
|
|
if max_examples: |
|
dataset = dataset.select(range(min(max_examples, len(dataset)))) |
|
|
|
logger.info(f"Processing {len(dataset)} examples...") |
|
|
|
|
|
processed_examples = [] |
|
for i, example in enumerate(dataset): |
|
processed_example = self._process_single_example(example, i) |
|
if processed_example: |
|
processed_examples.append(processed_example) |
|
|
|
|
|
if (i + 1) % 1000 == 0: |
|
logger.info(f"Processed {i + 1}/{len(dataset)} examples") |
|
|
|
|
|
self._save_processed_data(processed_examples) |
|
|
|
|
|
stats = self._generate_statistics(processed_examples) |
|
self._save_statistics(stats) |
|
|
|
logger.info(f"Successfully processed {len(processed_examples)} examples") |
|
return processed_examples |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing WikiSQL dataset: {e}") |
|
raise |
|
|
|
def _process_single_example(self, example: Dict[str, Any], index: int) -> Optional[Dict[str, Any]]: |
|
""" |
|
Process a single WikiSQL example. |
|
|
|
Args: |
|
example: Raw example from WikiSQL dataset |
|
index: Example index |
|
|
|
Returns: |
|
Processed example or None if invalid |
|
""" |
|
try: |
|
|
|
question = example.get("question", "").strip() |
|
table_headers = example.get("table", {}).get("header", []) |
|
sql_query = example.get("sql", {}).get("human_readable", "") |
|
|
|
|
|
if not question or not table_headers or not sql_query: |
|
return None |
|
|
|
|
|
question = self._clean_text(question) |
|
table_headers = [self._clean_text(h) for h in table_headers] |
|
sql_query = self._clean_sql(sql_query) |
|
|
|
|
|
complexity = self._assess_example_complexity(question, sql_query) |
|
category = self._categorize_example(question, sql_query) |
|
|
|
|
|
processed_example = { |
|
"example_id": f"wikisql_{index}", |
|
"question": question, |
|
"table_headers": table_headers, |
|
"sql": sql_query, |
|
"difficulty": complexity, |
|
"category": category, |
|
"metadata": { |
|
"source": "wikisql", |
|
"split": "train", |
|
"original_index": index, |
|
"table_name": example.get("table", {}).get("name", "unknown"), |
|
"question_type": self._classify_question_type(question), |
|
"sql_features": self._extract_sql_features(sql_query) |
|
} |
|
} |
|
|
|
return processed_example |
|
|
|
except Exception as e: |
|
logger.warning(f"Error processing example {index}: {e}") |
|
return None |
|
|
|
def _clean_text(self, text: str) -> str: |
|
"""Clean and normalize text.""" |
|
if not text: |
|
return "" |
|
|
|
|
|
text = " ".join(text.split()) |
|
|
|
|
|
text = text.replace('"', "'").replace('"', "'") |
|
|
|
return text.strip() |
|
|
|
def _clean_sql(self, sql: str) -> str: |
|
"""Clean and normalize SQL query.""" |
|
if not sql: |
|
return "" |
|
|
|
|
|
sql = " ".join(sql.split()) |
|
|
|
|
|
sql = sql.replace(" ,", ",").replace(", ", ",") |
|
sql = sql.replace(" (", "(").replace("( ", "(") |
|
sql = sql.replace(" )", ")").replace(") ", ")") |
|
|
|
|
|
if not sql.endswith(';'): |
|
sql += ';' |
|
|
|
return sql.strip() |
|
|
|
def _assess_example_complexity(self, question: str, sql: str) -> str: |
|
"""Assess the complexity of an example.""" |
|
complexity_score = 0 |
|
|
|
|
|
if len(question.split()) > 15: |
|
complexity_score += 2 |
|
elif len(question.split()) > 10: |
|
complexity_score += 1 |
|
|
|
|
|
sql_lower = sql.lower() |
|
if 'join' in sql_lower: |
|
complexity_score += 2 |
|
if 'group by' in sql_lower: |
|
complexity_score += 2 |
|
if 'having' in sql_lower: |
|
complexity_score += 2 |
|
if 'subquery' in sql_lower or '(' in sql_lower and ')' in sql_lower: |
|
complexity_score += 2 |
|
if 'union' in sql_lower or 'intersect' in sql_lower: |
|
complexity_score += 3 |
|
|
|
|
|
if complexity_score >= 6: |
|
return "hard" |
|
elif complexity_score >= 3: |
|
return "medium" |
|
else: |
|
return "easy" |
|
|
|
def _categorize_example(self, question: str, sql: str) -> str: |
|
"""Categorize the example based on question and SQL.""" |
|
question_lower = question.lower() |
|
sql_lower = sql.lower() |
|
|
|
|
|
if any(word in question_lower for word in ['count', 'how many', 'number of']): |
|
return "aggregation" |
|
elif any(word in question_lower for word in ['average', 'mean', 'sum', 'total']): |
|
return "aggregation" |
|
|
|
|
|
elif any(word in question_lower for word in ['group by', 'grouped', 'by department', 'by category']): |
|
return "grouping" |
|
|
|
|
|
elif any(word in question_lower for word in ['join', 'combine', 'merge', 'connect']): |
|
return "join" |
|
|
|
|
|
elif any(word in question_lower for word in ['order by', 'sort', 'rank', 'top', 'highest', 'lowest']): |
|
return "sorting" |
|
|
|
|
|
elif any(word in question_lower for word in ['where', 'filter', 'condition']): |
|
return "filtering" |
|
|
|
|
|
else: |
|
return "simple" |
|
|
|
def _classify_question_type(self, question: str) -> str: |
|
"""Classify the type of question.""" |
|
question_lower = question.lower() |
|
|
|
if '?' in question_lower: |
|
return "interrogative" |
|
elif any(word in question_lower for word in ['show', 'display', 'list']): |
|
return "display" |
|
elif any(word in question_lower for word in ['find', 'get', 'retrieve']): |
|
return "retrieval" |
|
else: |
|
return "statement" |
|
|
|
def _extract_sql_features(self, sql: str) -> List[str]: |
|
"""Extract SQL features from the query.""" |
|
features = [] |
|
sql_lower = sql.lower() |
|
|
|
if 'select' in sql_lower: |
|
features.append("select") |
|
if 'from' in sql_lower: |
|
features.append("from") |
|
if 'where' in sql_lower: |
|
features.append("where") |
|
if 'join' in sql_lower: |
|
features.append("join") |
|
if 'group by' in sql_lower: |
|
features.append("group_by") |
|
if 'having' in sql_lower: |
|
features.append("having") |
|
if 'order by' in sql_lower: |
|
features.append("order_by") |
|
if 'limit' in sql_lower: |
|
features.append("limit") |
|
if 'distinct' in sql_lower: |
|
features.append("distinct") |
|
if 'count(' in sql_lower: |
|
features.append("count_aggregation") |
|
if 'avg(' in sql_lower: |
|
features.append("avg_aggregation") |
|
if 'sum(' in sql_lower: |
|
features.append("sum_aggregation") |
|
|
|
return features |
|
|
|
def _save_processed_data(self, examples: List[Dict[str, Any]]) -> None: |
|
"""Save processed examples to file.""" |
|
try: |
|
with open(self.processed_data_path, 'w', encoding='utf-8') as f: |
|
json.dump(examples, f, indent=2, ensure_ascii=False) |
|
logger.info(f"Saved {len(examples)} processed examples to {self.processed_data_path}") |
|
except Exception as e: |
|
logger.error(f"Error saving processed data: {e}") |
|
|
|
def _save_statistics(self, stats: Dict[str, Any]) -> None: |
|
"""Save data statistics to file.""" |
|
try: |
|
with open(self.statistics_path, 'w', encoding='utf-8') as f: |
|
json.dump(stats, f, indent=2, ensure_ascii=False) |
|
logger.info(f"Saved statistics to {self.statistics_path}") |
|
except Exception as e: |
|
logger.error(f"Error saving statistics: {e}") |
|
|
|
def _generate_statistics(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
"""Generate comprehensive statistics about the processed data.""" |
|
if not examples: |
|
return {"error": "No examples to analyze"} |
|
|
|
|
|
total_examples = len(examples) |
|
|
|
|
|
difficulty_counts = {} |
|
for example in examples: |
|
difficulty = example.get("difficulty", "unknown") |
|
difficulty_counts[difficulty] = difficulty_counts.get(difficulty, 0) + 1 |
|
|
|
|
|
category_counts = {} |
|
for example in examples: |
|
category = example.get("category", "unknown") |
|
category_counts[category] = category_counts.get(category, 0) + 1 |
|
|
|
|
|
question_type_counts = {} |
|
for example in examples: |
|
question_type = example.get("metadata", {}).get("question_type", "unknown") |
|
question_type_counts[question_type] = question_type_counts.get(question_type, 0) + 1 |
|
|
|
|
|
sql_features_counts = {} |
|
for example in examples: |
|
features = example.get("metadata", {}).get("sql_features", []) |
|
for feature in features: |
|
sql_features_counts[feature] = sql_features_counts.get(feature, 0) + 1 |
|
|
|
|
|
table_sizes = [] |
|
for example in examples: |
|
headers = example.get("table_headers", []) |
|
table_sizes.append(len(headers)) |
|
|
|
avg_table_size = sum(table_sizes) / len(table_sizes) if table_sizes else 0 |
|
|
|
return { |
|
"total_examples": total_examples, |
|
"difficulty_distribution": difficulty_counts, |
|
"category_distribution": category_counts, |
|
"question_type_distribution": question_type_counts, |
|
"sql_features_distribution": sql_features_counts, |
|
"table_schema_stats": { |
|
"average_columns": avg_table_size, |
|
"min_columns": min(table_sizes) if table_sizes else 0, |
|
"max_columns": max(table_sizes) if table_sizes else 0 |
|
}, |
|
"data_quality": { |
|
"examples_with_questions": sum(1 for e in examples if e.get("question")), |
|
"examples_with_sql": sum(1 for e in examples if e.get("sql")), |
|
"examples_with_headers": sum(1 for e in examples if e.get("table_headers")) |
|
} |
|
} |
|
|
|
def load_processed_data(self) -> List[Dict[str, Any]]: |
|
"""Load previously processed data.""" |
|
try: |
|
if self.processed_data_path.exists(): |
|
with open(self.processed_data_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
logger.info(f"Loaded {len(data)} processed examples") |
|
return data |
|
else: |
|
logger.warning("No processed data found") |
|
return [] |
|
except Exception as e: |
|
logger.error(f"Error loading processed data: {e}") |
|
return [] |
|
|
|
def get_data_statistics(self) -> Dict[str, Any]: |
|
"""Get current data statistics.""" |
|
try: |
|
if self.statistics_path.exists(): |
|
with open(self.statistics_path, 'r', encoding='utf-8') as f: |
|
stats = json.load(f) |
|
return stats |
|
else: |
|
return {"error": "No statistics available"} |
|
except Exception as e: |
|
logger.error(f"Error loading statistics: {e}") |
|
return {"error": str(e)} |
|
|
|
def create_sample_dataset(self, num_examples: int = 100) -> List[Dict[str, Any]]: |
|
"""Create a small sample dataset for testing.""" |
|
sample_examples = [ |
|
{ |
|
"example_id": "sample_1", |
|
"question": "How many employees are older than 30?", |
|
"table_headers": ["id", "name", "age", "department", "salary"], |
|
"sql": "SELECT COUNT(*) FROM employees WHERE age > 30;", |
|
"difficulty": "easy", |
|
"category": "aggregation", |
|
"metadata": { |
|
"source": "sample", |
|
"question_type": "interrogative", |
|
"sql_features": ["select", "count_aggregation", "where"] |
|
} |
|
}, |
|
{ |
|
"example_id": "sample_2", |
|
"question": "Show all employees in IT department", |
|
"table_headers": ["id", "name", "age", "department", "salary"], |
|
"sql": "SELECT * FROM employees WHERE department = 'IT';", |
|
"difficulty": "easy", |
|
"category": "filtering", |
|
"metadata": { |
|
"source": "sample", |
|
"question_type": "display", |
|
"sql_features": ["select", "where"] |
|
} |
|
}, |
|
{ |
|
"example_id": "sample_3", |
|
"question": "What is the average salary by department?", |
|
"table_headers": ["id", "name", "age", "department", "salary"], |
|
"sql": "SELECT department, AVG(salary) FROM employees GROUP BY department;", |
|
"difficulty": "medium", |
|
"category": "grouping", |
|
"metadata": { |
|
"source": "sample", |
|
"question_type": "interrogative", |
|
"sql_features": ["select", "avg_aggregation", "group_by"] |
|
} |
|
} |
|
] |
|
|
|
|
|
while len(sample_examples) < num_examples: |
|
base_example = sample_examples[len(sample_examples) % 3] |
|
new_example = base_example.copy() |
|
new_example["example_id"] = f"sample_{len(sample_examples) + 1}" |
|
sample_examples.append(new_example) |
|
|
|
return sample_examples[:num_examples] |
|
|