|
import gradio as gr |
|
import pandas as pd |
|
import os |
|
from datetime import datetime |
|
import json |
|
from huggingface_hub import HfApi, create_repo, upload_file |
|
from datasets import Dataset |
|
|
|
|
|
INPUT_CSV = "summaries.csv" |
|
OUTPUT_CSV = "results.csv" |
|
TEMP_JSON = "temp_results.jsonl" |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", "") |
|
HF_DATASET_REPO = "boe-preference-summaries-results" |
|
HF_USERNAME = os.environ.get("HF_USERNAME", "") |
|
|
|
def load_data(): |
|
"""Load data from CSV file""" |
|
if os.path.exists(INPUT_CSV): |
|
return pd.read_csv(INPUT_CSV) |
|
else: |
|
|
|
return pd.DataFrame(columns=["id", "text", "summary_a", "summary_b"]) |
|
|
|
def initialize_hf_dataset(): |
|
"""Initialize a HuggingFace dataset repository if it doesn't exist""" |
|
if not HF_TOKEN or not HF_USERNAME: |
|
return False, "HuggingFace credentials not found. Please set HF_TOKEN and HF_USERNAME environment variables." |
|
|
|
try: |
|
api = HfApi(token=HF_TOKEN) |
|
repo_id = f"{HF_USERNAME}/{HF_DATASET_REPO}" |
|
|
|
|
|
try: |
|
api.repo_info(repo_id=repo_id, repo_type="dataset") |
|
print(f"Repository {repo_id} already exists") |
|
except Exception: |
|
print(f"Creating repository {repo_id}") |
|
create_repo(repo_id=repo_id, repo_type="dataset", token=HF_TOKEN) |
|
|
|
|
|
if not os.path.exists(TEMP_JSON): |
|
with open(TEMP_JSON, "w") as f: |
|
pass |
|
|
|
return True, f"{repo_id}" |
|
except Exception as e: |
|
return False, f"Error initializing HuggingFace dataset: {str(e)}" |
|
|
|
def push_to_hf_dataset(data_row): |
|
"""Push a new data row to the HuggingFace dataset""" |
|
if not HF_TOKEN or not HF_USERNAME: |
|
return False, "HuggingFace credentials not found" |
|
|
|
try: |
|
|
|
with open(TEMP_JSON, "a") as f: |
|
f.write(json.dumps(data_row) + "\n") |
|
|
|
|
|
api = HfApi(token=HF_TOKEN) |
|
repo_id = f"{HF_USERNAME}/{HF_DATASET_REPO}" |
|
|
|
|
|
upload_file( |
|
path_or_fileobj=TEMP_JSON, |
|
path_in_repo="data.jsonl", |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
token=HF_TOKEN |
|
) |
|
|
|
return True, f"Data pushed to {repo_id}" |
|
except Exception as e: |
|
return False, f"Error pushing to HuggingFace: {str(e)}" |
|
|
|
def save_choice(text_id, original_text, summary_a, summary_b, choice, notes="", request_id=""): |
|
"""Save the user's choice locally and to HuggingFace dataset""" |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
chosen_summary = "A" if choice == "Summary A" else "B" |
|
|
|
|
|
new_row = { |
|
"timestamp": timestamp, |
|
"text_id": text_id, |
|
"original_text": original_text, |
|
"summary_a": summary_a, |
|
"summary_b": summary_b, |
|
"chosen_summary": chosen_summary, |
|
"notes": notes, |
|
"request_id": request_id |
|
} |
|
|
|
|
|
if os.path.exists(OUTPUT_CSV): |
|
results_df = pd.read_csv(OUTPUT_CSV) |
|
results_df = pd.concat([results_df, pd.DataFrame([new_row])], ignore_index=True) |
|
else: |
|
results_df = pd.DataFrame([new_row]) |
|
|
|
results_df.to_csv(OUTPUT_CSV, index=False) |
|
|
|
|
|
success, message = push_to_hf_dataset(new_row) |
|
|
|
request_id_msg = f" (Request ID: {request_id})" if request_id else "" |
|
|
|
if success: |
|
return f"Selection saved for text ID: {text_id}{request_id_msg}! You chose {'Summary A' if choice == 'Summary A' else 'Summary B'}. Pushed to HuggingFace." |
|
else: |
|
return f"Selection saved locally for text ID: {text_id}{request_id_msg}. HuggingFace push failed: {message}" |
|
|
|
class SummaryChooser: |
|
def __init__(self): |
|
self.df = load_data() |
|
print(self.df) |
|
self.current_index = 0 |
|
self.total_items = len(self.df) |
|
print("Total items: ", self.total_items) |
|
self.hf_status = initialize_hf_dataset() |
|
self.request_id = "" |
|
|
|
def set_request_id(self, request: gr.Request): |
|
"""Set the request ID from the URL query parameters""" |
|
try: |
|
query_params = request.query_params |
|
self.request_id = query_params.get("id", "") |
|
return f"Request ID: {self.request_id}" if self.request_id else "No Request ID provided" |
|
except: |
|
self.request_id = "" |
|
return "Failed to get Request ID" |
|
|
|
def get_current_item(self): |
|
"""Get the current item from the dataframe""" |
|
if self.total_items == 0: |
|
return "", "", "", "", f"No data found in {INPUT_CSV}. Please check the file path." |
|
|
|
row = self.df.iloc[self.current_index] |
|
progress = f"Item {self.current_index + 1} of {self.total_items}" |
|
return row["id"], row["text"], row["summary_a"], row["summary_b"], progress |
|
|
|
def next_item(self, choice, notes): |
|
"""Save current choice and move to next item""" |
|
if self.total_items == 0: |
|
return "", "", "", "", "No data available", "" |
|
|
|
|
|
text_id, text, summary_a, summary_b, _ = self.get_current_item() |
|
|
|
|
|
result_message = save_choice(text_id, text, summary_a, summary_b, choice, notes, self.request_id) |
|
|
|
|
|
self.current_index = (self.current_index + 1) % self.total_items |
|
|
|
|
|
text_id, text, summary_a, summary_b, progress = self.get_current_item() |
|
return text_id, text, summary_a, summary_b, progress, result_message |
|
|
|
def prev_item(self): |
|
"""Move to previous item""" |
|
if self.total_items == 0: |
|
return "", "", "", "", "No data available", "" |
|
|
|
|
|
self.current_index = (self.current_index - 1) % self.total_items |
|
|
|
|
|
text_id, text, summary_a, summary_b, progress = self.get_current_item() |
|
return text_id, text, summary_a, summary_b, progress, "" |
|
|
|
def get_hf_status(self): |
|
"""Get the status of HuggingFace integration""" |
|
success, message = self.hf_status |
|
return f"{'Connected' if success else 'Not Connected'} - {message}" |
|
|
|
|
|
app = SummaryChooser() |
|
|
|
|
|
with gr.Blocks(title="Summary Chooser") as interface: |
|
gr.Markdown("# Summary Comparison Tool") |
|
gr.Markdown("Choose the better summary for each text") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
progress_label = gr.Label(label="Progress") |
|
|
|
with gr.Column(): |
|
hf_status = gr.Label(label="HuggingFace Status", value=app.get_hf_status()) |
|
|
|
with gr.Column(): |
|
request_id_label = gr.Label(label="Request ID") |
|
|
|
with gr.Column(): |
|
text_id_box = gr.Textbox(label="Text ID", interactive=False) |
|
|
|
with gr.Row(): |
|
text_box = gr.TextArea(label="Original Text", lines=8) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
summary_a = gr.TextArea(label="Summary A", lines=5) |
|
with gr.Column(): |
|
summary_b = gr.TextArea(label="Summary B", lines=5) |
|
|
|
with gr.Row(): |
|
choice_radio = gr.Radio( |
|
choices=["Summary A", "Summary B"], |
|
label="Select the better summary" |
|
) |
|
|
|
with gr.Row(): |
|
notes_box = gr.TextArea(label="Notes (optional)", lines=2) |
|
|
|
with gr.Row(): |
|
prev_button = gr.Button("Previous") |
|
submit_button = gr.Button("Submit and Next", variant="primary") |
|
|
|
with gr.Row(): |
|
result_box = gr.Textbox(label="Result") |
|
|
|
|
|
text_id, text, sum_a, sum_b, prog = app.get_current_item() |
|
text_id_box.value = text_id |
|
text_box.value = text |
|
summary_a.value = sum_a |
|
summary_b.value = sum_b |
|
progress_label.value = prog |
|
|
|
|
|
submit_button.click( |
|
fn=app.next_item, |
|
inputs=[choice_radio, notes_box], |
|
outputs=[text_id_box, text_box, summary_a, summary_b, progress_label, result_box] |
|
) |
|
|
|
prev_button.click( |
|
fn=app.prev_item, |
|
inputs=[], |
|
outputs=[text_id_box, text_box, summary_a, summary_b, progress_label, result_box] |
|
) |
|
|
|
|
|
interface.load( |
|
fn=app.set_request_id, |
|
inputs=[], |
|
outputs=[request_id_label] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch(ssr_mode=False,share=True) |