SorrelC commited on
Commit
82abc14
·
verified ·
1 Parent(s): 62ce8f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -790
app.py CHANGED
@@ -1,574 +1,14 @@
1
- import gradio as gr
2
- import torch
3
- from gliner import GLiNER
4
- import pandas as pd
5
- import warnings
6
- import random
7
- import re
8
- import time
9
- warnings.filterwarnings('ignore')
10
 
11
- # Common NER entity types
12
- STANDARD_ENTITIES = [
13
- 'DATE', 'EVENT', 'FAC', 'GPE', 'LANG', 'LOC',
14
- 'MISC', 'NORP', 'ORG', 'PER', 'PRODUCT', 'Work of Art'
15
- ]
16
-
17
- # Colour schemes
18
- STANDARD_COLORS = {
19
- 'DATE': '#FF6B6B', # Red
20
- 'EVENT': '#4ECDC4', # Teal
21
- 'FAC': '#45B7D1', # Blue
22
- 'GPE': '#F9CA24', # Yellow
23
- 'LANG': '#6C5CE7', # Purple
24
- 'LOC': '#A0E7E5', # Light Cyan
25
- 'MISC': '#FD79A8', # Pink
26
- 'NORP': '#8E8E93', # Grey
27
- 'ORG': '#55A3FF', # Light Blue
28
- 'PER': '#00B894', # Green
29
- 'PRODUCT': '#E17055', # Orange-Red
30
- 'WORK OF ART': '#DDA0DD' # Plum
31
- }
32
-
33
- # Additional colours for custom entities
34
- CUSTOM_COLOR_PALETTE = [
35
- '#FF9F43', '#10AC84', '#EE5A24', '#0FBC89', '#5F27CD',
36
- '#FF3838', '#2F3640', '#3742FA', '#2ED573', '#FFA502',
37
- '#FF6348', '#1E90FF', '#FF1493', '#32CD32', '#FFD700',
38
- '#FF4500', '#DA70D6', '#00CED1', '#FF69B4', '#7B68EE'
39
- ]
40
-
41
- class HybridNERManager:
42
- def __init__(self):
43
- self.gliner_model = None
44
- self.spacy_model = None
45
- self.flair_models = {}
46
- self.all_entity_colors = {}
47
- self.model_names = [
48
- 'entities_flair_ner-large',
49
- 'entities_spacy_en_core_web_trf',
50
- 'entities_flair_ner-ontonotes-large',
51
- 'entities_gliner_knowledgator/modern-gliner-bi-large-v1.0'
52
- ]
53
-
54
- def load_model(self, model_name):
55
- """Load the specified model"""
56
- try:
57
- if 'spacy' in model_name:
58
- return self.load_spacy_model()
59
- elif 'flair' in model_name:
60
- return self.load_flair_model(model_name)
61
- elif 'gliner' in model_name:
62
- return self.load_gliner_model()
63
- except Exception as e:
64
- print(f"Error loading {model_name}: {str(e)}")
65
- return None
66
-
67
- def load_spacy_model(self):
68
- """Load spaCy model for common NER"""
69
- if self.spacy_model is None:
70
- try:
71
- import spacy
72
- try:
73
- # Try transformer model first, fallback to small model
74
- self.spacy_model = spacy.load("en_core_web_trf")
75
- print("✓ spaCy transformer model loaded successfully")
76
- except OSError:
77
- try:
78
- self.spacy_model = spacy.load("en_core_web_sm")
79
- print("✓ spaCy common model loaded successfully")
80
- except OSError:
81
- print("spaCy model not found. Using GLiNER for all entity types.")
82
- return None
83
- except Exception as e:
84
- print(f"Error loading spaCy model: {str(e)}")
85
- return None
86
- return self.spacy_model
87
-
88
- def load_flair_model(self, model_name):
89
- """Load Flair models"""
90
- if model_name not in self.flair_models:
91
- try:
92
- from flair.models import SequenceTagger
93
- if 'ontonotes' in model_name:
94
- model = SequenceTagger.load("flair/ner-english-ontonotes-large")
95
- print("✓ Flair OntoNotes model loaded successfully")
96
- else:
97
- model = SequenceTagger.load("flair/ner-english-large")
98
- print("✓ Flair large model loaded successfully")
99
- self.flair_models[model_name] = model
100
- except Exception as e:
101
- print(f"Error loading {model_name}: {str(e)}")
102
- # Fallback to GLiNER
103
- return self.load_gliner_model()
104
- return self.flair_models[model_name]
105
-
106
- def load_gliner_model(self):
107
- """Load GLiNER model for custom entities"""
108
- if self.gliner_model is None:
109
- try:
110
- # Try the modern GLiNER model first, fallback to stable model
111
- self.gliner_model = GLiNER.from_pretrained("knowledgator/gliner-bi-large-v1.0")
112
- print("✓ GLiNER knowledgator model loaded successfully")
113
- except Exception as e:
114
- print(f"Primary GLiNER model failed: {str(e)}")
115
- try:
116
- # Fallback to stable model
117
- self.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
118
- print("✓ GLiNER fallback model loaded successfully")
119
- except Exception as e2:
120
- print(f"Error loading GLiNER model: {str(e2)}")
121
- return None
122
- return self.gliner_model
123
-
124
- def assign_colours(self, standard_entities, custom_entities):
125
- """Assign colours to all entity types"""
126
- self.all_entity_colors = {}
127
-
128
- # Assign common colours
129
- for entity in standard_entities:
130
- # Handle the special case of "Work of Art"
131
- colour_key = "WORK OF ART" if entity == "Work of Art" else entity.upper()
132
- self.all_entity_colors[entity.upper()] = STANDARD_COLORS.get(colour_key, '#CCCCCC')
133
-
134
- # Assign custom colours
135
- for i, entity in enumerate(custom_entities):
136
- if i < len(CUSTOM_COLOR_PALETTE):
137
- self.all_entity_colors[entity.upper()] = CUSTOM_COLOR_PALETTE[i]
138
- else:
139
- # Generate random colour if we run out
140
- self.all_entity_colors[entity.upper()] = f"#{random.randint(0, 0xFFFFFF):06x}"
141
-
142
- return self.all_entity_colors
143
-
144
- def extract_entities_by_model(self, text, entity_types, model_name, threshold=0.3):
145
- """Extract entities using the specified model"""
146
- if 'spacy' in model_name:
147
- return self.extract_spacy_entities(text, entity_types)
148
- elif 'flair' in model_name:
149
- return self.extract_flair_entities(text, entity_types, model_name)
150
- elif 'gliner' in model_name:
151
- return self.extract_gliner_entities(text, entity_types, threshold, is_custom=False)
152
- else:
153
- return []
154
-
155
- def extract_spacy_entities(self, text, entity_types):
156
- """Extract entities using spaCy"""
157
- model = self.load_spacy_model()
158
- if model is None:
159
- return []
160
-
161
- try:
162
- doc = model(text)
163
- entities = []
164
- for ent in doc.ents:
165
- if ent.label_ in entity_types:
166
- entities.append({
167
- 'text': ent.text,
168
- 'label': ent.label_,
169
- 'start': ent.start_char,
170
- 'end': ent.end_char,
171
- 'confidence': 1.0, # spaCy doesn't provide confidence scores
172
- 'source': 'spaCy'
173
- })
174
- return entities
175
- except Exception as e:
176
- print(f"Error with spaCy extraction: {str(e)}")
177
- return []
178
-
179
- def extract_flair_entities(self, text, entity_types, model_name):
180
- """Extract entities using Flair"""
181
- model = self.load_flair_model(model_name)
182
- if model is None:
183
- return []
184
-
185
- try:
186
- from flair.data import Sentence
187
- sentence = Sentence(text)
188
- model.predict(sentence)
189
- entities = []
190
- for entity in sentence.get_spans('ner'):
191
- # Map Flair labels to our common set
192
- label = entity.tag
193
- if label == 'PERSON':
194
- label = 'PER'
195
- elif label == 'ORGANIZATION':
196
- label = 'ORG'
197
- elif label == 'LOCATION':
198
- label = 'LOC'
199
- elif label == 'MISCELLANEOUS':
200
- label = 'MISC'
201
-
202
- if label in entity_types:
203
- entities.append({
204
- 'text': entity.text,
205
- 'label': label,
206
- 'start': entity.start_position,
207
- 'end': entity.end_position,
208
- 'confidence': entity.score,
209
- 'source': f'Flair-{model_name.split("-")[-1]}'
210
- })
211
- return entities
212
- except Exception as e:
213
- print(f"Error with Flair extraction: {str(e)}")
214
- return []
215
-
216
- def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True):
217
- """Extract entities using GLiNER"""
218
- model = self.load_gliner_model()
219
- if model is None:
220
- return []
221
-
222
- try:
223
- entities = model.predict_entities(text, entity_types, threshold=threshold)
224
- result = []
225
- for entity in entities:
226
- result.append({
227
- 'text': entity['text'],
228
- 'label': entity['label'].upper(),
229
- 'start': entity['start'],
230
- 'end': entity['end'],
231
- 'confidence': entity.get('score', 0.0),
232
- 'source': 'GLiNER-Custom' if is_custom else 'GLiNER-Common'
233
- })
234
- return result
235
- except Exception as e:
236
- print(f"Error with GLiNER extraction: {str(e)}")
237
- return []
238
-
239
- def find_overlapping_entities(entities):
240
- """Find and share overlapping entities - specifically entities found by BOTH common NER models AND custom entities"""
241
  if not entities:
