SorrelC commited on
Commit
63fb06a
Β·
verified Β·
1 Parent(s): f32b5c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -111
app.py CHANGED
@@ -41,7 +41,14 @@ class HybridNERManager:
41
  def __init__(self):
42
  self.gliner_model = None
43
  self.spacy_model = None
 
44
  self.all_entity_colors = {}
 
 
 
 
 
 
45
 
46
  def load_gliner_model(self):
47
  """Load GLiNER model for custom entities"""
@@ -55,12 +62,24 @@ class HybridNERManager:
55
  return None
56
  return self.gliner_model
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def load_spacy_model(self):
59
  """Load spaCy model for standard NER"""
60
  if self.spacy_model is None:
61
  try:
62
  import spacy
63
- # Try to load the transformer model first, fallback to smaller model
64
  try:
65
  self.spacy_model = spacy.load("en_core_web_sm")
66
  print("βœ“ spaCy model loaded successfully")
@@ -72,6 +91,46 @@ class HybridNERManager:
72
  return None
73
  return self.spacy_model
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def assign_colors(self, standard_entities, custom_entities):
76
  """Assign colors to all entity types"""
77
  self.all_entity_colors = {}
@@ -90,28 +149,52 @@ class HybridNERManager:
90
 
91
  return self.all_entity_colors
92
 
93
- def extract_spacy_entities(self, text, entity_types):
94
- """Extract entities using spaCy"""
95
- model = self.load_spacy_model()
 
 
 
 
 
 
 
 
 
 
 
96
  if model is None:
97
  return []
98
 
99
  try:
100
- doc = model(text)
 
 
101
  entities = []
102
- for ent in doc.ents:
103
- if ent.label_ in entity_types:
 
 
 
 
 
 
 
 
 
 
 
104
  entities.append({
105
- 'text': ent.text,
106
- 'label': ent.label_,
107
- 'start': ent.start_char,
108
- 'end': ent.end_char,
109
- 'confidence': 1.0, # spaCy doesn't provide confidence scores
110
- 'source': 'spaCy'
111
  })
112
  return entities
113
  except Exception as e:
114
- print(f"Error with spaCy extraction: {str(e)}")
115
  return []
116
 
117
  def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True):
@@ -138,13 +221,13 @@ class HybridNERManager:
138
  return []
139
 
140
  def find_overlapping_entities(entities):
141
- """Find and merge overlapping entities"""
142
  if not entities:
143
  return []
144
 
145
  # Sort entities by start position
146
  sorted_entities = sorted(entities, key=lambda x: x['start'])
147
- merged_entities = []
148
 
149
  i = 0
150
  while i < len(sorted_entities):
@@ -165,19 +248,19 @@ def find_overlapping_entities(entities):
165
  else:
166
  j += 1
167
 
168
- # Create merged entity
169
  if len(overlapping_entities) == 1:
170
- merged_entities.append(overlapping_entities[0])
171
  else:
172
- merged_entity = merge_entities(overlapping_entities)
173
- merged_entities.append(merged_entity)
174
 
175
  i += 1
176
 
177
- return merged_entities
178
 
179
- def merge_entities(entity_list):
180
- """Merge multiple overlapping entities into one"""
181
  if len(entity_list) == 1:
182
  return entity_list[0]
183
 
@@ -196,7 +279,7 @@ def merge_entities(entity_list):
196
  'labels': labels,
197
  'sources': sources,
198
  'confidences': confidences,
199
- 'is_merged': True,
200
  'entity_count': len(entity_list)
201
  }
202
 
@@ -205,11 +288,11 @@ def create_highlighted_html(text, entities, entity_colors):
205
  if not entities:
206
  return f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 5px; background-color: #fafafa;'><p>{text}</p></div>"
207
 
208
- # Find and merge overlapping entities
209
- merged_entities = find_overlapping_entities(entities)
210
 
211
  # Sort by start position
212
- sorted_entities = sorted(merged_entities, key=lambda x: x['start'])
213
 
214
  # Create HTML with highlighting
215
  html_parts = []
@@ -219,9 +302,9 @@ def create_highlighted_html(text, entities, entity_colors):
219
  # Add text before entity
220
  html_parts.append(text[last_end:entity['start']])
221
 
