breadlicker45's picture
Update app.py
ceabca1 verified
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
# Define model names
models = {
"gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test",
"ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test",
"ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2",
"ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender",
"ModernBERT Large (gender)": "breadlicker45/ModernBERT-large-gender"
}
# Define the mapping for user-friendly labels
label_map = {
"LABEL_0": "Male (0)",
"0": "Male (0)",
"LABEL_1": "Female (1)",
"1": "Female (1)"
}
# A cache to store loaded models/pipelines to speed up subsequent requests
model_cache = {}
# Determine the device to run on (GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# The main classification function, now handles both model types
def classify_text(model_name, text):
try:
processed_results = {}
model_id = models[model_name]
# --- SPECIAL HANDLING FOR THE GTE MODEL ---
if "gte-gender" in model_id:
# Check if model/tokenizer is already in our cache
if model_id not in model_cache:
print(f"Loading GTE model and tokenizer manually: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device)
model_cache[model_id] = (model, tokenizer) # Cache both
model, tokenizer = model_cache[model_id]
# Tokenize the input text and move to the correct device
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
# Get model predictions
with torch.no_grad():
logits = model(**inputs).logits
# Convert logits to probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
# Format results to match the pipeline's output style
processed_results[label_map["LABEL_0"]] = probabilities[0].item()
processed_results[label_map["LABEL_1"]] = probabilities[1].item()
# --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS ---
else:
# Check if the pipeline is already in our cache
if model_id not in model_cache:
print(f"Loading pipeline for model: {model_id}...")
# Load and cache the pipeline
model_cache[model_id] = pipeline(
"text-classification",
model=model_id,
top_k=None,
device=device # Use the determined device
)
classifier = model_cache[model_id]
predictions = classifier(text)
# Process predictions to use friendly labels
if predictions and isinstance(predictions, list) and predictions[0]:
for pred in predictions[0]:
raw_label = pred["label"]
score = pred["score"]
friendly_label = label_map.get(raw_label, raw_label)
processed_results[friendly_label] = score
return processed_results
except Exception as e:
print(f"Error: {e}")
# Return an error message suitable for gr.Label or gr.JSON
return {"Error": f"Failed to process: {e}"}
# Create the Gradio interface
interface = gr.Interface(
fn=classify_text,
inputs=[
gr.Dropdown(
list(models.keys()),
label="Select Model",
value="gte base (gender v3.1)" # Default model
),
gr.Textbox(
lines=2,
placeholder="Enter text to classify for perceived gender...",
value="This is an example sentence."
)
],
# Since we now consistently return a dictionary of {label: score},
# we can go back to using the nicer-looking gr.Label component!
outputs=gr.Label(num_top_classes=2, label="Classification Results"),
title="ModernBERT & GTE Gender Classifier",
description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.",
allow_flagging="never",
)
# Launch the app
if __name__ == "__main__":
interface.launch()