242
- return []
243
-
244
- # Sort entities by start position
245
- sorted_entities = sorted(entities, key=lambda x: x['start'])
246
- shared_entities = []
247
-
248
- i = 0
249
- while i < len(sorted_entities):
250
- current_entity = sorted_entities[i]
251
- overlapping_entities = [current_entity]
252
-
253
- # Find all entities that overlap with current entity
254
- j = i + 1
255
- while j < len(sorted_entities):
256
- next_entity = sorted_entities[j]
257
-
258
- # Check if entities overlap (same text span or overlapping positions)
259
- if (current_entity['start'] <= next_entity['start'] < current_entity['end'] or
260
- next_entity['start'] <= current_entity['start'] < current_entity['end'] or
261
- current_entity['text'].lower() == next_entity['text'].lower()):
262
- overlapping_entities.append(next_entity)
263
- sorted_entities.pop(j)
264
- else:
265
- j += 1
266
-
267
- # Create shared entity only if we have BOTH common and custom entities
268
- if len(overlapping_entities) == 1:
269
- shared_entities.append(overlapping_entities[0])
270
- else:
271
- # Check if this is a true "shared" entity (common + custom)
272
- has_common = False
273
- has_custom = False
274
-
275
- for entity in overlapping_entities:
276
- source = entity.get('source', '')
277
- if source in ['spaCy', 'GLiNER-Common'] or source.startswith('Flair-'):
278
- has_common = True
279
- elif source == 'GLiNER-Custom':
280
- has_custom = True
281
-
282
- if has_common and has_custom:
283
- # This is a true shared entity (common + custom)
284
- shared_entity = share_entities(overlapping_entities)
285
- shared_entities.append(shared_entity)
286
- else:
287
- # These are just overlapping entities from the same source type, keep separate
288
- shared_entities.extend(overlapping_entities)
289
-
290
- i += 1
291
 
