Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Load model and tokenizer | |
model_id = "Rerandaka/Cild_safety_bigbird" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) | |
model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
# Class mapping (optional β edit as needed) | |
label_map = { | |
0: "Safe / Normal", | |
1: "Inappropriate / Unsafe" | |
} | |
# Inference function | |
def classify_text(text: str): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
predicted = torch.argmax(probs, dim=1).item() | |
confidence = probs[0][predicted].item() | |
return { | |
"label": label_map.get(predicted, str(predicted)), | |
"confidence": round(confidence, 4) | |
} | |
# Define Gradio Interface | |
demo = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(label="Enter text to classify"), | |
outputs=[ | |
gr.Textbox(label="Predicted Label"), | |
gr.Textbox(label="Confidence") | |
], | |
title="Child-Safety Text Classifier", | |
description="This model detects if text content is unsafe or inappropriate for children.", | |
allow_flagging="never" | |
) | |
# Expose API endpoint explicitly | |
demo.launch(api_name="predict") | |