Spaces:
Running
Running
import gradio as gr | |
from datasets import load_dataset | |
from huggingface_hub import HfApi, list_repo_refs | |
import pandas as pd | |
import os | |
# Initialize HF API | |
api = HfApi(token=os.getenv("HF_TOKEN")) | |
def get_branches(repo_id="lvwerra/fineweb-ultra"): | |
"""Get all branches from the dataset repository""" | |
try: | |
refs = list_repo_refs(repo_id, repo_type="dataset", token=os.getenv("HF_TOKEN")) | |
branches = [branch.name for branch in refs.branches if branch.name != "main"] | |
# Sort branches by timestamp (newest first) | |
branches.sort(reverse=True) | |
return branches | |
except Exception as e: | |
print(f"Error fetching branches: {e}") | |
return [] | |
def load_branch_data(repo_id, branch_name): | |
"""Load dataset from a specific branch""" | |
try: | |
dataset = load_dataset(repo_id, revision=branch_name, split="train") | |
return dataset | |
except Exception as e: | |
print(f"Error loading branch {branch_name}: {e}") | |
return None | |
def update_branch_dropdown(): | |
"""Update the branch dropdown with available branches""" | |
branches = get_branches() | |
if branches: | |
return gr.Dropdown(choices=branches, value=branches[0], label="Select Branch") | |
else: | |
return gr.Dropdown(choices=[], value=None, label="No branches found") | |
def load_dataset_for_branch(branch_name): | |
"""Load dataset when branch is selected""" | |
if not branch_name: | |
return None, gr.Slider(maximum=0, value=0), "", "" | |
dataset = load_branch_data("lvwerra/fineweb-ultra", branch_name) | |
if dataset is None: | |
return None, gr.Slider(maximum=0, value=0), "Error loading dataset", "Error loading dataset" | |
max_samples = len(dataset) - 1 | |
# Load first sample | |
sample = dataset[0] | |
original_text = sample.get("text", sample.get("text", "No original text found")).replace('\n', '<br>') | |
rephrased_text = sample.get("ultra_text", "No rephrased text found") | |
return dataset, gr.Slider(maximum=max_samples, value=0, step=1, label=f"Sample Index (0-{max_samples})"), original_text, rephrased_text | |
def update_sample(dataset, sample_idx): | |
"""Update the text display when slider changes""" | |
if dataset is None or sample_idx >= len(dataset): | |
return "No data available", "No data available" | |
sample = dataset[int(sample_idx)] | |
original_text = sample.get("text", sample.get("text", "No original text found")).replace('\n', '<br>') | |
rephrased_text = sample.get("ultra_text", "No rephrased text found") | |
return original_text, rephrased_text | |
def format_text_for_display(text, title): | |
"""Format text with a title for better display""" | |
return f"## {title}\n\n{text}" | |
# Create Gradio interface | |
with gr.Blocks(title="Dataset Branch Viewer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Dataset Branch Viewer") | |
gr.Markdown("Compare original and rephrased text samples from different dataset branches") | |
# Store dataset in state | |
dataset_state = gr.State(value=None) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
refresh_btn = gr.Button("π Refresh Branches", variant="secondary") | |
branch_dropdown = gr.Dropdown( | |
choices=get_branches(), | |
value=get_branches()[0] if get_branches() else None, | |
label="Select Branch", | |
info="Choose a timestamp branch to view" | |
) | |
sample_slider = gr.Slider( | |
minimum=0, | |
maximum=0, | |
value=0, | |
step=1, | |
label="Sample Index", | |
info="Navigate through samples" | |
) | |
with gr.Row(): | |
gr.Markdown("### Sample Info") | |
sample_info = gr.Markdown("Select a branch to start") | |
with gr.Row(): | |
with gr.Column(): | |
original_display = gr.Markdown( | |
"## Original Text\n\nSelect a branch and sample to view content", | |
label="Original Text" | |
) | |
with gr.Column(): | |
rephrased_display = gr.Markdown( | |
"## Rephrased Text\n\nSelect a branch and sample to view content", | |
label="Rephrased Text" | |
) | |
# Event handlers | |
refresh_btn.click( | |
fn=update_branch_dropdown, | |
outputs=[branch_dropdown] | |
) | |
branch_dropdown.change( | |
fn=load_dataset_for_branch, | |
inputs=[branch_dropdown], | |
outputs=[dataset_state, sample_slider, original_display, rephrased_display] | |
) | |
sample_slider.change( | |
fn=update_sample, | |
inputs=[dataset_state, sample_slider], | |
outputs=[original_display, rephrased_display] | |
) | |
# Update sample info when slider changes | |
def update_sample_info(dataset, sample_idx): | |
if dataset is None: | |
return "No dataset loaded" | |
total_samples = len(dataset) | |
current_sample = int(sample_idx) | |
sample = dataset[current_sample] | |
sample_id = sample.get("id", "Unknown") | |
return f"**Sample {current_sample + 1} of {total_samples}** | ID: `{sample_id}`" | |
sample_slider.change( | |
fn=update_sample_info, | |
inputs=[dataset_state, sample_slider], | |
outputs=[sample_info] | |
) | |
# Load initial data if branches exist | |
initial_branches = get_branches() | |
if initial_branches: | |
demo.load( | |
fn=load_dataset_for_branch, | |
inputs=[gr.State(initial_branches[0])], | |
outputs=[dataset_state, sample_slider, original_display, rephrased_display] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |