Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Define model names | |
models = { | |
"gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test", | |
"ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test", | |
"ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2", | |
"ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender", | |
"ModernBERT Large (gender)": "breadlicker45/ModernBERT-large-gender" | |
} | |
# Define the mapping for user-friendly labels | |
label_map = { | |
"LABEL_0": "Male (0)", | |
"0": "Male (0)", | |
"LABEL_1": "Female (1)", | |
"1": "Female (1)" | |
} | |
# A cache to store loaded models/pipelines to speed up subsequent requests | |
model_cache = {} | |
# Determine the device to run on (GPU if available, otherwise CPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# The main classification function, now handles both model types | |
def classify_text(model_name, text): | |
try: | |
processed_results = {} | |
model_id = models[model_name] | |
# --- SPECIAL HANDLING FOR THE GTE MODEL --- | |
if "gte-gender" in model_id: | |
# Check if model/tokenizer is already in our cache | |
if model_id not in model_cache: | |
print(f"Loading GTE model and tokenizer manually: {model_id}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device) | |
model_cache[model_id] = (model, tokenizer) # Cache both | |
model, tokenizer = model_cache[model_id] | |
# Tokenize the input text and move to the correct device | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
# Get model predictions | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# Convert logits to probabilities using softmax | |
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
# Format results to match the pipeline's output style | |
processed_results[label_map["LABEL_0"]] = probabilities[0].item() | |
processed_results[label_map["LABEL_1"]] = probabilities[1].item() | |
# --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS --- | |
else: | |
# Check if the pipeline is already in our cache | |
if model_id not in model_cache: | |
print(f"Loading pipeline for model: {model_id}...") | |
# Load and cache the pipeline | |
model_cache[model_id] = pipeline( | |
"text-classification", | |
model=model_id, | |
top_k=None, | |
device=device # Use the determined device | |
) | |
classifier = model_cache[model_id] | |
predictions = classifier(text) | |
# Process predictions to use friendly labels | |
if predictions and isinstance(predictions, list) and predictions[0]: | |
for pred in predictions[0]: | |
raw_label = pred["label"] | |
score = pred["score"] | |
friendly_label = label_map.get(raw_label, raw_label) | |
processed_results[friendly_label] = score | |
return processed_results | |
except Exception as e: | |
print(f"Error: {e}") | |
# Return an error message suitable for gr.Label or gr.JSON | |
return {"Error": f"Failed to process: {e}"} | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=classify_text, | |
inputs=[ | |
gr.Dropdown( | |
list(models.keys()), | |
label="Select Model", | |
value="gte base (gender v3.1)" # Default model | |
), | |
gr.Textbox( | |
lines=2, | |
placeholder="Enter text to classify for perceived gender...", | |
value="This is an example sentence." | |
) | |
], | |
# Since we now consistently return a dictionary of {label: score}, | |
# we can go back to using the nicer-looking gr.Label component! | |
outputs=gr.Label(num_top_classes=2, label="Classification Results"), | |
title="ModernBERT & GTE Gender Classifier", | |
description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.", | |
allow_flagging="never", | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() |