MGTD-Demo / app.py
minemaster01's picture
Update app.py
349bcd6 verified
raw
history blame
2.65 kB
import gradio as gr
import datetime
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import Dataset, DatasetDict, disable_caching
import pandas as pd
from huggingface_hub import HfApi, HfFolder
# CONFIG
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Change if needed
HF_DATASET_REPO = "your-username/your-logging-dataset" # Must be created beforehand
# Token from environment in Spaces
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
# Load model + tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Log entries
log_entries = []
def setup_hf_dataset():
global DATASET_CREATED
if not DATASET_CREATED and HF_TOKEN:
try:
api = HfApi()
create_repo(DATASET_NAME, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
DATASET_CREATED = True
print(f"Dataset {DATASET_NAME} is ready")
except Exception as e: print(f"Error setting up dataset: {e}")
elif not HF_TOKEN:
print("Warning: HF_TOKEN not set. Data will be stored locally only.")
def infer_and_log(text_input):
inputs = tokenizer(text_input, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.tolist()
predicted = torch.argmax(outputs.logits, dim=-1).item()
output_label = model.config.id2label[predicted]
log_entries.append({
"timestamp": datetime.datetime.now().isoformat(),
"input": text_input,
"logits": logits,
})
return output_label
def clear_fields():
return "", ""
def save_to_hf():
if not HF_TOKEN:
return "No Hugging Face token found in environment. Cannot push dataset."
if not log_entries:
return "No logs to push."
df = pd.DataFrame(log_entries)
dataset = Dataset.from_pandas(df)
dataset.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
log_entries.clear()
return f"Pushed {len(df)} logs to {HF_DATASET_REPO}!"
with gr.Blocks() as demo:
gr.Markdown("## AI-generated text detector")
with gr.Row():
input_box = gr.Textbox(label="Input Text", lines=6, interactive=True)
output_box = gr.Textbox(label="Predicted Label", lines=6)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
if __name__ == "__main__":
demo.launch()