222
- if entity.get('is_merged', False):
223
- # Handle merged entity with multiple colors
224
- html_parts.append(create_merged_entity_html(entity, entity_colors))
225
  else:
226
  # Handle single entity
227
  html_parts.append(create_single_entity_html(entity, entity_colors))
@@ -253,8 +336,8 @@ def create_single_entity_html(entity, entity_colors):
253
  f'title="{label} ({source}) - confidence: {confidence:.2f}">'
254
  f'{entity["text"]}</span>')
255
 
256
- def create_merged_entity_html(entity, entity_colors):
257
- """Create HTML for a merged entity with multiple colors"""
258
  labels = entity['labels']
259
  sources = entity['sources']
260
  confidences = entity['confidences']
@@ -287,22 +370,22 @@ def create_merged_entity_html(entity, entity_colors):
287
  return (f'<span style="background: {gradient}; padding: 2px 4px; '
288
  f'border-radius: 3px; margin: 0 1px; '
289
  f'border: 2px solid #333; color: white; font-weight: bold;" '
290
- f'title="MERGED: {tooltip}">'
291
  f'{entity["text"]} πŸ”—</span>')
292
 
293
  def create_entity_table_html(entities, entity_colors):
294
- """Create HTML table of entities"""
295
  if not entities:
296
  return "<p>No entities found.</p>"
297
 
298
- # Merge overlapping entities
299
- merged_entities = find_overlapping_entities(entities)
300
 
301
  # Group entities by type
302
  entity_groups = {}
303
- for entity in merged_entities:
304
- if entity.get('is_merged', False):
305
- key = 'MERGED_ENTITIES'
306
  else:
307
  key = entity['label']
308
 
@@ -310,54 +393,162 @@ def create_entity_table_html(entities, entity_colors):
310
  entity_groups[key] = []
311
  entity_groups[key].append(entity)
312
 
313
- # Create HTML table
314
- html = "<div style='margin: 15px 0;'>"
315
-
316
- for entity_type, entities_of_type in entity_groups.items():
317
- if entity_type == 'MERGED_ENTITIES':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  color = '#666666'
319
- header = f"πŸ”— Merged Entities ({len(entities_of_type)})"
320
  else:
321
- color = entity_colors.get(entity_type.upper(), '#CCCCCC')
322
- header = f"{entity_type} ({len(entities_of_type)})"
323
-
324
- html += f"""
325
- <h4 style="color: {color}; margin: 15px 0 10px 0;">{header}</h4>
326
- <table style="width: 100%; border-collapse: collapse; margin-bottom: 20px; border: 1px solid #ddd;">
327
- <thead>
328
- <tr style="background-color: {color}; color: white;">
329
- <th style="padding: 10px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
330
- <th style="padding: 10px; text-align: left; border: 1px solid #ddd;">Label(s)</th>
331
- <th style="padding: 10px; text-align: left; border: 1px solid #ddd;">Source(s)</th>
332
- <th style="padding: 10px; text-align: left; border: 1px solid #ddd;">Confidence</th>
333
- </tr>
334
- </thead>
335
- <tbody>
336
  """
337
 
338
- for entity in entities_of_type:
339
- if entity.get('is_merged', False):
 
 
 
 
 
 
 
 
 
 
 
340
  labels_text = " | ".join(entity['labels'])
341
  sources_text = " | ".join(entity['sources'])
342
- conf_text = " | ".join([f"{c:.2f}" for c in entity['confidences']])
343
- else:
344
- labels_text = entity['label']
345
- sources_text = entity['source']
346
- conf_text = f"{entity['confidence']:.2f}"
347
-
348
- html += f"""
349
- <tr style="background-color: #fff;">
350
- <td style="padding: 8px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
351
- <td style="padding: 8px; border: 1px solid #ddd;">{labels_text}</td>
352
- <td style="padding: 8px; border: 1px solid #ddd;">{sources_text}</td>
353
- <td style="padding: 8px; border: 1px solid #ddd;">{conf_text}</td>
354
- </tr>
 
 
 
 
 
 
 
 
 
 
355
  """
356
 
357
- html += "</tbody></table>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
- html += "</div>"
360
- return html
361
 
362
  def create_legend_html(entity_colors, standard_entities, custom_entities):
363
  """Create a legend showing entity colors"""
@@ -391,7 +582,7 @@ def create_legend_html(entity_colors, standard_entities, custom_entities):
391
  # Initialize the NER manager
