lvwerra's picture
lvwerra HF Staff
Update app.py
aefc510 verified
import gradio as gr
from datasets import load_dataset
from huggingface_hub import HfApi, list_repo_refs
import pandas as pd
import os
# Initialize HF API
api = HfApi(token=os.getenv("HF_TOKEN"))
def get_branches(repo_id="lvwerra/fineweb-ultra"):
"""Get all branches from the dataset repository"""
try:
refs = list_repo_refs(repo_id, repo_type="dataset", token=os.getenv("HF_TOKEN"))
branches = [branch.name for branch in refs.branches if branch.name != "main"]
# Sort branches by timestamp (newest first)
branches.sort(reverse=True)
return branches
except Exception as e:
print(f"Error fetching branches: {e}")
return []
def load_branch_data(repo_id, branch_name):
"""Load dataset from a specific branch"""
try:
dataset = load_dataset(repo_id, revision=branch_name, split="train")
return dataset
except Exception as e:
print(f"Error loading branch {branch_name}: {e}")
return None
def update_branch_dropdown():
"""Update the branch dropdown with available branches"""
branches = get_branches()
if branches:
return gr.Dropdown(choices=branches, value=branches[0], label="Select Branch")
else:
return gr.Dropdown(choices=[], value=None, label="No branches found")
def load_dataset_for_branch(branch_name):
"""Load dataset when branch is selected"""
if not branch_name:
return None, gr.Slider(maximum=0, value=0), "", ""
dataset = load_branch_data("lvwerra/fineweb-ultra", branch_name)
if dataset is None:
return None, gr.Slider(maximum=0, value=0), "Error loading dataset", "Error loading dataset"
max_samples = len(dataset) - 1
# Load first sample
sample = dataset[0]
original_text = sample.get("text", sample.get("text", "No original text found")).replace('\n', '<br>')
rephrased_text = sample.get("ultra_text", "No rephrased text found")
return dataset, gr.Slider(maximum=max_samples, value=0, step=1, label=f"Sample Index (0-{max_samples})"), original_text, rephrased_text
def update_sample(dataset, sample_idx):
"""Update the text display when slider changes"""
if dataset is None or sample_idx >= len(dataset):
return "No data available", "No data available"
sample = dataset[int(sample_idx)]
original_text = sample.get("text", sample.get("text", "No original text found")).replace('\n', '<br>')
rephrased_text = sample.get("ultra_text", "No rephrased text found")
return original_text, rephrased_text
def format_text_for_display(text, title):
"""Format text with a title for better display"""
return f"## {title}\n\n{text}"
# Create Gradio interface
with gr.Blocks(title="Dataset Branch Viewer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Dataset Branch Viewer")
gr.Markdown("Compare original and rephrased text samples from different dataset branches")
# Store dataset in state
dataset_state = gr.State(value=None)
with gr.Row():
with gr.Column(scale=1):
refresh_btn = gr.Button("πŸ”„ Refresh Branches", variant="secondary")
branch_dropdown = gr.Dropdown(
choices=get_branches(),
value=get_branches()[0] if get_branches() else None,
label="Select Branch",
info="Choose a timestamp branch to view"
)
sample_slider = gr.Slider(
minimum=0,
maximum=0,
value=0,
step=1,
label="Sample Index",
info="Navigate through samples"
)
with gr.Row():
gr.Markdown("### Sample Info")
sample_info = gr.Markdown("Select a branch to start")
with gr.Row():
with gr.Column():
original_display = gr.Markdown(
"## Original Text\n\nSelect a branch and sample to view content",
label="Original Text"
)
with gr.Column():
rephrased_display = gr.Markdown(
"## Rephrased Text\n\nSelect a branch and sample to view content",
label="Rephrased Text"
)
# Event handlers
refresh_btn.click(
fn=update_branch_dropdown,
outputs=[branch_dropdown]
)
branch_dropdown.change(
fn=load_dataset_for_branch,
inputs=[branch_dropdown],
outputs=[dataset_state, sample_slider, original_display, rephrased_display]
)
sample_slider.change(
fn=update_sample,
inputs=[dataset_state, sample_slider],
outputs=[original_display, rephrased_display]
)
# Update sample info when slider changes
def update_sample_info(dataset, sample_idx):
if dataset is None:
return "No dataset loaded"
total_samples = len(dataset)
current_sample = int(sample_idx)
sample = dataset[current_sample]
sample_id = sample.get("id", "Unknown")
return f"**Sample {current_sample + 1} of {total_samples}** | ID: `{sample_id}`"
sample_slider.change(
fn=update_sample_info,
inputs=[dataset_state, sample_slider],
outputs=[sample_info]
)
# Load initial data if branches exist
initial_branches = get_branches()
if initial_branches:
demo.load(
fn=load_dataset_for_branch,
inputs=[gr.State(initial_branches[0])],
outputs=[dataset_state, sample_slider, original_display, rephrased_display]
)
# Launch the app
if __name__ == "__main__":
demo.launch()