breadlicker45 commited on
Commit
ceabca1
·
verified ·
1 Parent(s): f49e0cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -26
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
  # Define model names
5
  models = {
 
6
  "ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test",
7
  "ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2",
8
  "ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender",
@@ -10,8 +12,6 @@ models = {
10
  }
11
 
12
  # Define the mapping for user-friendly labels
13
- # Note: Transformers pipelines often output 'LABEL_0', 'LABEL_1'.
14
- # We handle potential variations like just '0', '1'.
15
  label_map = {
16
  "LABEL_0": "Male (0)",
17
  "0": "Male (0)",
@@ -19,27 +19,74 @@ label_map = {
19
  "1": "Female (1)"
20
  }
21
 
22
- # Function to load the selected model and classify text
 
 
 
 
 
 
 
 
23
  def classify_text(model_name, text):
24
  try:
25
- classifier = pipeline("text-classification", model=models[model_name], top_k=None)
26
- predictions = classifier(text)
27
-
28
- # Process predictions to use friendly labels
29
  processed_results = {}
30
- if predictions and isinstance(predictions, list) and predictions[0]:
31
- # predictions[0] should be a list of label dicts like [{'label': 'LABEL_1', 'score': 0.9...}, ...]
32
- for pred in predictions[0]:
33
- raw_label = pred["label"]
34
- score = pred["score"]
35
- # Use the map to get a friendly name, fallback to the raw label if not found
36
- friendly_label = label_map.get(raw_label, raw_label)
37
- processed_results[friendly_label] = score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return processed_results
 
39
  except Exception as e:
40
- # Handle potential errors during model loading or inference
41
  print(f"Error: {e}")
42
- # Return an error message suitable for gr.Label
43
  return {"Error": f"Failed to process: {e}"}
44
 
45
 
@@ -50,20 +97,22 @@ interface = gr.Interface(
50
  gr.Dropdown(
51
  list(models.keys()),
52
  label="Select Model",
53
- value="ModernBERT Large (gender)" # Default model
54
  ),
55
  gr.Textbox(
56
  lines=2,
57
- placeholder="Enter text to classify for perceived gender...", # Corrected placeholder
58
- value="This is an example sentence." # Changed example text
59
  )
60
  ],
61
- # The gr.Label component works well for showing classification scores
62
- outputs=gr.Label(num_top_classes=2), # Show both classes explicitly
63
- title="ModernBERT Gender Classifier",
64
- description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.", # Updated description
 
 
65
  )
66
 
67
  # Launch the app
68
  if __name__ == "__main__":
69
- interface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
  # Define model names
6
  models = {
7
+ "gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test",
8
  "ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test",
9
  "ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2",
10
  "ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender",
 
12
  }
13
 
14
  # Define the mapping for user-friendly labels
 
 
15
  label_map = {
16
  "LABEL_0": "Male (0)",
17
  "0": "Male (0)",
 
19
  "1": "Female (1)"
20
  }
21
 
22
+ # A cache to store loaded models/pipelines to speed up subsequent requests
23
+ model_cache = {}
24
+
25
+ # Determine the device to run on (GPU if available, otherwise CPU)
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ print(f"Using device: {device}")
28
+
29
+
30
+ # The main classification function, now handles both model types
31
  def classify_text(model_name, text):
32
  try:
 
 
 
 
33
  processed_results = {}
34
+ model_id = models[model_name]
35
+
36
+ # --- SPECIAL HANDLING FOR THE GTE MODEL ---
37
+ if "gte-gender" in model_id:
38
+ # Check if model/tokenizer is already in our cache
39
+ if model_id not in model_cache:
40
+ print(f"Loading GTE model and tokenizer manually: {model_id}...")
41
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device)
43
+ model_cache[model_id] = (model, tokenizer) # Cache both
44
+
45
+ model, tokenizer = model_cache[model_id]
46
+
47
+ # Tokenize the input text and move to the correct device
48
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
49
+
50
+ # Get model predictions
51
+ with torch.no_grad():
52
+ logits = model(**inputs).logits
53
+
54
+ # Convert logits to probabilities using softmax
55
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
56
+
57
+ # Format results to match the pipeline's output style
58
+ processed_results[label_map["LABEL_0"]] = probabilities[0].item()
59
+ processed_results[label_map["LABEL_1"]] = probabilities[1].item()
60
+
61
+ # --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS ---
62
+ else:
63
+ # Check if the pipeline is already in our cache
64
+ if model_id not in model_cache:
65
+ print(f"Loading pipeline for model: {model_id}...")
66
+ # Load and cache the pipeline
67
+ model_cache[model_id] = pipeline(
68
+ "text-classification",
69
+ model=model_id,
70
+ top_k=None,
71
+ device=device # Use the determined device
72
+ )
73
+
74
+ classifier = model_cache[model_id]
75
+ predictions = classifier(text)
76
+
77
+ # Process predictions to use friendly labels
78
+ if predictions and isinstance(predictions, list) and predictions[0]:
79
+ for pred in predictions[0]:
80
+ raw_label = pred["label"]
81
+ score = pred["score"]
82
+ friendly_label = label_map.get(raw_label, raw_label)
83
+ processed_results[friendly_label] = score
84
+
85
  return processed_results
86
+
87
  except Exception as e:
 
88
  print(f"Error: {e}")
89
+ # Return an error message suitable for gr.Label or gr.JSON
90
  return {"Error": f"Failed to process: {e}"}
91
 
92
 
 
97
  gr.Dropdown(
98
  list(models.keys()),
99
  label="Select Model",
100
+ value="gte base (gender v3.1)" # Default model
101
  ),
102
  gr.Textbox(
103
  lines=2,
104
+ placeholder="Enter text to classify for perceived gender...",
105
+ value="This is an example sentence."
106
  )
107
  ],
108
+ # Since we now consistently return a dictionary of {label: score},
109
+ # we can go back to using the nicer-looking gr.Label component!
110
+ outputs=gr.Label(num_top_classes=2, label="Classification Results"),
111
+ title="ModernBERT & GTE Gender Classifier",
112
+ description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.",
113
+ allow_flagging="never",
114
  )
115
 
116
  # Launch the app
117
  if __name__ == "__main__":
118
+ interface.launch()