392
  ner_manager = HybridNERManager()
393
 
394
- def process_text(text, standard_entities, custom_entities_str, confidence_threshold, use_spacy, use_gliner_standard):
395
  """Main processing function for Gradio interface"""
396
  if not text.strip():
397
  return "❌ Please enter some text to analyze", "", ""
@@ -409,17 +600,12 @@ def process_text(text, standard_entities, custom_entities_str, confidence_thresh
409
 
410
  all_entities = []
411
 
412
- # Extract standard entities
413
- if selected_standard:
414
- if use_spacy:
415
- spacy_entities = ner_manager.extract_spacy_entities(text, selected_standard)
416
- all_entities.extend(spacy_entities)
417
-
418
- if use_gliner_standard:
419
- gliner_standard_entities = ner_manager.extract_gliner_entities(text, selected_standard, confidence_threshold, is_custom=False)
420
- all_entities.extend(gliner_standard_entities)
421
 
422
- # Extract custom entities
423
  if custom_entities:
424
  custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True)
425
  all_entities.extend(custom_entity_results)
@@ -435,17 +621,17 @@ def process_text(text, standard_entities, custom_entities_str, confidence_thresh
435
  highlighted_html = create_highlighted_html(text, all_entities, entity_colors)
436
  table_html = create_entity_table_html(all_entities, entity_colors)
437
 
438
- # Create summary
439
  total_entities = len(all_entities)
440
- merged_entities = find_overlapping_entities(all_entities)
441
- final_count = len(merged_entities)
442
- merged_count = sum(1 for e in merged_entities if e.get('is_merged', False))
443
 
444
  summary = f"""
445
  ## πŸ“Š Analysis Summary
446
  - **Total entities found:** {total_entities}
447
  - **Final entities displayed:** {final_count}
448
- - **Merged entities:** {merged_count}
449
  - **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f}
450
  """
451
 
@@ -459,14 +645,15 @@ def create_interface():
459
 
460
  Combine standard NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction.
461
 
462
- ## πŸ”— NEW: Overlapping entities are automatically merged with split-color highlighting!
463
 
464
  ### How to use:
465
  1. **πŸ“ Enter your text** in the text area below
466
- 2. **🎯 Select standard entities** you want to find (PER, ORG, LOC, etc.)
467
- 3. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills"
468
- 4. **βš™οΈ Choose models** and adjust confidence threshold
469
- 5. **πŸ” Click "Analyze Text"** to see results
 
470
  """)
471
 
472
  with gr.Row():
@@ -490,15 +677,40 @@ def create_interface():
490
  with gr.Row():
491
  with gr.Column():
492
  gr.Markdown("### 🎯 Standard Entity Types")
 
 
 
 
 
 
 
 
 
 
493
  standard_entities = gr.CheckboxGroup(
494
  choices=STANDARD_ENTITIES,
495
  value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection
496
  label="Select Standard Entities"
497
  )
498
 
 
499
  with gr.Row():
500
- use_spacy = gr.Checkbox(label="Use spaCy", value=True)
501
- use_gliner_standard = gr.Checkbox(label="Use GLiNER for Standard", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  with gr.Column():
504
  gr.Markdown("### ✨ Custom Entity Types")
@@ -510,8 +722,9 @@ def create_interface():
510
  gr.Markdown("""
511
  **Examples:**
512
  - relationships, occupations, skills
513
- - emotions, actions, objects
514
  - medical conditions, treatments
 
515
  """)
516
 
517
  analyze_btn = gr.Button("πŸ” Analyze Text", variant="primary", size="lg")
@@ -524,7 +737,7 @@ def create_interface():
524
  highlighted_output = gr.HTML(label="Highlighted Text")
525
 
526
  with gr.Row():
527
- table_output = gr.HTML(label="Detailed Results")
528
 
529
  # Connect the button to the processing function
530
  analyze_btn.click(
@@ -534,8 +747,7 @@ def create_interface():
534
  standard_entities,
535
  custom_entities,
536
  confidence_threshold,
537
- use_spacy,
538
- use_gliner_standard
539
  ],
540
  outputs=[summary_output, highlighted_output, table_output]
541
  )
@@ -548,16 +760,21 @@ def create_interface():
548
  ["PER", "ORG", "LOC", "DATE"],
549
  "relationships, occupations, educational background",
550
  0.3,
551
- True,
552
- False
553
  ],
554
  [
555
  "The meeting between CEO Jane Doe and the board of directors at Microsoft headquarters in Seattle discussed the Q4 financial results and the new AI strategy for 2024.",
556
  ["PER", "ORG", "LOC", "DATE"],
557
  "corporate roles, business events, financial terms",
558
  0.4,
559
- True,
560
- True
 
 
 
 
 
 
561
  ]
562
  ],
563
  inputs=[
@@ -565,8 +782,7 @@ def create_interface():
565
  standard_entities,
566
  custom_entities,
567
  confidence_threshold,
568
- use_spacy,
569
- use_gliner_standard
570
  ]
571
  )
572
 
 
41
  def __init__(self):
42
  self.gliner_model = None
43
  self.spacy_model = None
44
+ self.flair_models = {}
45
  self.all_entity_colors = {}
46
+ self.model_names = [
47
+ 'spacy_en_core_web_sm',
48
+ 'flair_ner-ontonotes-large',
49
+ 'flair_ner-large',
50
+ 'gliner_medium-v2.1'
51
+ ]
52
 
53
  def load_gliner_model(self):
54
  """Load GLiNER model for custom entities"""
 
62
  return None
63
  return self.gliner_model
64
 
65
+ def load_model(self, model_name):
66
+ """Load the specified model"""
67
+ try:
68
+ if model_name == 'spacy_en_core_web_sm':
69
+ return self.load_spacy_model()
70
+ elif 'flair' in model_name:
71
+ return self.load_flair_model(model_name)
72
+ elif 'gliner' in model_name:
73
+ return self.load_gliner_model()
74
+ except Exception as e:
75
+ print(f"Error loading {model_name}: {str(e)}")
76
+ return None
77
+
78
  def load_spacy_model(self):
79
  """Load spaCy model for standard NER"""
80
  if self.spacy_model is None:
81
  try:
82
  import spacy
 
83
  try:
84
  self.spacy_model = spacy.load("en_core_web_sm")
85
  print("βœ“ spaCy model loaded successfully")
 
91
  return None
92
  return self.spacy_model
93
 
94
+ def load_flair_model(self, model_name):
95
+ """Load Flair models"""
96
+ if model_name not in self.flair_models:
97
+ try:
98
+ from flair.models import SequenceTagger
99
+ if 'ontonotes' in model_name:
100
+ model = SequenceTagger.load("flair/ner-english-ontonotes-large")
101
+ else:
102
+ model = SequenceTagger.load("flair/ner-english-large")
103
+ self.flair_models[model_name] = model
104
+ print(f"βœ“ {model_name} loaded successfully")
105
+ except Exception as e:
106
+ print(f"Error loading {model_name}: {str(e)}")
107
+ return None
108
+ return self.flair_models[model_name]
109
+
110
+ def extract_spacy_entities(self, text, entity_types):
111
+ """Extract entities using spaCy"""
112
+ model = self.load_spacy_model()
113
+ if model is None:
114
+ return []
115
+
116
+ try:
117
+ doc = model(text)
118
+ entities = []
119
+ for ent in doc.ents:
120
+ if ent.label_ in entity_types:
121
+ entities.append({
122
+ 'text': ent.text,
123
+ 'label': ent.label_,
124
+ 'start': ent.start_char,
125
+ 'end': ent.end_char,
126
+ 'confidence': 1.0, # spaCy doesn't provide confidence scores
127
+ 'source': 'spaCy'
128
+ })
129
+ return entities
130
+ except Exception as e:
131
+ print(f"Error with spaCy extraction: {str(e)}")
132
+ return []
133
+
134
  def assign_colors(self, standard_entities, custom_entities):
135
  """Assign colors to all entity types"""
136
  self.all_entity_colors = {}
 
149
 
150
  return self.all_entity_colors
151
 
152
+ def extract_entities_by_model(self, text, entity_types, model_name, threshold=0.3):
153
+ """Extract entities using the specified model"""
154
+ if model_name == 'spacy_en_core_web_sm':
155
+ return self.extract_spacy_entities(text, entity_types)
156
+ elif 'flair' in model_name:
157
+ return self.extract_flair_entities(text, entity_types, model_name)
158
+ elif 'gliner' in model_name:
159
+ return self.extract_gliner_entities(text, entity_types, threshold, is_custom=False)
160
+ else:
161
+ return []
162
+
163
+ def extract_flair_entities(self, text, entity_types, model_name):
164
+ """Extract entities using Flair"""
165
+ model = self.load_flair_model(model_name)
166
  if model is None:
167
  return []
168
 
169
  try:
170
+ from flair.data import Sentence
171
+ sentence = Sentence(text)
172
+ model.predict(sentence)
173
  entities = []
174
+ for entity in sentence.get_spans('ner'):
175
+ # Map Flair labels to our standard set
176
+ label = entity.tag
177
+ if label == 'PERSON':
178
+ label = 'PER'
179
+ elif label == 'ORGANIZATION':
180
+ label = 'ORG'
181
+ elif label == 'LOCATION':
182
+ label = 'LOC'
183
+ elif label == 'MISCELLANEOUS':
184
+ label = 'MISC'
185
+
186
+ if label in entity_types:
187
  entities.append({
188
+ 'text': entity.text,
189
+ 'label': label,
190
+ 'start': entity.start_position,
191
+ 'end': entity.end_position,
192
+ 'confidence': entity.score,
193
+ 'source': f'Flair-{model_name.split("-")[-1]}'
194
  })
195
  return entities
196
  except Exception as e:
197
+ print(f"Error with Flair extraction: {str(e)}")
198
  return []
199
 
200
  def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True):
 
221
  return []
222
 
223
  def find_overlapping_entities(entities):
224
+ """Find and share overlapping entities"""
225
  if not entities:
226
  return []
227
 
228
  # Sort entities by start position
229
  sorted_entities = sorted(entities, key=lambda x: x['start'])
230
+ shared_entities = []
231
 
232
  i = 0
233
  while i < len(sorted_entities):
 
248
  else:
249
  j += 1
250
 
251
+ # Create shared entity
252
  if len(overlapping_entities) == 1:
253
+ shared_entities.append(overlapping_entities[0])
254
  else:
255
+ shared_entity = share_entities(overlapping_entities)
256
+ shared_entities.append(shared_entity)
257
 
258
  i += 1
259
 
260
+ return shared_entities
261
 
262
+ def share_entities(entity_list):
263
+ """Share multiple overlapping entities into one"""
264
  if len(entity_list) == 1:
265
  return entity_list[0]
266
 
 
279
  'labels': labels,
280
  'sources': sources,
281
  'confidences': confidences,
282
+ 'is_shared': True,
283
  'entity_count': len(entity_list)
284
  }
285
 
 
288
  if not entities:
289
  return f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 5px; background-color: #fafafa;'><p>{text}</p></div>"
290
 
291
+ # Find and share overlapping entities
292
+ shared_entities = find_overlapping_entities(entities)
293
 
294
  # Sort by start position
295
+ sorted_entities = sorted(shared_entities, key=lambda x: x['start'])
296
 
297
  # Create HTML with highlighting
298
  html_parts = []
 
302
  # Add text before entity
303
  html_parts.append(text[last_end:entity['start']])
304
 
305
+ if entity.get('is_shared', False):
306
+ # Handle shared entity with multiple colors
307
+ html_parts.append(create_shared_entity_html(entity, entity_colors))
308
  else:
309
  # Handle single entity
310
  html_parts.append(create_single_entity_html(entity, entity_colors))
 
336
  f'title="{label} ({source}) - confidence: {confidence:.2f}">'
337
  f'{entity["text"]}</span>')
338
 
339
+ def create_shared_entity_html(entity, entity_colors):
340
+ """Create HTML for a shared entity with multiple colors"""
341
  labels = entity['labels']
342
  sources = entity['sources']
343
  confidences = entity['confidences']
 
370
  return (f'<span style="background: {gradient}; padding: 2px 4px; '
371
  f'border-radius: 3px; margin: 0 1px; '
372
  f'border: 2px solid #333; color: white; font-weight: bold;" '
373
+ f'title="SHARED: {tooltip}">'
374
  f'{entity["text"]} πŸ”—</span>')
375
 
376
  def create_entity_table_html(entities, entity_colors):
377
+ """Create HTML table with tabbed interface like the original"""
378
  if not entities:
379
  return "<p>No entities found.</p>"
380
 
381
+ # Share overlapping entities
382
+ shared_entities = find_overlapping_entities(entities)
383
 
384
  # Group entities by type
385
  entity_groups = {}
386
+ for entity in shared_entities:
387
+ if entity.get('is_shared', False):
388
+ key = 'SHARED_ENTITIES'
389
  else:
390
  key = entity['label']
391
 
 
393
  entity_groups[key] = []
394
  entity_groups[key].append(entity)
395
 
396
+ if not entity_groups:
397
+ return "<p>No entities found.</p>"
398
+
399
+ # Create tabbed interface
400
+ tab_html = "<div style='margin: 20px 0;'>"
401
+
402
+ # Tab headers
403
+ tab_html += "<div style='border-bottom: 2px solid #ddd; margin-bottom: 20px;'>"
404
+ tab_headers = []
405
+
406
+ for i, entity_type in enumerate(sorted(entity_groups.keys())):
407
+ count = len(entity_groups[entity_type])
408
+
409
+ if entity_type == 'SHARED_ENTITIES':
410
+ color = '#666666'
411
+ icon = "πŸ”—"
412
+ display_name = "SHARED"
413
+ else:
414
+ color = entity_colors.get(entity_type.upper(), '#f0f0f0')
415
+ # Determine if it's standard or custom
416
+ is_standard = entity_type in STANDARD_ENTITIES
417
+ icon = "🎯" if is_standard else "✨"
418
+ display_name = entity_type
419
+
420
+ active_style = f"background-color: #f8f9fa; border-bottom: 3px solid {color};" if i == 0 else "background-color: #fff;"
421
+ tab_headers.append(f"""
422
+ <button onclick="showTab('{entity_type}')" id="tab-{entity_type}"
423
+ style="padding: 12px 24px; margin-right: 5px; border: 1px solid #ddd;
424
+ border-bottom: none; cursor: pointer; font-weight: bold; {active_style}">
425
+ {icon} {display_name} ({count})
426
+ </button>
427
+ """)
428
+
429
+ tab_html += ''.join(tab_headers)
430
+ tab_html += "</div>"
431
+
432
+ # Tab content
433
+ for i, entity_type in enumerate(sorted(entity_groups.keys())):
434
+ entities_of_type = entity_groups[entity_type]
435
+ display_style = "display: block;" if i == 0 else "display: none;"
436
+
437
+ if entity_type == 'SHARED_ENTITIES':
438
  color = '#666666'
439
+ header_text = f"πŸ”— Shared Entities ({len(entities_of_type)} found)"
440
  else:
441
+ color = entity_colors.get(entity_type.upper(), '#f0f0f0')
442
+ source_type = entities_of_type[0].get('source', 'Unknown')
443
+ is_standard = entity_type in STANDARD_ENTITIES
444
+ source_icon = "🎯 Standard NER" if is_standard else "✨ Custom GLiNER"
445
+ header_text = f"{source_icon} - {entity_type} Entities ({len(entities_of_type)} found)"
446
+
447
+ tab_html += f"""
448
+ <div id="content-{entity_type}" style="{display_style}">
449
+ <h4 style="color: {color}; margin-bottom: 15px;">{header_text}</h4>
450
+ <table style="width: 100%; border-collapse: collapse; margin-bottom: 20px;">
451
+ <thead>
 
 
 
 
452
  """
453
 
454
+ if entity_type == 'SHARED_ENTITIES':
455
+ tab_html += f"""
456
+ <tr style="background-color: {color}; color: white;">
457
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
458
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">All Labels</th>
459
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Sources</th>
460
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Count</th>
461
+ </tr>
462
+ </thead>
463
+ <tbody>
464
+ """
465
+
466
+ for entity in entities_of_type:
467
  labels_text = " | ".join(entity['labels'])
468
  sources_text = " | ".join(entity['sources'])
469
+
470
+ tab_html += f"""
471
+ <tr style="background-color: #fff;">
472
+ <td style="padding: 10px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
473
+ <td style="padding: 10px; border: 1px solid #ddd;">{labels_text}</td>
474
+ <td style="padding: 10px; border: 1px solid #ddd;">{sources_text}</td>
475
+ <td style="padding: 10px; border: 1px solid #ddd; text-align: center;">
476
+ <span style='background-color: #28a745; color: white; padding: 2px 6px; border-radius: 10px; font-size: 11px;'>
477
+ {entity['entity_count']}
478
+ </span>
479
+ </td>
480
+ </tr>
481
+ """
482
+ else:
483
+ tab_html += f"""
484
+ <tr style="background-color: {color}; color: white;">
485
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
486
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Confidence</th>
487
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Type</th>
488
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Source</th>
489
+ </tr>
490
+ </thead>
491
+ <tbody>
492
  """
493
 
494
+ # Sort by confidence score
495
+ entities_of_type.sort(key=lambda x: x.get('confidence', 0), reverse=True)
496
+
497
+ for entity in entities_of_type:
498
+ confidence = entity.get('confidence', 0.0)
499
+ confidence_color = "#28a745" if confidence > 0.7 else "#ffc107" if confidence > 0.4 else "#dc3545"
500
+ source = entity.get('source', 'Unknown')
501
+ source_badge = f"<span style='background-color: #007bff; color: white; padding: 2px 6px; border-radius: 10px; font-size: 11px;'>{source}</span>"
502
+
503
+ tab_html += f"""
504
+ <tr style="background-color: #fff;">
505
+ <td style="padding: 10px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
506
+ <td style="padding: 10px; border: 1px solid #ddd;">
507
+ <span style="color: {confidence_color}; font-weight: bold;">
508
+ {confidence:.3f}
509
+ </span>
510
+ </td>
511
+ <td style="padding: 10px; border: 1px solid #ddd;">{entity['label']}</td>
512
+ <td style="padding: 10px; border: 1px solid #ddd;">{source_badge}</td>
513
+ </tr>
514
+ """
515
+
516
+ tab_html += """
517
+ </tbody>
518
+ </table>
519
+ </div>
520
+ """
521
+
522
+ # JavaScript for tab switching
523
+ tab_html += """
524
+ <script>
525
+ function showTab(entityType) {
526
+ // Hide all content
527
+ var contents = document.querySelectorAll('[id^="content-"]');
528
+ contents.forEach(function(content) {
529
+ content.style.display = 'none';
530
+ });
531
+
532
+ // Reset all tab styles
533
+ var tabs = document.querySelectorAll('[id^="tab-"]');
534
+ tabs.forEach(function(tab) {
535
+ tab.style.backgroundColor = '#fff';
536
+ tab.style.borderBottom = 'none';
537
+ });
538
+
539
+ // Show selected content
540
+ document.getElementById('content-' + entityType).style.display = 'block';
541
+
542
+ // Highlight selected tab
543
+ var activeTab = document.getElementById('tab-' + entityType);
544
+ activeTab.style.backgroundColor = '#f8f9fa';
545
+ activeTab.style.borderBottom = '3px solid #4ECDC4';
546
+ }
547
+ </script>
548
+ """
549
 
550
+ tab_html += "</div>"
551
+ return tab_html
552
 
553
  def create_legend_html(entity_colors, standard_entities, custom_entities):
554
  """Create a legend showing entity colors"""
 
582
  # Initialize the NER manager
583
  ner_manager = HybridNERManager()
584
 
585
+ def process_text(text, standard_entities, custom_entities_str, confidence_threshold, selected_model):
586
  """Main processing function for Gradio interface"""
587
  if not text.strip():
588
  return "❌ Please enter some text to analyze", "", ""
 
600
 
601
  all_entities = []
602
 
603
+ # Extract standard entities using selected model
604
+ if selected_standard and selected_model:
605
+ standard_entities_results = ner_manager.extract_entities_by_model(text, selected_standard, selected_model, confidence_threshold)
606
+ all_entities.extend(standard_entities_results)
 
 
 
 
 
607
 
608
+ # Extract custom entities using GLiNER
609
  if custom_entities:
610
  custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True)
611
  all_entities.extend(custom_entity_results)
 
621
  highlighted_html = create_highlighted_html(text, all_entities, entity_colors)
622
  table_html = create_entity_table_html(all_entities, entity_colors)
623
 
624
+ # Create summary with shared entities terminology
625
  total_entities = len(all_entities)
626
+ shared_entities = find_overlapping_entities(all_entities)
627
+ final_count = len(shared_entities)
628
+ shared_count = sum(1 for e in shared_entities if e.get('is_shared', False))
629
 
630
  summary = f"""
631
  ## πŸ“Š Analysis Summary
632
  - **Total entities found:** {total_entities}
633
  - **Final entities displayed:** {final_count}
634
+ - **Shared entities:** {shared_count}
635
  - **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f}
