SorrelC commited on
Commit
59c2efc
Β·
verified Β·
1 Parent(s): 82abc14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +881 -125
app.py CHANGED
@@ -1,10 +1,414 @@
1
- # Alternative approach using collapsible sections instead of tabs
2
- # Replace the create_entity_table_gradio_tabs function and the output display section
 
 
 
 
 
 
 
3
 
4
- def create_entity_results_accordion(entities, entity_colors):
5
- """Create collapsible accordion-style results instead of tabs"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  if not entities:
7
- return "<p>No entities found.</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Share overlapping entities
10
  shared_entities = find_overlapping_entities(entities)
@@ -22,133 +426,485 @@ def create_entity_results_accordion(entities, entity_colors):
22
  entity_groups[key].append(entity)
23
 
24
  if not entity_groups:
25
- return "<p>No entities found.</p>"
26
-
27
- # Create accordion HTML
28
- accordion_html = """
29
- <style>
30
- .accordion-item {
31
- border: 1px solid #ddd;
32
- border-radius: 8px;
33
- margin-bottom: 10px;
34
- overflow: hidden;
35
- }
36
- .accordion-header {
37
- background-color: #f8f9fa;
38
- padding: 15px;
39
- cursor: pointer;
40
- display: flex;
41
- justify-content: space-between;
42
- align-items: center;
43
- transition: background-color 0.3s;
44
- }
45
- .accordion-header:hover {
46
- background-color: #e9ecef;
47
- }
48
- .accordion-header.active {
49
- background-color: #4ECDC4;
50
- color: white;
51
- }
52
- .accordion-content {
53
- padding: 0 15px;
54
- max-height: 0;
55
- overflow: hidden;
56
- transition: max-height 0.3s ease-out, padding 0.3s ease-out;
57
- }
58
- .accordion-content.show {
59
- padding: 15px;
60
- max-height: 2000px;
61
- }
62
- .entity-badge {
63
- background-color: #007bff;
64
- color: white;
65
- padding: 4px 12px;
66
- border-radius: 15px;
67
- font-size: 14px;
68
- font-weight: bold;
69
- }
70
- .confidence-high { color: #28a745; }
71
- .confidence-medium { color: #ffc107; }
72
- .confidence-low { color: #dc3545; }
73
- </style>
74
-
75
- <div class="accordion-container">
76
- """
77
-
78
- # Add shared entities section if any
79
- if 'SHARED_ENTITIES' in entity_groups:
80
- shared_entities_list = entity_groups['SHARED_ENTITIES']
81
- accordion_html += f"""
82
- <div class="accordion-item">
83
- <div class="accordion-header" onclick="toggleAccordion(this)">
84
- <div>
85
- <span style="font-size: 20px; margin-right: 10px;">🀝</span>
86
- <strong>Shared Entities</strong>
87
- <span class="entity-badge" style="margin-left: 10px;">{len(shared_entities_list)} found</span>
88
- </div>
89
- <span>β–Ό</span>
90
- </div>
91
- <div class="accordion-content">
92
- {create_entity_table_html(shared_entities_list, 'SHARED_ENTITIES', '#666666', is_shared=True)}
93
- </div>
94
- </div>
95
- """
96
 
97
- # Add other entity types
98
  for entity_type, entities_of_type in entity_groups.items():
99
  if entity_type == 'SHARED_ENTITIES':
100
- continue
 
101
 
102
- colour = entity_colors.get(entity_type.upper(), '#f0f0f0')
103
- is_standard = entity_type in STANDARD_ENTITIES
104
- icon = "🎯" if is_standard else "✨"
105
-
106
- accordion_html += f"""
107
- <div class="accordion-item">
108
- <div class="accordion-header" onclick="toggleAccordion(this)">
109
- <div>
110
- <span style="font-size: 20px; margin-right: 10px;">{icon}</span>
111
- <strong>{entity_type}</strong>
112
- <span class="entity-badge" style="margin-left: 10px; background-color: {colour};">{len(entities_of_type)} found</span>
113
- </div>
114
- <span>β–Ό</span>
115
- </div>
116
- <div class="accordion-content">
117
- {create_entity_table_html(entities_of_type, entity_type, colour)}
118
- </div>
119
- </div>
120
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- accordion_html += """
123
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- <script>
126
- function toggleAccordion(header) {
127
- const content = header.nextElementSibling;
128
- const arrow = header.querySelector('span:last-child');
129
-
130
- // Toggle active class
131
- header.classList.toggle('active');
132
-
133
- // Toggle content visibility
134
- content.classList.toggle('show');
135
-
136
- // Toggle arrow
137
- arrow.textContent = content.classList.contains('show') ? 'β–²' : 'β–Ό';
138
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- // Auto-expand first accordion
141
- document.addEventListener('DOMContentLoaded', function() {
142
- const firstAccordion = document.querySelector('.accordion-header');
143
- if (firstAccordion) {
144
- toggleAccordion(firstAccordion);
145
- }
146
- });
147
- </script>
 
 
 
 
 
148
  """
 
 
149
 
150
- return accordion_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- # Then in your process_text function, replace the tab creation with:
153
- # tab_contents = create_entity_results_accordion(all_entities, entity_colors)
154
- # And return it as a single HTML output instead of multiple tab outputs
 
1
+ import gradio as gr
2
+ import torch
3
+ from gliner import GLiNER
4
+ import pandas as pd
5
+ import warnings
6
+ import random
7
+ import re
8
+ import time
9
+ warnings.filterwarnings('ignore')
10
 
11
+ # Common NER entity types
12
+ STANDARD_ENTITIES = [
13
+ 'DATE', 'EVENT', 'FAC', 'GPE', 'LANG', 'LOC',
14
+ 'MISC', 'NORP', 'ORG', 'PER', 'PRODUCT', 'Work of Art'
15
+ ]
16
+
17
+ # Colour schemes
18
+ STANDARD_COLORS = {
19
+ 'DATE': '#FF6B6B', # Red
20
+ 'EVENT': '#4ECDC4', # Teal
21
+ 'FAC': '#45B7D1', # Blue
22
+ 'GPE': '#F9CA24', # Yellow
23
+ 'LANG': '#6C5CE7', # Purple
24
+ 'LOC': '#A0E7E5', # Light Cyan
25
+ 'MISC': '#FD79A8', # Pink
26
+ 'NORP': '#8E8E93', # Grey
27
+ 'ORG': '#55A3FF', # Light Blue
28
+ 'PER': '#00B894', # Green
29
+ 'PRODUCT': '#E17055', # Orange-Red
30
+ 'WORK OF ART': '#DDA0DD' # Plum
31
+ }
32
+
33
+ # Additional colours for custom entities
34
+ CUSTOM_COLOR_PALETTE = [
35
+ '#FF9F43', '#10AC84', '#EE5A24', '#0FBC89', '#5F27CD',
36
+ '#FF3838', '#2F3640', '#3742FA', '#2ED573', '#FFA502',
37
+ '#FF6348', '#1E90FF', '#FF1493', '#32CD32', '#FFD700',
38
+ '#FF4500', '#DA70D6', '#00CED1', '#FF69B4', '#7B68EE'
39
+ ]
40
+
41
+ class HybridNERManager:
42
+ def __init__(self):
43
+ self.gliner_model = None
44
+ self.spacy_model = None
45
+ self.flair_models = {}
46
+ self.all_entity_colors = {}
47
+ self.model_names = [
48
+ 'entities_flair_ner-large',
49
+ 'entities_spacy_en_core_web_trf',
50
+ 'entities_flair_ner-ontonotes-large',
51
+ 'entities_gliner_knowledgator/modern-gliner-bi-large-v1.0'
52
+ ]
53
+
54
+ def load_model(self, model_name):
55
+ """Load the specified model"""
56
+ try:
57
+ if 'spacy' in model_name:
58
+ return self.load_spacy_model()
59
+ elif 'flair' in model_name:
60
+ return self.load_flair_model(model_name)
61
+ elif 'gliner' in model_name:
62
+ return self.load_gliner_model()
63
+ except Exception as e:
64
+ print(f"Error loading {model_name}: {str(e)}")
65
+ return None
66
+
67
+ def load_spacy_model(self):
68
+ """Load spaCy model for common NER"""
69
+ if self.spacy_model is None:
70
+ try:
71
+ import spacy
72
+ try:
73
+ # Try transformer model first, fallback to small model
74
+ self.spacy_model = spacy.load("en_core_web_trf")
75
+ print("βœ“ spaCy transformer model loaded successfully")
76
+ except OSError:
77
+ try:
78
+ self.spacy_model = spacy.load("en_core_web_sm")
79
+ print("βœ“ spaCy common model loaded successfully")
80
+ except OSError:
81
+ print("spaCy model not found. Using GLiNER for all entity types.")
82
+ return None
83
+ except Exception as e:
84
+ print(f"Error loading spaCy model: {str(e)}")
85
+ return None
86
+ return self.spacy_model
87
+
88
+ def load_flair_model(self, model_name):
89
+ """Load Flair models"""
90
+ if model_name not in self.flair_models:
91
+ try:
92
+ from flair.models import SequenceTagger
93
+ if 'ontonotes' in model_name:
94
+ model = SequenceTagger.load("flair/ner-english-ontonotes-large")
95
+ print("βœ“ Flair OntoNotes model loaded successfully")
96
+ else:
97
+ model = SequenceTagger.load("flair/ner-english-large")
98
+ print("βœ“ Flair large model loaded successfully")
99
+ self.flair_models[model_name] = model
100
+ except Exception as e:
101
+ print(f"Error loading {model_name}: {str(e)}")
102
+ # Fallback to GLiNER
103
+ return self.load_gliner_model()
104
+ return self.flair_models[model_name]
105
+
106
+ def load_gliner_model(self):
107
+ """Load GLiNER model for custom entities"""
108
+ if self.gliner_model is None:
109
+ try:
110
+ # Try the modern GLiNER model first, fallback to stable model
111
+ self.gliner_model = GLiNER.from_pretrained("knowledgator/gliner-bi-large-v1.0")
112
+ print("βœ“ GLiNER knowledgator model loaded successfully")
113
+ except Exception as e:
114
+ print(f"Primary GLiNER model failed: {str(e)}")
115
+ try:
116
+ # Fallback to stable model
117
+ self.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
118
+ print("βœ“ GLiNER fallback model loaded successfully")
119
+ except Exception as e2:
120
+ print(f"Error loading GLiNER model: {str(e2)}")
121
+ return None
122
+ return self.gliner_model
123
+
124
+ def assign_colours(self, standard_entities, custom_entities):
125
+ """Assign colours to all entity types"""
126
+ self.all_entity_colors = {}
127
+
128
+ # Assign common colours
129
+ for entity in standard_entities:
130
+ # Handle the special case of "Work of Art"
131
+ colour_key = "WORK OF ART" if entity == "Work of Art" else entity.upper()
132
+ self.all_entity_colors[entity.upper()] = STANDARD_COLORS.get(colour_key, '#CCCCCC')
133
+
134
+ # Assign custom colours
135
+ for i, entity in enumerate(custom_entities):
136
+ if i < len(CUSTOM_COLOR_PALETTE):
137
+ self.all_entity_colors[entity.upper()] = CUSTOM_COLOR_PALETTE[i]
138
+ else:
139
+ # Generate random colour if we run out
140
+ self.all_entity_colors[entity.upper()] = f"#{random.randint(0, 0xFFFFFF):06x}"
141
+
142
+ return self.all_entity_colors
143
+
144
+ def extract_entities_by_model(self, text, entity_types, model_name, threshold=0.3):
145
+ """Extract entities using the specified model"""
146
+ if 'spacy' in model_name:
147
+ return self.extract_spacy_entities(text, entity_types)
148
+ elif 'flair' in model_name:
149
+ return self.extract_flair_entities(text, entity_types, model_name)
150
+ elif 'gliner' in model_name:
151
+ return self.extract_gliner_entities(text, entity_types, threshold, is_custom=False)
152
+ else:
153
+ return []
154
+
155
+ def extract_spacy_entities(self, text, entity_types):
156
+ """Extract entities using spaCy"""
157
+ model = self.load_spacy_model()
158
+ if model is None:
159
+ return []
160
+
161
+ try:
162
+ doc = model(text)
163
+ entities = []
164
+ for ent in doc.ents:
165
+ if ent.label_ in entity_types:
166
+ entities.append({
167
+ 'text': ent.text,
168
+ 'label': ent.label_,
169
+ 'start': ent.start_char,
170
+ 'end': ent.end_char,
171
+ 'confidence': 1.0, # spaCy doesn't provide confidence scores
172
+ 'source': 'spaCy'
173
+ })
174
+ return entities
175
+ except Exception as e:
176
+ print(f"Error with spaCy extraction: {str(e)}")
177
+ return []
178
+
179
+ def extract_flair_entities(self, text, entity_types, model_name):
180
+ """Extract entities using Flair"""
181
+ model = self.load_flair_model(model_name)
182
+ if model is None:
183
+ return []
184
+
185
+ try:
186
+ from flair.data import Sentence
187
+ sentence = Sentence(text)
188
+ model.predict(sentence)
189
+ entities = []
190
+ for entity in sentence.get_spans('ner'):
191
+ # Map Flair labels to our common set
192
+ label = entity.tag
193
+ if label == 'PERSON':
194
+ label = 'PER'
195
+ elif label == 'ORGANIZATION':
196
+ label = 'ORG'
197
+ elif label == 'LOCATION':
198
+ label = 'LOC'
199
+ elif label == 'MISCELLANEOUS':
200
+ label = 'MISC'
201
+
202
+ if label in entity_types:
203
+ entities.append({
204
+ 'text': entity.text,
205
+ 'label': label,
206
+ 'start': entity.start_position,
207
+ 'end': entity.end_position,
208
+ 'confidence': entity.score,
209
+ 'source': f'Flair-{model_name.split("-")[-1]}'
210
+ })
211
+ return entities
212
+ except Exception as e:
213
+ print(f"Error with Flair extraction: {str(e)}")
214
+ return []
215
+
216
+ def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True):
217
+ """Extract entities using GLiNER"""
218
+ model = self.load_gliner_model()
219
+ if model is None:
220
+ return []
221
+
222
+ try:
223
+ entities = model.predict_entities(text, entity_types, threshold=threshold)
224
+ result = []
225
+ for entity in entities:
226
+ result.append({
227
+ 'text': entity['text'],
228
+ 'label': entity['label'].upper(),
229
+ 'start': entity['start'],
230
+ 'end': entity['end'],
231
+ 'confidence': entity.get('score', 0.0),
232
+ 'source': 'GLiNER-Custom' if is_custom else 'GLiNER-Common'
233
+ })
234
+ return result
235
+ except Exception as e:
236
+ print(f"Error with GLiNER extraction: {str(e)}")
237
+ return []
238
+
239
+ def find_overlapping_entities(entities):
240
+ """Find and share overlapping entities - specifically entities found by BOTH common NER models AND custom entities"""
241
  if not entities:
242
+ return []
243
+
244
+ # Sort entities by start position
245
+ sorted_entities = sorted(entities, key=lambda x: x['start'])
246
+ shared_entities = []
247
+
248
+ i = 0
249
+ while i < len(sorted_entities):
250
+ current_entity = sorted_entities[i]
251
+ overlapping_entities = [current_entity]
252
+
253
+ # Find all entities that overlap with current entity
254
+ j = i + 1
255
+ while j < len(sorted_entities):
256
+ next_entity = sorted_entities[j]
257
+
258
+ # Check if entities overlap (same text span or overlapping positions)
259
+ if (current_entity['start'] <= next_entity['start'] < current_entity['end'] or
260
+ next_entity['start'] <= current_entity['start'] < current_entity['end'] or
261
+ current_entity['text'].lower() == next_entity['text'].lower()):
262
+ overlapping_entities.append(next_entity)
263
+ sorted_entities.pop(j)
264
+ else:
265
+ j += 1
266
+
267
+ # Create shared entity only if we have BOTH common and custom entities
268
+ if len(overlapping_entities) == 1:
269
+ shared_entities.append(overlapping_entities[0])
270
+ else:
271
+ # Check if this is a true "shared" entity (common + custom)
272
+ has_common = False
273
+ has_custom = False
274
+
275
+ for entity in overlapping_entities:
276
+ source = entity.get('source', '')
277
+ if source in ['spaCy', 'GLiNER-Common'] or source.startswith('Flair-'):
278
+ has_common = True
279
+ elif source == 'GLiNER-Custom':
280
+ has_custom = True
281
+
282
+ if has_common and has_custom:
283
+ # This is a true shared entity (common + custom)
284
+ shared_entity = share_entities(overlapping_entities)
285
+ shared_entities.append(shared_entity)
286
+ else:
287
+ # These are just overlapping entities from the same source type, keep separate
288
+ shared_entities.extend(overlapping_entities)
289
+
290
+ i += 1
291
+
292
+ return shared_entities
293
+
294
+ def share_entities(entity_list):
295
+ """Share multiple overlapping entities into one"""
296
+ if len(entity_list) == 1:
297
+ return entity_list[0]
298
+
299
+ # Use the entity with the longest text span as the base
300
+ base_entity = max(entity_list, key=lambda x: len(x['text']))
301
+
302
+ # Collect all labels and sources
303
+ labels = [entity['label'] for entity in entity_list]
304
+ sources = [entity['source'] for entity in entity_list]
305
+ confidences = [entity['confidence'] for entity in entity_list]
306
+
307
+ return {
308
+ 'text': base_entity['text'],
309
+ 'start': base_entity['start'],
310
+ 'end': base_entity['end'],
311
+ 'labels': labels,
312
+ 'sources': sources,
313
+ 'confidences': confidences,
314
+ 'is_shared': True,
315
+ 'entity_count': len(entity_list)
316
+ }
317
+
318
+ def create_highlighted_html(text, entities, entity_colors):
319
+ """Create HTML with highlighted entities"""
320
+ if not entities:
321
+ return f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 5px; background-color: #fafafa;'><p>{text}</p></div>"
322
+
323
+ # Find and share overlapping entities
324
+ shared_entities = find_overlapping_entities(entities)
325
+
326
+ # Sort by start position
327
+ sorted_entities = sorted(shared_entities, key=lambda x: x['start'])
328
+
329
+ # Create HTML with highlighting
330
+ html_parts = []
331
+ last_end = 0
332
+
333
+ for entity in sorted_entities:
334
+ # Add text before entity
335
+ html_parts.append(text[last_end:entity['start']])
336
+
337
+ if entity.get('is_shared', False):
338
+ # Handle shared entity with multiple colours
339
+ html_parts.append(create_shared_entity_html(entity, entity_colors))
340
+ else:
341
+ # Handle single entity
342
+ html_parts.append(create_single_entity_html(entity, entity_colors))
343
+
344
+ last_end = entity['end']
345
+
346
+ # Add remaining text
347
+ html_parts.append(text[last_end:])
348
+
349
+ highlighted_text = ''.join(html_parts)
350
+
351
+ return f"""
352
+ <div style='padding: 15px; border: 2px solid #ddd; border-radius: 8px; background-color: #fafafa; margin: 10px 0;'>
353
+ <h4 style='margin: 0 0 15px 0; color: #333;'>πŸ“ Text with Highlighted Entities</h4>
354
+ <div style='line-height: 1.8; font-size: 16px; background-color: white; padding: 15px; border-radius: 5px;'>{highlighted_text}</div>
355
+ </div>
356
+ """
357
+
358
+ def create_single_entity_html(entity, entity_colors):
359
+ """Create HTML for a single entity"""
360
+ label = entity['label']
361
+ colour = entity_colors.get(label.upper(), '#CCCCCC')
362
+ confidence = entity.get('confidence', 0.0)
363
+ source = entity.get('source', 'Unknown')
364
+
365
+ return (f'<span style="background-color: {colour}; padding: 2px 4px; '
366
+ f'border-radius: 3px; margin: 0 1px; '
367
+ f'border: 1px solid {colour}; color: white; font-weight: bold;" '
368
+ f'title="{label} ({source}) - confidence: {confidence:.2f}">'
369
+ f'{entity["text"]}</span>')
370
+
371
+ def create_shared_entity_html(entity, entity_colors):
372
+ """Create HTML for a shared entity with multiple colours"""
373
+ labels = entity['labels']
374
+ sources = entity['sources']
375
+ confidences = entity['confidences']
376
+
377
+ # Get colours for each label
378
+ colours = []
379
+ for label in labels:
380
+ colour = entity_colors.get(label.upper(), '#CCCCCC')
381
+ colours.append(colour)
382
+
383
+ # Create gradient background
384
+ if len(colours) == 2:
385
+ gradient = f"linear-gradient(to right, {colours[0]} 50%, {colours[1]} 50%)"
386
+ else:
387
+ # For more colours, create equal segments
388
+ segment_size = 100 / len(colours)
389
+ gradient_parts = []
390
+ for i, colour in enumerate(colours):
391
+ start = i * segment_size
392
+ end = (i + 1) * segment_size
393
+ gradient_parts.append(f"{colour} {start}%, {colour} {end}%")
394
+ gradient = f"linear-gradient(to right, {', '.join(gradient_parts)})"
395
+
396
+ # Create tooltip
397
+ tooltip_parts = []
398
+ for i, label in enumerate(labels):
399
+ tooltip_parts.append(f"{label} ({sources[i]}) - {confidences[i]:.2f}")
400
+ tooltip = " | ".join(tooltip_parts)
401
+
402
+ return (f'<span style="background: {gradient}; padding: 2px 4px; '
403
+ f'border-radius: 3px; margin: 0 1px; '
404
+ f'border: 2px solid #333; color: white; font-weight: bold;" '
405
+ f'title="SHARED: {tooltip}">'
406
+ f'{entity["text"]} 🀝</span>')
407
+
408
+ def create_entity_table_gradio_tabs(entities, entity_colors):
409
+ """Create Gradio tabs for entity results"""
410
+ if not entities:
411
+ return "No entities found."
412
 
413
  # Share overlapping entities
414
  shared_entities = find_overlapping_entities(entities)
 
426
  entity_groups[key].append(entity)
427
 
428
  if not entity_groups:
429
+ return "No entities found."
430
+
431
+ # Create content for each tab
432
+ tab_contents = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
 
434
  for entity_type, entities_of_type in entity_groups.items():
435
  if entity_type == 'SHARED_ENTITIES':
436
+ colour = '#666666'
437
+ header = f"🀝 Shared Entities ({len(entities_of_type)} found)"
438
 
439
+ # Create table for shared entities
440
+ table_html = f"""
441
+ <div style="margin: 15px 0;">
442
+ <h4 style="color: {colour}; margin-bottom: 15px;">{header}</h4>
443
+ <table style="width: 100%; border-collapse: collapse; border: 1px solid #ddd;">
444
+ <thead>
445
+ <tr style="background-color: {colour}; color: white;">
446
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
447
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">All Labels</th>
448
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Sources</th>
449
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Count</th>
450
+ </tr>
451
+ </thead>
452
+ <tbody>
453
+ """
454
+
455
+ for entity in entities_of_type:
456
+ labels_text = " | ".join(entity['labels'])
457
+ sources_text = " | ".join(entity['sources'])
458
+
459
+ table_html += f"""
460
+ <tr style="background-color: #fff;">
461
+ <td style="padding: 10px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
462
+ <td style="padding: 10px; border: 1px solid #ddd;">{labels_text}</td>
463
+ <td style="padding: 10px; border: 1px solid #ddd;">{sources_text}</td>
464
+ <td style="padding: 10px; border: 1px solid #ddd; text-align: center;">
465
+ <span style='background-color: #28a745; color: white; padding: 2px 6px; border-radius: 10px; font-size: 11px;'>
466
+ {entity['entity_count']}
467
+ </span>
468
+ </td>
469
+ </tr>
470
+ """
471
+
472
+ table_html += "</tbody></table></div>"
473
+ tab_contents[f"🀝 SHARED ({len(entities_of_type)})"] = table_html
474
+
475
+ else:
476
+ colour = entity_colors.get(entity_type.upper(), '#f0f0f0')
477
+ # Determine if it's common or custom
478
+ is_standard = entity_type in STANDARD_ENTITIES
479
+ icon = "🎯" if is_standard else "✨"
480
+ source_text = "Common NER" if is_standard else "Custom GLiNER"
481
+ header = f"{icon} {source_text} - {entity_type} ({len(entities_of_type)} found)"
482
+
483
+ # Create table for this entity type
484
+ table_html = f"""
485
+ <div style="margin: 15px 0;">
486
+ <h4 style="color: {colour}; margin-bottom: 15px;">{header}</h4>
487
+ <table style="width: 100%; border-collapse: collapse; border: 1px solid #ddd;">
488
+ <thead>
489
+ <tr style="background-color: {colour}; color: white;">
490
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Entity Text</th>
491
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Confidence</th>
492
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Type</th>
493
+ <th style="padding: 12px; text-align: left; border: 1px solid #ddd;">Source</th>
494
+ </tr>
495
+ </thead>
496
+ <tbody>
497
+ """
498
+
499
+ # Sort by confidence score
500
+ entities_of_type.sort(key=lambda x: x.get('confidence', 0), reverse=True)
501
+
502
+ for entity in entities_of_type:
503
+ confidence = entity.get('confidence', 0.0)
504
+ confidence_colour = "#28a745" if confidence > 0.7 else "#ffc107" if confidence > 0.4 else "#dc3545"
505
+ source = entity.get('source', 'Unknown')
506
+ source_badge = f"<span style='background-color: #007bff; color: white; padding: 2px 6px; border-radius: 10px; font-size: 11px;'>{source}</span>"
507
+
508
+ table_html += f"""
509
+ <tr style="background-color: #fff;">
510
+ <td style="padding: 10px; border: 1px solid #ddd; font-weight: bold;">{entity['text']}</td>
511
+ <td style="padding: 10px; border: 1px solid #ddd;">
512
+ <span style="color: {confidence_colour}; font-weight: bold;">
513
+ {confidence:.3f}
514
+ </span>
515
+ </td>
516
+ <td style="padding: 10px; border: 1px solid #ddd;">{entity['label']}</td>
517
+ <td style="padding: 10px; border: 1px solid #ddd;">{source_badge}</td>
518
+ </tr>
519
+ """
520
+
521
+ table_html += "</tbody></table></div>"
522
+ tab_label = f"{icon} {entity_type} ({len(entities_of_type)})"
523
+ tab_contents[tab_label] = table_html
524
 
525
+ return tab_contents
526
+
527
+ def create_legend_html(entity_colors, standard_entities, custom_entities):
528
+ """Create a legend showing entity colours"""
529
+ if not entity_colors:
530
+ return ""
531
+
532
+ html = "<div style='margin: 15px 0; padding: 15px; background-color: #f8f9fa; border-radius: 8px;'>"
533
+ html += "<h4 style='margin: 0 0 15px 0;'>🎨 Entity Type Legend</h4>"
534
+
535
+ if standard_entities:
536
+ html += "<div style='margin-bottom: 15px;'>"
537
+ html += "<h5 style='margin: 0 0 8px 0;'>🎯 Common Entities:</h5>"
538
+ html += "<div style='display: flex; flex-wrap: wrap; gap: 8px;'>"
539
+ for entity_type in standard_entities:
540
+ colour = entity_colors.get(entity_type.upper(), '#ccc')
541
+ html += f"<span style='background-color: {colour}; padding: 4px 8px; border-radius: 15px; color: white; font-weight: bold; font-size: 12px;'>{entity_type}</span>"
542
+ html += "</div></div>"
543
+
544
+ if custom_entities:
545
+ html += "<div>"
546
+ html += "<h5 style='margin: 0 0 8px 0;'>✨ Custom Entities:</h5>"
547
+ html += "<div style='display: flex; flex-wrap: wrap; gap: 8px;'>"
548
+ for entity_type in custom_entities:
549
+ colour = entity_colors.get(entity_type.upper(), '#ccc')
550
+ html += f"<span style='background-color: {colour}; padding: 4px 8px; border-radius: 15px; color: white; font-weight: bold; font-size: 12px;'>{entity_type}</span>"
551
+ html += "</div></div>"
552
+
553
+ html += "</div>"
554
+ return html
555
+
556
+ # Initialize the NER manager
557
+ ner_manager = HybridNERManager()
558
+
559
+ def process_text(text, standard_entities, custom_entities_str, confidence_threshold, selected_model, progress=gr.Progress()):
560
+ """Main processing function for Gradio interface with progress tracking"""
561
+ if not text.strip():
562
+ return "❌ Please enter some text to analyse", "", {}
563
+
564
+ progress(0.1, desc="Initialising...")
565
 
566
+ # Parse custom entities
567
+ custom_entities = []
568
+ if custom_entities_str.strip():
569
+ custom_entities = [entity.strip() for entity in custom_entities_str.split(',') if entity.strip()]
570
+
571
+ # Parse common entities
572
+ selected_standard = [entity for entity in standard_entities if entity]
573
+
574
+ if not selected_standard and not custom_entities:
575
+ return "❌ Please select at least one common entity type OR enter custom entity types", "", {}
576
+
577
+ progress(0.2, desc="Loading models...")
578
+
579
+ all_entities = []
580
+
581
+ # Extract common entities using selected model
582
+ if selected_standard and selected_model:
583
+ progress(0.4, desc="Extracting common entities...")
584
+ standard_entities_results = ner_manager.extract_entities_by_model(text, selected_standard, selected_model, confidence_threshold)
585
+ all_entities.extend(standard_entities_results)
586
+
587
+ # Extract custom entities using GLiNER
588
+ if custom_entities:
589
+ progress(0.6, desc="Extracting custom entities...")
590
+ custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True)
591
+ all_entities.extend(custom_entity_results)
592
+
593
+ if not all_entities:
594
+ return "❌ No entities found. Try lowering the confidence threshold or using different entity types.", "", {}
595
+
596
+ progress(0.8, desc="Processing results...")
597
+
598
+ # Assign colours
599
+ entity_colors = ner_manager.assign_colours(selected_standard, custom_entities)
600
+
601
+ # Create outputs
602
+ legend_html = create_legend_html(entity_colors, selected_standard, custom_entities)
603
+ highlighted_html = create_highlighted_html(text, all_entities, entity_colors)
604
+ tab_contents = create_entity_table_gradio_tabs(all_entities, entity_colors)
605
+
606
+ progress(0.9, desc="Creating summary...")
607
 
608
+ # Create summary with shared entities terminology
609
+ # Note: Shared entities are those found by BOTH common NER models AND custom GLiNER
610
+ total_entities = len(all_entities)
611
+ shared_entities = find_overlapping_entities(all_entities)
612
+ final_count = len(shared_entities)
613
+ shared_count = sum(1 for e in shared_entities if e.get('is_shared', False))
614
+
615
+ summary = f"""
616
+ ## πŸ“Š Analysis Summary
617
+ - **Total entities found:** {total_entities}
618
+ - **Final entities displayed:** {final_count}
619
+ - **Shared entities:** {shared_count}
620
+ - **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f}
621
  """
622
+
623
+ progress(1.0, desc="Complete!")
624
 
625
+ return summary, legend_html + highlighted_html, tab_contents
626
+
627
+ # Create Gradio interface
628
+ def create_interface():
629
+ with gr.Blocks(title="Hybrid NER + GLiNER Tool", theme=gr.themes.Soft()) as demo:
630
+ gr.Markdown("""
631
+ # 🎯 Hybrid NER + Custom GLiNER Entity Recognition Tool
632
+
633
+ Combine common NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction.
634
+
635
+ ## 🀝 NEW: Overlapping entities are automatically shared with split-colour highlighting!
636
+
637
+ ### How to use:
638
+ 1. **πŸ“ Enter your text** in the text area below
639
+ 2. **🎯 Select a model** from the dropdown for common entities
640
+ 3. **β˜‘οΈ Select common entities** you want to find (PER, ORG, LOC, etc.)
641
+ 4. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills"
642
+ 5. **βš™οΈ Adjust confidence threshold**
643
+ 6. **πŸ” Click "Analyse Text"** to see results with tabbed output
644
+ """)
645
+
646
+ with gr.Row():
647
+ with gr.Column(scale=2):
648
+ text_input = gr.Textbox(
649
+ label="πŸ“ Text to Analyse",
650
+ placeholder="Enter your text here...",
651
+ lines=6,
652
+ max_lines=10
653
+ )
654
+
655
+ with gr.Column(scale=1):
656
+ confidence_threshold = gr.Slider(
657
+ minimum=0.1,
658
+ maximum=0.9,
659
+ value=0.3,
660
+ step=0.1,
661
+ label="🎚️ Confidence Threshold"
662
+ )
663
+
664
+ with gr.Row():
665
+ with gr.Column():
666
+ gr.Markdown("### 🎯 Common Entity Types")
667
+
668
+ # Model selector
669
+ model_dropdown = gr.Dropdown(
670
+ choices=ner_manager.model_names,
671
+ value=ner_manager.model_names[0],
672
+ label="Select Model for Common Entities",
673
+ info="Choose which model to use for common NER"
674
+ )
675
+
676
+ # Common entities with select all functionality
677
+ standard_entities = gr.CheckboxGroup(
678
+ choices=STANDARD_ENTITIES,
679
+ value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection
680
+ label="Select Common Entities"
681
+ )
682
+
683
+ # Select/Deselect All button
684
+ with gr.Row():
685
+ select_all_btn = gr.Button("πŸ”˜ Deselect All", size="sm")
686
+
687
+ # Function for select/deselect all
688
+ def toggle_all_entities(current_selection):
689
+ if len(current_selection) > 0:
690
+ # If any are selected, deselect all
691
+ return [], "β˜‘οΈ Select All"
692
+ else:
693
+ # If none selected, select all
694
+ return STANDARD_ENTITIES, "πŸ”˜ Deselect All"
695
+
696
+ select_all_btn.click(
697
+ fn=toggle_all_entities,
698
+ inputs=[standard_entities],
699
+ outputs=[standard_entities, select_all_btn]
700
+ )
701
+
702
+ with gr.Column():
703
+ gr.Markdown("### ✨ Custom Entity Types")
704
+ custom_entities = gr.Textbox(
705
+ label="Custom Entities (comma-separated)",
706
+ placeholder="e.g. relationships, occupations, skills, emotions",
707
+ lines=3
708
+ )
709
+ gr.Markdown("""
710
+ **Examples:**
711
+ - relationships, occupations, skills
712
+ - emotions, actions, objects
713
+ - medical conditions, treatments
714
+ - financial terms, business roles
715
+ """)
716
+
717
+ analyse_btn = gr.Button("πŸ” Analyse Text", variant="primary", size="lg")
718
+
719
+ # Output sections
720
+ with gr.Row():
721
+ summary_output = gr.Markdown(label="Summary")
722
+
723
+ with gr.Row():
724
+ highlighted_output = gr.HTML(label="Highlighted Text")
725
+
726
+ # Create dynamic tabs for results
727
+ results_tabs = gr.State({})
728
+
729
+ def update_tabs(tab_contents):
730
+ """Update the results tabs based on the analysis"""
731
+ if not tab_contents or not isinstance(tab_contents, dict):
732
+ return {gr.HTML("No results to display"): gr.update(visible=True)}
733
+
734
+ # Create tabs dynamically
735
+ tab_components = {}
736
+ for tab_name, content in tab_contents.items():
737
+ tab_components[tab_name] = gr.HTML(content)
738
+
739
+ return tab_components
740
+
741
+ # Results section with tabs
742
+ with gr.Row():
743
+ with gr.Column():
744
+ gr.Markdown("### πŸ“‹ Detailed Results")
745
+
746
+ # We'll update this section dynamically
747
+ results_container = gr.HTML(label="Results")
748
+
749
+ # Function to process and display results
750
+ def process_and_display(text, standard_entities, custom_entities, confidence_threshold, selected_model):
751
+ # Get results from main processing function
752
+ summary, highlighted, tab_contents = process_text(
753
+ text, standard_entities, custom_entities, confidence_threshold, selected_model
754
+ )
755
+
756
+ # Create tabs HTML manually since Gradio dynamic tabs are complex
757
+ if isinstance(tab_contents, dict) and tab_contents:
758
+ # Generate unique IDs to avoid conflicts
759
+ import time
760
+ timestamp = str(int(time.time() * 1000))
761
+
762
+ tabs_html = f"""
763
+ <div style="margin: 20px 0;" id="tab-container-{timestamp}">
764
+ <div style="border-bottom: 2px solid #ddd; margin-bottom: 20px;">
765
+ """
766
+
767
+ # Create tab buttons
768
+ tab_names = list(tab_contents.keys())
769
+ for i, tab_name in enumerate(tab_names):
770
+ active_style = "background-color: #f8f9fa; border-bottom: 3px solid #4ECDC4;" if i == 0 else "background-color: #fff;"
771
+ default_bg = '#f8f9fa' if i == 0 else '#fff'
772
+ tabs_html += f"""
773
+ <button onclick="showResultTab{timestamp}('{i}')" id="result-tab-{timestamp}-{i}"
774
+ style="padding: 12px 24px; margin-right: 5px; border: 1px solid #ddd;
775
+ border-bottom: none; cursor: pointer; font-weight: bold; {active_style}
776
+ transition: all 0.3s ease;"
777
+ onmouseover="this.style.backgroundColor='#e9ecef'"
778
+ onmouseout="this.style.backgroundColor='{default_bg}'">
779
+ {tab_name}
780
+ </button>
781
+ """
782
+
783
+ tabs_html += "</div>"
784
+
785
+ # Create tab content
786
+ for i, (tab_name, content) in enumerate(tab_contents.items()):
787
+ display_style = "display: block;" if i == 0 else "display: none;"
788
+ tabs_html += f"""
789
+ <div id="result-content-{timestamp}-{i}" style="{display_style}">
790
+ {content}
791
+ </div>
792
+ """
793
+
794
+ # Add JavaScript for tab switching with unique function name
795
+ tabs_html += f"""
796
+ <script>
797
+ function showResultTab{timestamp}(tabIndex) {{
798
+ console.log('Tab clicked:', tabIndex);
799
+
800
+ // Hide all content for this specific tab container
801
+ var contents = document.querySelectorAll('[id^="result-content-{timestamp}-"]');
802
+ contents.forEach(function(content) {{
803
+ content.style.display = 'none';
804
+ }});
805
+
806
+ // Reset all tab styles for this specific tab container
807
+ var tabs = document.querySelectorAll('[id^="result-tab-{timestamp}-"]');
808
+ tabs.forEach(function(tab) {{
809
+ tab.style.backgroundColor = '#fff';
810
+ tab.style.borderBottom = 'none';
811
+ }});
812
+
813
+ // Show selected content
814
+ var targetContent = document.getElementById('result-content-{timestamp}-' + tabIndex);
815
+ if (targetContent) {{
816
+ targetContent.style.display = 'block';
817
+ }}
818
+
819
+ // Highlight selected tab
820
+ var activeTab = document.getElementById('result-tab-{timestamp}-' + tabIndex);
821
+ if (activeTab) {{
822
+ activeTab.style.backgroundColor = '#f8f9fa';
823
+ activeTab.style.borderBottom = '3px solid #4ECDC4';
824
+ }}
825
+ }}
826
+
827
+ // Ensure tabs are clickable after DOM load
828
+ document.addEventListener('DOMContentLoaded', function() {{
829
+ var tabs = document.querySelectorAll('[id^="result-tab-{timestamp}-"]');
830
+ tabs.forEach(function(tab, index) {{
831
+ tab.addEventListener('click', function(e) {{
832
+ e.preventDefault();
833
+ showResultTab{timestamp}(index.toString());
834
+ }});
835
+ }});
836
+ }});
837
+
838
+ // Also try immediate setup in case DOM is already loaded
839
+ setTimeout(function() {{
840
+ var tabs = document.querySelectorAll('[id^="result-tab-{timestamp}-"]');
841
+ tabs.forEach(function(tab, index) {{
842
+ tab.onclick = function(e) {{
843
+ e.preventDefault();
844
+ showResultTab{timestamp}(index.toString());
845
+ return false;
846
+ }};
847
+ }});
848
+ }}, 100);
849
+ </script>
850
+ </div>
851
+ """
852
+
853
+ results_display = tabs_html
854
+ else:
855
+ results_display = str(tab_contents) if tab_contents else "No results to display"
856
+
857
+ return summary, highlighted, results_display
858
+
859
+ # Connect the button to the processing function
860
+ analyse_btn.click(
861
+ fn=process_and_display,
862
+ inputs=[
863
+ text_input,
864
+ standard_entities,
865
+ custom_entities,
866
+ confidence_threshold,
867
+ model_dropdown
868
+ ],
869
+ outputs=[summary_output, highlighted_output, results_container]
870
+ )
871
+
872
+ # Add examples
873
+ gr.Examples(
874
+ examples=[
875
+ [
876
+ "John Smith works at Google in New York. He graduated from Stanford University in 2015 and specialises in artificial intelligence research. His wife Sarah is a doctor at Mount Sinai Hospital.",
877
+ ["PER", "ORG", "LOC", "DATE"],
878
+ "relationships, occupations, educational background",
879
+ 0.3,
880
+ "entities_spacy_en_core_web_trf"
881
+ ],
882
+ [
883
+ "The meeting between CEO Jane Doe and the board of directors at Microsoft headquarters in Seattle discussed the Q4 financial results and the new AI strategy for 2024.",
884
+ ["PER", "ORG", "LOC", "DATE"],
885
+ "corporate roles, business events, financial terms",
886
+ 0.4,
887
+ "entities_flair_ner-ontonotes-large"
888
+ ],
889
+ [
890
+ "Dr. Emily Watson published a research paper on machine learning algorithms at MIT. She collaborates with her colleague Prof. David Chen on natural language processing projects.",
891
+ ["PER", "ORG", "Work of Art"],
892
+ "academic titles, research topics, collaborations",
893
+ 0.3,
894
+ "entities_gliner_knowledgator/modern-gliner-bi-large-v1.0"
895
+ ]
896
+ ],
897
+ inputs=[
898
+ text_input,
899
+ standard_entities,
900
+ custom_entities,
901
+ confidence_threshold,
902
+ model_dropdown
903
+ ]
904
+ )
905
+
906
+ return demo
907
 
908
+ if __name__ == "__main__":
909
+ demo = create_interface()
910
+ demo.launch()