292
- return shared_entities
293
-
294
- def share_entities(entity_list):
295
- """Share multiple overlapping entities into one"""
296
- if len(entity_list) == 1:
297
- return entity_list[0]
298
-
299
- # Use the entity with the longest text span as the base
300
- base_entity = max(entity_list, key=lambda x: len(x['text']))
301
-
302
- # Collect all labels and sources
303
- labels = [entity['label'] for entity in entity_list]
304
- sources = [entity['source'] for entity in entity_list]
305
- confidences = [entity['confidence'] for entity in entity_list]
306
-
307
- return {
308
- 'text': base_entity['text'],
309
- 'start': base_entity['start'],
310
- 'end': base_entity['end'],
311
- 'labels': labels,
312
- 'sources': sources,
313
- 'confidences': confidences,
314
- 'is_shared': True,
315
- 'entity_count': len(entity_list)
316
- }
317
-
318
- def create_highlighted_html(text, entities, entity_colors):
319
- """Create HTML with highlighted entities"""
320
- if not entities:
321
- return f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 5px; background-color: #fafafa;'><p>{text}</p></div>"
322
-
323
- # Find and share overlapping entities
324
  shared_entities = find_overlapping_entities(entities)
325
 
326
- # Sort by start position
327
- sorted_entities = sorted(shared_entities, key=lambda x: x['start'])
328
-
329
- # Create HTML with highlighting
330
- html_parts = []
331
- last_end = 0
332
-
333
- for entity in sorted_entities:
334
- # Add text before entity
335
- html_parts.append(text[last_end:entity['start']])
336
-
337
- if entity.get('is_shared', False):
338
- # Handle shared entity with multiple colours
339
- html_parts.append(create_shared_entity_html(entity, entity_colors))
340
- else:
341
- # Handle single entity
342
- html_parts.append(create_single_entity_html(entity, entity_colors))
343
-
344
- last_end = entity['end']
345
-
346
- # Add remaining text
347
- html_parts.append(text[last_end:])
348
-
349
- highlighted_text = ''.join(html_parts)
350
-
351
- return f"""
352
- <div style='padding: 15px; border: 2px solid #ddd; border-radius: 8px; background-color: #fafafa; margin: 10px 0;'>
353
- <h4 style='margin: 0 0 15px 0; color: #333;'>📝 Text with Highlighted Entities</h4>
354
- <div style='line-height: 1.8; font-size: 16px; background-color: white; padding: 15px; border-radius: 5px;'>{highlighted_text}</div>
355
- </div>
356
- """
357
-
358
- def create_single_entity_html(entity, entity_colors):
359
- """Create HTML for a single entity"""
360
- label = entity['label']
361
- colour = entity_colors.get(label.upper(), '#CCCCCC')
362
- confidence = entity.get('confidence', 0.0)
363
- source = entity.get('source', 'Unknown')
364
-
365
- return (f'<span style="background-color: {colour}; padding: 2px 4px; '
366
- f'border-radius: 3px; margin: 0 1px; '
367
- f'border: 1px solid {colour}; color: white; font-weight: bold;" '
368
- f'title="{label} ({source}) - confidence: {confidence:.2f}">'
369
- f'{entity["text"]}</span>')
370
-
371
- def create_shared_entity_html(entity, entity_colors):
372
- """Create HTML for a shared entity with multiple colours"""
373
- labels = entity['labels']
374
- sources = entity['sources']
375
- confidences = entity['confidences']
376
-
377
- # Get colours for each label
378
- colours = []
379
- for label in labels:
380
- colour = entity_colors.get(label.upper(), '#CCCCCC')
381
- colours.append(colour)
382
-
383
- # Create gradient background
384
- if len(colours) == 2:
385
- gradient = f"linear-gradient(to right, {colours[0]} 50%, {colours[1]} 50%)"
386
- else:
387
- # For more colours, create equal segments
388
- segment_size = 100 / len(colours)
389
- gradient_parts = []
390
- for i, colour in enumerate(colours):
391
- start = i * segment_size
392
- end = (i + 1) * segment_size
393
- gradient_parts.append(f"{colour} {start}%, {colour} {end}%")
394
- gradient = f"linear-gradient(to right, {', '.join(gradient_parts)})"
395
-
396
- # Create tooltip
397
- tooltip_parts = []
398
- for i, label in enumerate(labels):
399
- tooltip_parts.append(f"{label} ({sources[i]}) - {confidences[i]:.2f}")
400
- tooltip = " | ".join(tooltip_parts)
401
-
402
- return (f'<span style="background: {gradient}; padding: 2px 4px; '
403
- f'border-radius: 3px; margin: 0 1px; '
404
- f'border: 2px solid #333; color: white; font-weight: bold;" '
405
- f'title="SHARED: {tooltip}">'
406
- f'{entity["text"]} 🤝</span>')
407
-
408
- def create_entity_table_html(entities_of_type, entity_type, colour, is_shared=False):
409
- """Create HTML table for a specific entity type"""
410
- if is_shared:
411
- header = f"🤝 Shared Entities ({len(entities_of_type)} found)"
412
-
413
- table_html = f"""
414
- <div style="margin: 15px 0;">
415
- <h4 style="color: {colour}; margin-bottom: 15px;">{header}</h4>
416
- <table style="width: 100%; border-collapse: collapse; border: 1px solid #ddd;">
417
- <thead>
418
- <tr style="background-color: {colour}; color: white;">
419
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
420
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">All Labels</th>
421
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Sources</th>
422
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Count</th>
423
- </tr>
424
- </thead>
425
- <tbody>
426
- """
427
-
428
- for entity in entities_of_type:
429
- labels_text = " | ".join(entity['labels'])
430
- sources_text = " | ".join(entity['sources'])
431
-
432
- table_html += f"""
433
- <tr style="background-color: #fff;">
434
- <td style="padding: 10px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
435
- <td style="padding: 10px; border: 1px solid #ddd;">{labels_text}</td>
436
- <td style="padding: 10px; border: 1px solid #ddd;">{sources_text}</td>
437
- <td style="padding: 10px; border: 1px solid #ddd; text-align: center;">
438
- <span style='background-color: #28a745; color: white; padding: 2px 6px; border-radius: 10px; font-size: 11px;'>
439
- {entity['entity_count']}
440
- </span>
441
- </td>
442
- </tr>
443
- """
444
- else:
445
- # Determine if it's common or custom
446
- is_standard = entity_type in STANDARD_ENTITIES
447
- icon = "🎯" if is_standard else "✨"
448
- source_text = "Common NER" if is_standard else "Custom GLiNER"
449
- header = f"{icon} {source_text} - {entity_type} ({len(entities_of_type)} found)"
450
-
451
- table_html = f"""
452
- <div style="margin: 15px 0;">
453
- <h4 style="color: {colour}; margin-bottom: 15px;">{header}</h4>
454
- <table style="width: 100%; border-collapse: collapse; border: 1px solid #ddd;">
455
- <thead>
456
- <tr style="background-color: {colour}; color: white;">
457
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
458
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Confidence</th>
459
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Type</th>
460
- <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Source</th>
461
- </tr>
462
- </thead>
463
- <tbody>
464
- """
465
-
466
- # Sort by confidence score
467
- entities_of_type.sort(key=lambda x: x.get('confidence', 0), reverse=True)
468
-
469
- for entity in entities_of_type:
470
- confidence = entity.get('confidence', 0.0)
471
- confidence_colour = "#28a745" if confidence > 0.7 else "#ffc107" if confidence > 0.4 else "#dc3545"
472
- source = entity.get('source', 'Unknown')
473
- source_badge = f"<span style='background-color: #007bff; color: white; padding: 2px 6px; border-radius: 10px; font-size: 11px;'>{source}</span>"
474
-
475
- table_html += f"""
476
- <tr style="background-color: #fff;">
477
- <td style="padding: 10px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
478
- <td style="padding: 10px; border: 1px solid #ddd;">
479
- <span style="color: {confidence_colour}; font-weight: bold;">
480
- {confidence:.3f}
481
- </span>
482
- </td>
483
- <td style="padding: 10px; border: 1px solid #ddd;">{entity['label']}</td>
484
- <td style="padding: 10px; border: 1px solid #ddd;">{source_badge}</td>
485
- </tr>
486
- """
487
-
488
- table_html += "</tbody></table></div>"
489
- return table_html
490
-
491
- def create_legend_html(entity_colors, standard_entities, custom_entities):
492
- """Create a legend showing entity colours"""
493
- if not entity_colors:
494
- return ""
495
-
496
- html = "<div style='margin: 15px 0; padding: 15px; background-color: #f8f9fa; border-radius: 8px;'>"
497
- html += "<h4 style='margin: 0 0 15px 0;'>🎨 Entity Type Legend</h4>"
498
-
499
- if standard_entities:
500
- html += "<div style='margin-bottom: 15px;'>"
501
- html += "<h5 style='margin: 0 0 8px 0;'>🎯 Common Entities:</h5>"
502
- html += "<div style='display: flex; flex-wrap: wrap; gap: 8px;'>"
503
- for entity_type in standard_entities:
504
- colour = entity_colors.get(entity_type.upper(), '#ccc')
505
- html += f"<span style='background-color: {colour}; padding: 4px 8px; border-radius: 15px; color: white; font-weight: bold; font-size: 12px;'>{entity_type}</span>"
506
- html += "</div></div>"
507
-
508
- if custom_entities:
509
- html += "<div>"
510
- html += "<h5 style='margin: 0 0 8px 0;'>✨ Custom Entities:</h5>"
511
- html += "<div style='display: flex; flex-wrap: wrap; gap: 8px;'>"
512
- for entity_type in custom_entities:
513
- colour = entity_colors.get(entity_type.upper(), '#ccc')
514
- html += f"<span style='background-color: {colour}; padding: 4px 8px; border-radius: 15px; color: white; font-weight: bold; font-size: 12px;'>{entity_type}</span>"
515
- html += "</div></div>"
516
-
517
- html += "</div>"
518
- return html
519
-
520
- # Initialize the NER manager
521
- ner_manager = HybridNERManager()
522
-
523
- def process_text(text, standard_entities, custom_entities_str, confidence_threshold, selected_model, progress=gr.Progress()):
524
- """Main processing function for Gradio interface with progress tracking"""
525
- if not text.strip():
526
- return "❌ Please enter some text to analyse", "", None, None, None, None, None, None, None, None
527
-
528
- progress(0.1, desc="Initialising...")
529
-
530
- # Parse custom entities
531
- custom_entities = []
532
- if custom_entities_str.strip():
533
- custom_entities = [entity.strip() for entity in custom_entities_str.split(',') if entity.strip()]
534
-
535
- # Parse common entities
536
- selected_standard = [entity for entity in standard_entities if entity]
537
-
538
- if not selected_standard and not custom_entities:
539
- return "❌ Please select at least one common entity type OR enter custom entity types", "", None, None, None, None, None, None, None, None
540
-
541
- progress(0.2, desc="Loading models...")
542
-
543
- all_entities = []
544
-
545
- # Extract common entities using selected model
546
- if selected_standard and selected_model:
547
- progress(0.4, desc="Extracting common entities...")
548
- standard_entities_results = ner_manager.extract_entities_by_model(text, selected_standard, selected_model, confidence_threshold)
549
- all_entities.extend(standard_entities_results)
550
-
551
- # Extract custom entities using GLiNER
552
- if custom_entities:
553
- progress(0.6, desc="Extracting custom entities...")
554
- custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True)
555
- all_entities.extend(custom_entity_results)
556
-
557
- if not all_entities:
558
- return "❌ No entities found. Try lowering the confidence threshold or using different entity types.", "", None, None, None, None, None, None, None, None
559
-
560
- progress(0.8, desc="Processing results...")
561
-
562
- # Assign colours
563
- entity_colors = ner_manager.assign_colours(selected_standard, custom_entities)
564
-
565
- # Create outputs
566
- legend_html = create_legend_html(entity_colors, selected_standard, custom_entities)
567
- highlighted_html = create_highlighted_html(text, all_entities, entity_colors)
568
-
569
- # Share overlapping entities
570
- shared_entities = find_overlapping_entities(all_entities)
571
-
572
  # Group entities by type
573
  entity_groups = {}
574
  for entity in shared_entities:
@@ -581,238 +21,134 @@ def process_text(text, standard_entities, custom_entities_str, confidence_thresh
581
  entity_groups[key] = []
582
  entity_groups[key].append(entity)
583
 
584
- progress(0.9, desc="Creating summary...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
- # Create summary
587
- total_entities = len(all_entities)
588
- final_count = len(shared_entities)
589
- shared_count = sum(1 for e in shared_entities if e.get('is_shared', False))
590
-
591
- summary = f"""
592
- ## 📊 Analysis Summary
593
- - **Total entities found:** {total_entities}
594
- - **Final entities displayed:** {final_count}
595
- - **Shared entities:** {shared_count}
596
- - **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f}
597
  """
598
-
599
- progress(1.0, desc="Complete!")
600
-
601
- # Create HTML tables for each entity type
602
- tables = {}
603
 
604
- # Create table for shared entities if any
605
  if 'SHARED_ENTITIES' in entity_groups:
606
- tables['shared'] = create_entity_table_html(entity_groups['SHARED_ENTITIES'], 'SHARED_ENTITIES', '#666666', is_shared=True)
607
- else:
608
- tables['shared'] = None
609
-
610
- # Create tables for other entity types (up to 7 for the interface)
611
- entity_types = [k for k in entity_groups.keys() if k != 'SHARED_ENTITIES']
612
- for i in range(7):
613
- if i < len(entity_types):
614
- entity_type = entity_types[i]
615
- colour = entity_colors.get(entity_type.upper(), '#f0f0f0')
616
- tables[f'tab{i+1}'] = create_entity_table_html(entity_groups[entity_type], entity_type, colour)
617
- else:
618
- tables[f'tab{i+1}'] = None
 
 
 
619
 
620
- return (summary, legend_html + highlighted_html,
621
- tables.get('shared'), tables.get('tab1'), tables.get('tab2'),
622
- tables.get('tab3'), tables.get('tab4'), tables.get('tab5'),
623
- tables.get('tab6'), tables.get('tab7'))
624
-
625
- # Create Gradio interface
626
- def create_interface():
627
- with gr.Blocks(title="Hybrid NER + GLiNER Tool", theme=gr.themes.Soft()) as demo:
628
- gr.Markdown("""
629
- # 🎯 Hybrid NER + Custom GLiNER Entity Recognition Tool
630
-
631
- Combine common NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction.
632
-
633
- ## 🤝 NEW: Overlapping entities are automatically shared with split-colour highlighting!
634
-
635
- ### How to use:
636
- 1. **📝 Enter your text** in the text area below
637
- 2. **🎯 Select a model** from the dropdown for common entities
638
- 3. **☑️ Select common entities** you want to find (PER, ORG, LOC, etc.)
639
- 4. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills"
640
- 5. **⚙️ Adjust confidence threshold**
641
- 6. **🔍 Click "Analyse Text"** to see results with tabbed output
642
- """)
643
-
644
- with gr.Row():
645
- with gr.Column(scale=2):
646
- text_input = gr.Textbox(
647
- label="📝 Text to Analyse",
648
- placeholder="Enter your text here...",
649
- lines=6,
650
- max_lines=10
651
- )
652
 
653
- with gr.Column(scale=1):
654
- confidence_threshold = gr.Slider(
655
- minimum=0.1,
656
- maximum=0.9,
657
- value=0.3,
658
- step=0.1,
659
- label="🎚️ Confidence Threshold"
660
- )
661
-
662
- with gr.Row():
663
- with gr.Column():
664
- gr.Markdown("### 🎯 Common Entity Types")
665
-
666
- # Model selector
667
- model_dropdown = gr.Dropdown(
668
- choices=ner_manager.model_names,
669
- value=ner_manager.model_names[0],
670
- label="Select Model for Common Entities",
671
- info="Choose which model to use for common NER"
672
- )
673
-
674
- # Common entities with select all functionality
675
- standard_entities = gr.CheckboxGroup(
676
- choices=STANDARD_ENTITIES,
677
- value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection
678
- label="Select Common Entities"
679
- )
680
-
681
- # Select/Deselect All button
682
- with gr.Row():
683
- select_all_btn = gr.Button("🔘 Deselect All", size="sm")
684
-
685
- # Function for select/deselect all
686
- def toggle_all_entities(current_selection):
687
- if len(current_selection) > 0:
688
- # If any are selected, deselect all
689
- return [], "☑️ Select All"
690
- else:
691
- # If none selected, select all
692
- return STANDARD_ENTITIES, "🔘 Deselect All"
693
-
694
- select_all_btn.click(
695
- fn=toggle_all_entities,
696
- inputs=[standard_entities],
697
- outputs=[standard_entities, select_all_btn]
698
- )
699
-
700
- with gr.Column():
701
- gr.Markdown("### ✨ Custom Entity Types")
702
- custom_entities = gr.Textbox(
703
- label="Custom Entities (comma-separated)",
704
- placeholder="e.g. relationships, occupations, skills, emotions",
705
- lines=3
706
- )
707
- gr.Markdown("""
708
- **Examples:**
709
- - relationships, occupations, skills
710
- - emotions, actions, objects
711
- - medical conditions, treatments
712
- - financial terms, business roles
713
- """)
714
-
715
- analyse_btn = gr.Button("🔍 Analyse Text", variant="primary", size="lg")
716
-
717
- # Output sections
718
- with gr.Row():
719
- summary_output = gr.Markdown(label="Summary")
720
 
721
- with gr.Row():
722
- highlighted_output = gr.HTML(label="Highlighted Text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
 
724
- # Results section with native Gradio tabs
725
- with gr.Row():
726
- with gr.Column():
727
- gr.Markdown("### 📋 Detailed Results")
728
-
729
- with gr.Tabs() as results_tabs:
730
- # Pre-create tabs for different entity types
731
- with gr.Tab("🤝 Shared", visible=False) as shared_tab:
732
- shared_output = gr.HTML()
733
-
734
- with gr.Tab("Entity Type 1", visible=False) as tab1:
735
- tab1_output = gr.HTML()
736
-
737
- with gr.Tab("Entity Type 2", visible=False) as tab2:
738
- tab2_output = gr.HTML()
739
-
740
- with gr.Tab("Entity Type 3", visible=False) as tab3:
741
- tab3_output = gr.HTML()
742
-
743
- with gr.Tab("Entity Type 4", visible=False) as tab4:
744
- tab4_output = gr.HTML()
745
-
746
- with gr.Tab("Entity Type 5", visible=False) as tab5:
747
- tab5_output = gr.HTML()
748
-
749
- with gr.Tab("Entity Type 6", visible=False) as tab6:
750
- tab6_output = gr.HTML()
751
-
752
- with gr.Tab("Entity Type 7", visible=False) as tab7:
753
- tab7_output = gr.HTML()
754
 
755
- # Connect the button to the processing function
756
- analyse_btn.click(
757
- fn=process_text,
758
- inputs=[
759
- text_input,
760
- standard_entities,
761
- custom_entities,
762
- confidence_threshold,
763
- model_dropdown
764
- ],
765
- outputs=[
766
- summary_output,
767
- highlighted_output,
768
- shared_output,
769
- tab1_output,
770
- tab2_output,
771
- tab3_output,
772
- tab4_output,
773
- tab5_output,
774
- tab6_output,
775
- tab7_output
776
- ]
777
- )
778
 
779
- # Add examples
780
- gr.Examples(
781
- examples=[
782
- [
783
- "John Smith works at Google in New York. He graduated from Stanford University in 2015 and specialises in artificial intelligence research. His wife Sarah is a doctor at Mount Sinai Hospital.",
784
- ["PER", "ORG", "LOC", "DATE"],
785
- "relationships, occupations, educational background",
786
- 0.3,
787
- "entities_spacy_en_core_web_trf"
788
- ],
789
- [
790
- "The meeting between CEO Jane Doe and the board of directors at Microsoft headquarters in Seattle discussed the Q4 financial results and the new AI strategy for 2024.",
791
- ["PER", "ORG", "LOC", "DATE"],
792
- "corporate roles, business events, financial terms",
793
- 0.4,
794
- "entities_flair_ner-ontonotes-large"
795
- ],
796
- [
797
- "Dr. Emily Watson published a research paper on machine learning algorithms at MIT. She collaborates with her colleague Prof. David Chen on natural language processing projects.",
798
- ["PER", "ORG", "Work of Art"],
799
- "academic titles, research topics, collaborations",
800
- 0.3,
801
- "entities_gliner_knowledgator/modern-gliner-bi-large-v1.0"
802
- ]
803
- ],
804
- inputs=[
805
- text_input,
806
- standard_entities,
807
- custom_entities,
808
- confidence_threshold,
809
- model_dropdown
810
- ]
811
- )
812
-
813
- return demo
814
 
815
- if __name__ == "__main__":
816
- demo = create_interface()
817
- demo.launch()
818
-
 
1
+ # Alternative approach using collapsible sections instead of tabs
2
+ # Replace the create_entity_table_gradio_tabs function and the output display section
 
 
 
 
 
 
 
3
 
4
+ def create_entity_results_accordion(entities, entity_colors):
5
+ """Create collapsible accordion-style results instead of tabs"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  if not entities:
7
+ return "<p>No entities found.</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Share overlapping entities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  shared_entities = find_overlapping_entities(entities)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Group entities by type
13
  entity_groups = {}
14
  for entity in shared_entities:
 
21
  entity_groups[key] = []
22
  entity_groups[key].append(entity)
23
 
24
+ if not entity_groups:
25
+ return "<p>No entities found.</p>"
26
+
27
+ # Create accordion HTML
28
+ accordion_html = """
29
+ <style>
30
+ .accordion-item {
31
+ border: 1px solid #ddd;
32
+ border-radius: 8px;
33
+ margin-bottom: 10px;
34
+ overflow: hidden;
35
+ }
36
+ .accordion-header {
37
+ background-color: #f8f9fa;
38
+ padding: 15px;
39
+ cursor: pointer;
40
+ display: flex;
41
+ justify-content: space-between;
42
+ align-items: center;
43
+ transition: background-color 0.3s;
44
+ }
45
+ .accordion-header:hover {
46
+ background-color: #e9ecef;
47
+ }
48
+ .accordion-header.active {
49
+ background-color: #4ECDC4;
50
+ color: white;
51
+ }
52
+ .accordion-content {
53
+ padding: 0 15px;
54
+ max-height: 0;
55
+ overflow: hidden;
56
+ transition: max-height 0.3s ease-out, padding 0.3s ease-out;
57
+ }
58
+ .accordion-content.show {
59
+ padding: 15px;
60
+ max-height: 2000px;
61
+ }
62
+ .entity-badge {
63
+ background-color: #007bff;
64
+ color: white;
65
+ padding: 4px 12px;
66
+ border-radius: 15px;
67
+ font-size: 14px;
68
+ font-weight: bold;
69
+ }
70
+ .confidence-high { color: #28a745; }
71
+ .confidence-medium { color: #ffc107; }
72
+ .confidence-low { color: #dc3545; }
73
+ </style>
74
 
75
+ <div class="accordion-container">
 
 
 
 
 
 
 
 
 
 
76
  """
 
 
 
 
 
77
 
78
+ # Add shared entities section if any
79
  if 'SHARED_ENTITIES' in entity_groups:
80
+ shared_entities_list = entity_groups['SHARED_ENTITIES']
81
+ accordion_html += f"""
82
+ <div class="accordion-item">
83
+ <div class="accordion-header" onclick="toggleAccordion(this)">
84
+ <div>
85
+ <span style="font-size: 20px; margin-right: 10px;">🤝</span>
86
+ <strong>Shared Entities</strong>
87
+ <span class="entity-badge" style="margin-left: 10px;">{len(shared_entities_list)} found</span>
88
+ </div>
89
+ <span>▼</span>
90
+ </div>
91
+ <div class="accordion-content">
92
+ {create_entity_table_html(shared_entities_list, 'SHARED_ENTITIES', '#666666', is_shared=True)}
93
+ </div>
94
+ </div>
95
+ """
96
 
97
+ # Add other entity types
98
+ for entity_type, entities_of_type in entity_groups.items():
99
+ if entity_type == 'SHARED_ENTITIES':
100
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ colour = entity_colors.get(entity_type.upper(), '#f0f0f0')
103
+ is_standard = entity_type in STANDARD_ENTITIES
104
+ icon = "🎯" if is_standard else "✨"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ accordion_html += f"""
107
+ <div class="accordion-item">
108
+ <div class="accordion-header" onclick="toggleAccordion(this)">
109
+ <div>
110
+ <span style="font-size: 20px; margin-right: 10px;">{icon}</span>
111
+ <strong>{entity_type}</strong>
112
+ <span class="entity-badge" style="margin-left: 10px; background-color: {colour};">{len(entities_of_type)} found</span>
113
+ </div>
114
+ <span>▼</span>
115
+ </div>
116
+ <div class="accordion-content">
117
+ {create_entity_table_html(entities_of_type, entity_type, colour)}
118
+ </div>
119
+ </div>
120
+ """
121
+
122
+ accordion_html += """
123
+ </div>
124
+
125
+ <script>
126
+ function toggleAccordion(header) {
127
+ const content = header.nextElementSibling;
128
+ const arrow = header.querySelector('span:last-child');
129
 
130
+ // Toggle active class
131
+ header.classList.toggle('active');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ // Toggle content visibility
134
+ content.classList.toggle('show');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ // Toggle arrow
137
+ arrow.textContent = content.classList.contains('show') ? '▲' : '▼';
138
+ }
139
+
140
+ // Auto-expand first accordion
141
+ document.addEventListener('DOMContentLoaded', function() {
142
+ const firstAccordion = document.querySelector('.accordion-header');
143
+ if (firstAccordion) {
144
+ toggleAccordion(firstAccordion);
145
+ }
146
+ });
147
+ </script>
148
+ """
149
+
150
+ return accordion_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # Then in your process_text function, replace the tab creation with:
153
+ # tab_contents = create_entity_results_accordion(all_entities, entity_colors)
154
+ # And return it as a single HTML output instead of multiple tab outputs