636
  """
637
 
 
645
 
646
  Combine standard NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction.
647
 
648
+ ## πŸ”— NEW: Overlapping entities are automatically shared with split-color highlighting!
649
 
650
  ### How to use:
651
  1. **πŸ“ Enter your text** in the text area below
652
+ 2. **🎯 Select a model** from the dropdown for standard entities
653
+ 3. **β˜‘οΈ Select standard entities** you want to find (PER, ORG, LOC, etc.)
654
+ 4. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills"
655
+ 5. **βš™οΈ Adjust confidence threshold**
656
+ 6. **πŸ” Click "Analyze Text"** to see results with tabbed output
657
  """)
658
 
659
  with gr.Row():
 
677
  with gr.Row():
678
  with gr.Column():
679
  gr.Markdown("### 🎯 Standard Entity Types")
680
+
681
+ # Model selector
682
+ model_dropdown = gr.Dropdown(
683
+ choices=ner_manager.model_names,
684
+ value=ner_manager.model_names[0],
685
+ label="Select Model for Standard Entities",
686
+ info="Choose which model to use for standard NER"
687
+ )
688
+
689
+ # Standard entities with select all functionality
690
  standard_entities = gr.CheckboxGroup(
691
  choices=STANDARD_ENTITIES,
692
  value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection
693
  label="Select Standard Entities"
694
  )
695
 
696
+ # Select/Deselect All button
697
  with gr.Row():
698
+ select_all_btn = gr.Button("πŸ”˜ Deselect All", size="sm")
699
+
700
+ # Function for select/deselect all
701
+ def toggle_all_entities(current_selection):
702
+ if len(current_selection) > 0:
703
+ # If any are selected, deselect all
704
+ return [], "β˜‘οΈ Select All"
705
+ else:
706
+ # If none selected, select all
707
+ return STANDARD_ENTITIES, "πŸ”˜ Deselect All"
708
+
709
+ select_all_btn.click(
710
+ fn=toggle_all_entities,
711
+ inputs=[standard_entities],
712
+ outputs=[standard_entities, select_all_btn]
713
+ )
714
 
715
  with gr.Column():
716
  gr.Markdown("### ✨ Custom Entity Types")
 
722
  gr.Markdown("""
723
  **Examples:**
724
  - relationships, occupations, skills
725
+ - emotions, actions, objects
726
  - medical conditions, treatments
727
+ - financial terms, business roles
728
  """)
729
 
730
  analyze_btn = gr.Button("πŸ” Analyze Text", variant="primary", size="lg")
 
737
  highlighted_output = gr.HTML(label="Highlighted Text")
738
 
739
  with gr.Row():
