import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification import torch # Define model names models = { "gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test", "ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test", "ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2", "ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender", "ModernBERT Large (gender)": "breadlicker45/ModernBERT-large-gender" } # Define the mapping for user-friendly labels label_map = { "LABEL_0": "Male (0)", "0": "Male (0)", "LABEL_1": "Female (1)", "1": "Female (1)" } # A cache to store loaded models/pipelines to speed up subsequent requests model_cache = {} # Determine the device to run on (GPU if available, otherwise CPU) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # The main classification function, now handles both model types def classify_text(model_name, text): try: processed_results = {} model_id = models[model_name] # --- SPECIAL HANDLING FOR THE GTE MODEL --- if "gte-gender" in model_id: # Check if model/tokenizer is already in our cache if model_id not in model_cache: print(f"Loading GTE model and tokenizer manually: {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device) model_cache[model_id] = (model, tokenizer) # Cache both model, tokenizer = model_cache[model_id] # Tokenize the input text and move to the correct device inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) # Get model predictions with torch.no_grad(): logits = model(**inputs).logits # Convert logits to probabilities using softmax probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] # Format results to match the pipeline's output style processed_results[label_map["LABEL_0"]] = probabilities[0].item() processed_results[label_map["LABEL_1"]] = probabilities[1].item() # --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS --- else: # Check if the pipeline is already in our cache if model_id not in model_cache: print(f"Loading pipeline for model: {model_id}...") # Load and cache the pipeline model_cache[model_id] = pipeline( "text-classification", model=model_id, top_k=None, device=device # Use the determined device ) classifier = model_cache[model_id] predictions = classifier(text) # Process predictions to use friendly labels if predictions and isinstance(predictions, list) and predictions[0]: for pred in predictions[0]: raw_label = pred["label"] score = pred["score"] friendly_label = label_map.get(raw_label, raw_label) processed_results[friendly_label] = score return processed_results except Exception as e: print(f"Error: {e}") # Return an error message suitable for gr.Label or gr.JSON return {"Error": f"Failed to process: {e}"} # Create the Gradio interface interface = gr.Interface( fn=classify_text, inputs=[ gr.Dropdown( list(models.keys()), label="Select Model", value="gte base (gender v3.1)" # Default model ), gr.Textbox( lines=2, placeholder="Enter text to classify for perceived gender...", value="This is an example sentence." ) ], # Since we now consistently return a dictionary of {label: score}, # we can go back to using the nicer-looking gr.Label component! outputs=gr.Label(num_top_classes=2, label="Classification Results"), title="ModernBERT & GTE Gender Classifier", description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.", allow_flagging="never", ) # Launch the app if __name__ == "__main__": interface.launch()