Spaces:
Paused
Paused
from typing import List, Dict, Any | |
from datasets import load_dataset, Dataset | |
import logging | |
class DatasetManagementService: | |
def __init__(self, dataset_name: str): | |
self.dataset_name = dataset_name | |
def update_dataset(self, new_metadata: List[Dict[str, Any]]) -> str: | |
try: | |
# Try to load the existing dataset | |
try: | |
dataset = load_dataset(self.dataset_name, split="train") | |
current_data = dataset.to_dict() | |
except Exception: | |
# If loading fails, start with an empty dictionary | |
current_data = {} | |
# If the dataset is empty, initialize it with the structure from new_metadata | |
if not current_data: | |
current_data = {key: [] for key in new_metadata[0].keys()} | |
updated = False | |
for paper in new_metadata: | |
entry_id = paper['entry_id'] | |
if 'entry_id' not in current_data: | |
current_data['entry_id'] = [] | |
if entry_id not in current_data['entry_id']: | |
# Add new paper | |
for key, value in paper.items(): | |
current_data.setdefault(key, []).append(value) | |
updated = True | |
else: | |
# Update existing paper | |
index = current_data['entry_id'].index(entry_id) | |
for key, value in paper.items(): | |
if current_data[key][index] != value: | |
current_data[key][index] = value | |
updated = True | |
if updated: | |
updated_dataset = Dataset.from_dict(current_data) | |
updated_dataset.push_to_hub(self.dataset_name, split="train") | |
return f"Successfully updated dataset with {len(new_metadata)} papers" | |
else: | |
return "No new data to update." | |
except Exception as e: | |
return f"Failed to update dataset: {str(e)}" | |
def get_dataset_size(self) -> int: | |
try: | |
dataset = load_dataset(self.dataset_name, split="train") | |
size = len(dataset) | |
logging.info(f"Dataset size: {size}") | |
return size | |
except Exception as e: | |
logging.error(f"Error getting dataset size: {str(e)}") | |
return 0 | |
def get_dataset_records(self, page: int, page_size: int) -> List[Dict[str, Any]]: | |
try: | |
dataset = load_dataset(self.dataset_name, split="train") | |
start_idx = (page - 1) * page_size | |
end_idx = start_idx + page_size | |
records = dataset[start_idx:end_idx] | |
# Convert to list of dictionaries | |
records_list = [dict(zip(records.keys(), values)) for values in zip(*records.values())] | |
logging.info(f"Records type: {type(records_list)}") | |
logging.info(f"Number of records: {len(records_list)}") | |
return records_list | |
except Exception as e: | |
logging.error(f"Error loading dataset records: {str(e)}") | |
return [{"error": f"Error loading dataset: {str(e)}"}] | |
# Usage: | |
# dataset_service = DatasetManagementService("your_dataset_name") | |
# result = dataset_service.update_dataset(new_metadata) | |
# records = dataset_service.get_dataset_records() |