|
import gradio as gr |
|
import datetime |
|
import torch |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from datasets import Dataset, DatasetDict, disable_caching |
|
import pandas as pd |
|
from huggingface_hub import HfApi, HfFolder |
|
|
|
|
|
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" |
|
HF_DATASET_REPO = "your-username/your-logging-dataset" |
|
|
|
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
|
log_entries = [] |
|
|
|
def setup_hf_dataset(): |
|
global DATASET_CREATED |
|
if not DATASET_CREATED and HF_TOKEN: |
|
try: |
|
api = HfApi() |
|
create_repo(DATASET_NAME, repo_type="dataset", token=HF_TOKEN, exist_ok=True) |
|
DATASET_CREATED = True |
|
print(f"Dataset {DATASET_NAME} is ready") |
|
except Exception as e: print(f"Error setting up dataset: {e}") |
|
elif not HF_TOKEN: |
|
print("Warning: HF_TOKEN not set. Data will be stored locally only.") |
|
|
|
def infer_and_log(text_input): |
|
inputs = tokenizer(text_input, return_tensors="pt", truncation=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits.tolist() |
|
predicted = torch.argmax(outputs.logits, dim=-1).item() |
|
output_label = model.config.id2label[predicted] |
|
|
|
log_entries.append({ |
|
"timestamp": datetime.datetime.now().isoformat(), |
|
"input": text_input, |
|
"logits": logits, |
|
}) |
|
|
|
return output_label |
|
|
|
def clear_fields(): |
|
return "", "" |
|
|
|
def save_to_hf(): |
|
if not HF_TOKEN: |
|
return "No Hugging Face token found in environment. Cannot push dataset." |
|
|
|
if not log_entries: |
|
return "No logs to push." |
|
|
|
df = pd.DataFrame(log_entries) |
|
dataset = Dataset.from_pandas(df) |
|
|
|
dataset.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN) |
|
log_entries.clear() |
|
return f"Pushed {len(df)} logs to {HF_DATASET_REPO}!" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## AI-generated text detector") |
|
|
|
with gr.Row(): |
|
input_box = gr.Textbox(label="Input Text", lines=6, interactive=True) |
|
output_box = gr.Textbox(label="Predicted Label", lines=6) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.Button("Clear") |
|
|
|
submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box) |
|
clear_btn.click(fn=clear_fields, outputs=[input_box, output_box]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|