MGTD-Demo / app.py
minemaster01's picture
Update app.py
71f5043 verified
raw
history blame
2.33 kB
import gradio as gr
import datetime
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import Dataset, DatasetDict, disable_caching
import pandas as pd
from huggingface_hub import HfApi, HfFolder
# CONFIG
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Change if needed
HF_DATASET_REPO = "your-username/your-logging-dataset" # Must be created beforehand
# Token from environment in Spaces
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
# Load model + tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Log entries
log_entries = []
def infer_and_log(text_input):
inputs = tokenizer(text_input, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.tolist()
predicted = torch.argmax(outputs.logits, dim=-1).item()
output_label = model.config.id2label[predicted]
log_entries.append({
"timestamp": datetime.datetime.now().isoformat(),
"input": text_input,
"logits": logits,
})
return output_label
def clear_fields():
return "", ""
def save_to_hf():
if not HF_TOKEN:
return "No Hugging Face token found in environment. Cannot push dataset."
if not log_entries:
return "No logs to push."
df = pd.DataFrame(log_entries)
dataset = Dataset.from_pandas(df)
dataset.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
log_entries.clear()
return f"Pushed {len(df)} logs to {HF_DATASET_REPO}!"
with gr.Blocks() as demo:
gr.Markdown("## 🤖 Text Classification with Logging")
with gr.Row():
input_box = gr.Textbox(label="Input Text", lines=4, interactive=True)
output_box = gr.Textbox(label="Predicted Label", lines=2)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
status_box = gr.Textbox(label="Status", interactive=False)
submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
gr.Button("Save Logs to HF Dataset").click(fn=save_to_hf, outputs=status_box)
if __name__ == "__main__":
demo.launch()