740
+ table_output = gr.HTML(label="Detailed Results (Tabbed)")
741
 
742
  # Connect the button to the processing function
743
  analyze_btn.click(
 
747
  standard_entities,
748
  custom_entities,
749
  confidence_threshold,
750
+ model_dropdown
 
751
  ],
752
  outputs=[summary_output, highlighted_output, table_output]
753
  )
 
760
  ["PER", "ORG", "LOC", "DATE"],
761
  "relationships, occupations, educational background",
762
  0.3,
763
+ "spacy_en_core_web_sm"
 
764
  ],
765
  [
766
  "The meeting between CEO Jane Doe and the board of directors at Microsoft headquarters in Seattle discussed the Q4 financial results and the new AI strategy for 2024.",
767
  ["PER", "ORG", "LOC", "DATE"],
768
  "corporate roles, business events, financial terms",
769
  0.4,
770
+ "flair_ner-ontonotes-large"
771
+ ],
772
+ [
773
+ "Dr. Emily Watson published a research paper on machine learning algorithms at MIT. She collaborates with her colleague Prof. David Chen on natural language processing projects.",
774
+ ["PER", "ORG", "WORK_OF_ART"],
775
+ "academic titles, research topics, collaborations",
776
+ 0.3,
777
+ "gliner_medium-v2.1"
778
  ]
779
  ],
780
  inputs=[
 
782
  standard_entities,
783
  custom_entities,
784
  confidence_threshold,
785
+ model_dropdown
 
786
  ]
787
  )
788