import gradio as gr
import torch
from gliner import GLiNER
import pandas as pd
import warnings
import random
import re
import time
warnings.filterwarnings('ignore')
# Common NER entity types
STANDARD_ENTITIES = [
'DATE', 'EVENT', 'FAC', 'GPE', 'LANG', 'LOC',
'MISC', 'NORP', 'ORG', 'PER', 'PRODUCT', 'Work of Art'
]
# Colour schemes
STANDARD_COLORS = {
'DATE': '#FF6B6B', # Red
'EVENT': '#4ECDC4', # Teal
'FAC': '#45B7D1', # Blue
'GPE': '#F9CA24', # Yellow
'LANG': '#6C5CE7', # Purple
'LOC': '#A0E7E5', # Light Cyan
'MISC': '#FD79A8', # Pink
'NORP': '#8E8E93', # Grey
'ORG': '#55A3FF', # Light Blue
'PER': '#00B894', # Green
'PRODUCT': '#E17055', # Orange-Red
'WORK OF ART': '#DDA0DD' # Plum
}
# Entity definitions for glossary
ENTITY_DEFINITIONS = {
'DATE': 'Absolute or relative dates or periods',
'EVENT': 'Named hurricanes, battles, wars, sports events, etc.',
'FAC': 'Facilities - Buildings, airports, highways, bridges, etc.',
'GPE': 'Geopolitical entities - Countries, cities, states',
'LANG': 'Any named language',
'LOC': 'Non-GPE locations - Mountain ranges, bodies of water',
'MISC': 'Miscellaneous entities - Things that don\'t fit elsewhere',
'NORP': 'Nationalities or religious or political groups',
'ORG': 'Organizations - Companies, agencies, institutions, etc.',
'PER': 'People, including fictional characters',
'PRODUCT': 'Objects, vehicles, foods, etc. (Not services)',
'Work of Art': 'Titles of books, songs, movies, paintings, etc.'
}
# Additional colours for custom entities
CUSTOM_COLOR_PALETTE = [
'#FF9F43', '#10AC84', '#EE5A24', '#0FBC89', '#5F27CD',
'#FF3838', '#2F3640', '#3742FA', '#2ED573', '#FFA502',
'#FF6348', '#1E90FF', '#FF1493', '#32CD32', '#FFD700',
'#FF4500', '#DA70D6', '#00CED1', '#FF69B4', '#7B68EE'
]
class HybridNERManager:
def __init__(self):
self.gliner_model = None
self.spacy_model = None
self.flair_models = {}
self.all_entity_colors = {}
self.model_names = [
'entities_flair_ner-large',
'entities_spacy_en_core_web_trf',
'entities_flair_ner-ontonotes-large',
'entities_gliner_knowledgator/modern-gliner-bi-large-v1.0'
]
def load_model(self, model_name):
"""Load the specified model"""
try:
if 'spacy' in model_name:
return self.load_spacy_model()
elif 'flair' in model_name:
return self.load_flair_model(model_name)
elif 'gliner' in model_name:
return self.load_gliner_model()
except Exception as e:
print(f"Error loading {model_name}: {str(e)}")
return None
def load_spacy_model(self):
"""Load spaCy model for common NER"""
if self.spacy_model is None:
try:
import spacy
try:
# Try transformer model first, fallback to small model
self.spacy_model = spacy.load("en_core_web_trf")
print("✓ spaCy transformer model loaded successfully")
except OSError:
try:
self.spacy_model = spacy.load("en_core_web_sm")
print("✓ spaCy common model loaded successfully")
except OSError:
print("spaCy model not found. Using GLiNER for all entity types.")
return None
except Exception as e:
print(f"Error loading spaCy model: {str(e)}")
return None
return self.spacy_model
def load_flair_model(self, model_name):
"""Load Flair models"""
if model_name not in self.flair_models:
try:
from flair.models import SequenceTagger
if 'ontonotes' in model_name:
model = SequenceTagger.load("flair/ner-english-ontonotes-large")
print("✓ Flair OntoNotes model loaded successfully")
else:
model = SequenceTagger.load("flair/ner-english-large")
print("✓ Flair large model loaded successfully")
self.flair_models[model_name] = model
except Exception as e:
print(f"Error loading {model_name}: {str(e)}")
# Fallback to GLiNER
return self.load_gliner_model()
return self.flair_models[model_name]
def load_gliner_model(self):
"""Load GLiNER model for custom entities"""
if self.gliner_model is None:
try:
# Try the modern GLiNER model first, fallback to stable model
self.gliner_model = GLiNER.from_pretrained("knowledgator/gliner-bi-large-v1.0")
print("✓ GLiNER knowledgator model loaded successfully")
except Exception as e:
print(f"Primary GLiNER model failed: {str(e)}")
try:
# Fallback to stable model
self.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
print("✓ GLiNER fallback model loaded successfully")
except Exception as e2:
print(f"Error loading GLiNER model: {str(e2)}")
return None
return self.gliner_model
def assign_colours(self, standard_entities, custom_entities):
"""Assign colours to all entity types"""
self.all_entity_colors = {}
# Assign common colours
for entity in standard_entities:
# Handle the special case of "Work of Art"
colour_key = "WORK OF ART" if entity == "Work of Art" else entity.upper()
self.all_entity_colors[entity.upper()] = STANDARD_COLORS.get(colour_key, '#CCCCCC')
# Assign custom colours
for i, entity in enumerate(custom_entities):
if i < len(CUSTOM_COLOR_PALETTE):
self.all_entity_colors[entity.upper()] = CUSTOM_COLOR_PALETTE[i]
else:
# Generate random colour if we run out
self.all_entity_colors[entity.upper()] = f"#{random.randint(0, 0xFFFFFF):06x}"
return self.all_entity_colors
def extract_entities_by_model(self, text, entity_types, model_name, threshold=0.3):
"""Extract entities using the specified model"""
if 'spacy' in model_name:
return self.extract_spacy_entities(text, entity_types)
elif 'flair' in model_name:
return self.extract_flair_entities(text, entity_types, model_name)
elif 'gliner' in model_name:
return self.extract_gliner_entities(text, entity_types, threshold, is_custom=False)
else:
return []
def extract_spacy_entities(self, text, entity_types):
"""Extract entities using spaCy"""
model = self.load_spacy_model()
if model is None:
return []
try:
doc = model(text)
entities = []
for ent in doc.ents:
if ent.label_ in entity_types:
entities.append({
'text': ent.text,
'label': ent.label_,
'start': ent.start_char,
'end': ent.end_char,
'confidence': 1.0, # spaCy doesn't provide confidence scores
'source': 'spaCy'
})
return entities
except Exception as e:
print(f"Error with spaCy extraction: {str(e)}")
return []
def extract_flair_entities(self, text, entity_types, model_name):
"""Extract entities using Flair"""
model = self.load_flair_model(model_name)
if model is None:
return []
try:
from flair.data import Sentence
sentence = Sentence(text)
model.predict(sentence)
entities = []
for entity in sentence.get_spans('ner'):
# Map Flair labels to our common set
label = entity.tag
if label == 'PERSON':
label = 'PER'
elif label == 'ORGANIZATION':
label = 'ORG'
elif label == 'LOCATION':
label = 'LOC'
elif label == 'MISCELLANEOUS':
label = 'MISC'
if label in entity_types:
entities.append({
'text': entity.text,
'label': label,
'start': entity.start_position,
'end': entity.end_position,
'confidence': entity.score,
'source': f'Flair-{model_name.split("-")[-1]}'
})
return entities
except Exception as e:
print(f"Error with Flair extraction: {str(e)}")
return []
def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True):
"""Extract entities using GLiNER"""
model = self.load_gliner_model()
if model is None:
return []
try:
entities = model.predict_entities(text, entity_types, threshold=threshold)
result = []
for entity in entities:
result.append({
'text': entity['text'],
'label': entity['label'].upper(),
'start': entity['start'],
'end': entity['end'],
'confidence': entity.get('score', 0.0),
'source': 'GLiNER-Custom' if is_custom else 'GLiNER-Common'
})
return result
except Exception as e:
print(f"Error with GLiNER extraction: {str(e)}")
return []
def find_overlapping_entities(entities):
"""Find and share overlapping entities - specifically entities found by BOTH common NER models AND custom entities"""
if not entities:
return []
# Sort entities by start position
sorted_entities = sorted(entities, key=lambda x: x['start'])
shared_entities = []
i = 0
while i < len(sorted_entities):
current_entity = sorted_entities[i]
overlapping_entities = [current_entity]
# Find all entities that overlap with current entity
j = i + 1
while j < len(sorted_entities):
next_entity = sorted_entities[j]
# Check if entities overlap (same text span or overlapping positions)
if (current_entity['start'] <= next_entity['start'] < current_entity['end'] or
next_entity['start'] <= current_entity['start'] < current_entity['end'] or
current_entity['text'].lower() == next_entity['text'].lower()):
overlapping_entities.append(next_entity)
sorted_entities.pop(j)
else:
j += 1
# Create shared entity only if we have BOTH common and custom entities
if len(overlapping_entities) == 1:
shared_entities.append(overlapping_entities[0])
else:
# Check if this is a true "shared" entity (common + custom)
has_common = False
has_custom = False
for entity in overlapping_entities:
source = entity.get('source', '')
if source in ['spaCy', 'GLiNER-Common'] or source.startswith('Flair-'):
has_common = True
elif source == 'GLiNER-Custom':
has_custom = True
if has_common and has_custom:
# This is a true shared entity (common + custom)
shared_entity = share_entities(overlapping_entities)
shared_entities.append(shared_entity)
else:
# These are just overlapping entities from the same source type, keep separate
shared_entities.extend(overlapping_entities)
i += 1
return shared_entities
def share_entities(entity_list):
"""Share multiple overlapping entities into one"""
if len(entity_list) == 1:
return entity_list[0]
# Use the entity with the longest text span as the base
base_entity = max(entity_list, key=lambda x: len(x['text']))
# Collect all labels and sources
labels = [entity['label'] for entity in entity_list]
sources = [entity['source'] for entity in entity_list]
confidences = [entity['confidence'] for entity in entity_list]
return {
'text': base_entity['text'],
'start': base_entity['start'],
'end': base_entity['end'],
'labels': labels,
'sources': sources,
'confidences': confidences,
'is_shared': True,
'entity_count': len(entity_list)
}
def create_highlighted_html(text, entities, entity_colors):
"""Create HTML with highlighted entities"""
if not entities:
return f"
"
# Find and share overlapping entities
shared_entities = find_overlapping_entities(entities)
# Sort by start position
sorted_entities = sorted(shared_entities, key=lambda x: x['start'])
# Create HTML with highlighting
html_parts = []
last_end = 0
for entity in sorted_entities:
# Add text before entity
html_parts.append(text[last_end:entity['start']])
if entity.get('is_shared', False):
# Handle shared entity with multiple colours
html_parts.append(create_shared_entity_html(entity, entity_colors))
else:
# Handle single entity
html_parts.append(create_single_entity_html(entity, entity_colors))
last_end = entity['end']
# Add remaining text
html_parts.append(text[last_end:])
highlighted_text = ''.join(html_parts)
return f"""
📝 Text with Highlighted Entities
{highlighted_text}
"""
def create_single_entity_html(entity, entity_colors):
"""Create HTML for a single entity"""
label = entity['label']
colour = entity_colors.get(label.upper(), '#CCCCCC')
confidence = entity.get('confidence', 0.0)
source = entity.get('source', 'Unknown')
return (f''
f'{entity["text"]}')
def create_shared_entity_html(entity, entity_colors):
"""Create HTML for a shared entity with multiple colours"""
labels = entity['labels']
sources = entity['sources']
confidences = entity['confidences']
# Get colours for each label
colours = []
for label in labels:
colour = entity_colors.get(label.upper(), '#CCCCCC')
colours.append(colour)
# Create gradient background
if len(colours) == 2:
gradient = f"linear-gradient(to right, {colours[0]} 50%, {colours[1]} 50%)"
else:
# For more colours, create equal segments
segment_size = 100 / len(colours)
gradient_parts = []
for i, colour in enumerate(colours):
start = i * segment_size
end = (i + 1) * segment_size
gradient_parts.append(f"{colour} {start}%, {colour} {end}%")
gradient = f"linear-gradient(to right, {', '.join(gradient_parts)})"
# Create tooltip
tooltip_parts = []
for i, label in enumerate(labels):
tooltip_parts.append(f"{label} ({sources[i]}) - {confidences[i]:.2f}")
tooltip = " | ".join(tooltip_parts)
return (f''
f'{entity["text"]} 🤝')
def create_entity_table_html(entities_of_type, entity_type, colour, is_shared=False):
"""Create HTML table for a specific entity type"""
if is_shared:
table_html = f"""
Entity Text |
All Labels |
Sources |
Count |
"""
for entity in entities_of_type:
labels_text = " | ".join(entity['labels'])
sources_text = " | ".join(entity['sources'])
table_html += f"""
{entity['text']} |
{labels_text} |
{sources_text} |
{entity['entity_count']}
|
"""
else:
table_html = f"""
Entity Text |
Confidence |
Type |
Source |
"""
# Sort by confidence score
entities_of_type.sort(key=lambda x: x.get('confidence', 0), reverse=True)
for entity in entities_of_type:
confidence = entity.get('confidence', 0.0)
confidence_colour = "#28a745" if confidence > 0.7 else "#ffc107" if confidence > 0.4 else "#dc3545"
source = entity.get('source', 'Unknown')
source_badge = f"{source}"
table_html += f"""
{entity['text']} |
{confidence:.3f}
|
{entity['label']} |
{source_badge} |
"""
table_html += "
"
return table_html
def create_all_entity_tables(entities, entity_colors):
"""Create all entity tables in a single container"""
if not entities:
return "No entities found.
"
# Share overlapping entities
shared_entities = find_overlapping_entities(entities)
# Group entities by type
entity_groups = {}
for entity in shared_entities:
if entity.get('is_shared', False):
key = 'SHARED_ENTITIES'
else:
key = entity['label']
if key not in entity_groups:
entity_groups[key] = []
entity_groups[key].append(entity)
if not entity_groups:
return "No entities found.
"
# Create container with all tables
all_tables_html = """
"""
# Create quick navigation
all_tables_html += '
'
all_tables_html += '
Quick Navigation:'
# Sort entity groups to show shared entities first
sorted_groups = []
if 'SHARED_ENTITIES' in entity_groups:
sorted_groups.append(('SHARED_ENTITIES', entity_groups['SHARED_ENTITIES']))
for entity_type, entities_list in sorted(entity_groups.items()):
if entity_type != 'SHARED_ENTITIES':
sorted_groups.append((entity_type, entities_list))
for entity_type, entities_list in sorted_groups:
if entity_type == 'SHARED_ENTITIES':
icon = '🤝'
label = 'Shared'
else:
icon = '🎯' if entity_type in STANDARD_ENTITIES else '✨'
label = entity_type
all_tables_html += f'
{icon} {label} ({len(entities_list)})'
all_tables_html += '
'
# Add shared entities section if any
if 'SHARED_ENTITIES' in entity_groups:
shared_entities_list = entity_groups['SHARED_ENTITIES']
all_tables_html += f"""
{create_entity_table_html(shared_entities_list, 'SHARED_ENTITIES', '#666666', is_shared=True)}
"""
# Add other entity types
for entity_type, entities_of_type in sorted(entity_groups.items()):
if entity_type == 'SHARED_ENTITIES':
continue
colour = entity_colors.get(entity_type.upper(), '#f0f0f0')
is_standard = entity_type in STANDARD_ENTITIES
icon = "🎯" if is_standard else "✨"
type_label = "Common NER" if is_standard else "Custom GLiNER"
all_tables_html += f"""
{create_entity_table_html(entities_of_type, entity_type, colour)}
"""
all_tables_html += "
"
return all_tables_html
def create_legend_html(entity_colors, standard_entities, custom_entities):
"""Create a legend showing entity colours"""
if not entity_colors:
return ""
html = ""
html += "
🎨 Entity Type Legend
"
if standard_entities:
html += "
"
html += "
🎯 Common Entities:
"
html += "
"
for entity_type in standard_entities:
colour = entity_colors.get(entity_type.upper(), '#ccc')
html += f"{entity_type}"
html += "
"
if custom_entities:
html += "
"
html += "
✨ Custom Entities:
"
html += "
"
for entity_type in custom_entities:
colour = entity_colors.get(entity_type.upper(), '#ccc')
html += f"{entity_type}"
html += "
"
html += "
"
return html
# Initialize the NER manager
ner_manager = HybridNERManager()
def process_text(text, standard_entities, custom_entities_str, confidence_threshold, selected_model, progress=gr.Progress()):
"""Main processing function for Gradio interface with progress tracking"""
if not text.strip():
return "❌ Please enter some text to analyse", "", ""
progress(0.1, desc="Initialising...")
# Parse custom entities
custom_entities = []
if custom_entities_str.strip():
custom_entities = [entity.strip() for entity in custom_entities_str.split(',') if entity.strip()]
# Parse common entities
selected_standard = [entity for entity in standard_entities if entity]
if not selected_standard and not custom_entities:
return "❌ Please select at least one common entity type OR enter custom entity types", "", ""
progress(0.2, desc="Loading models...")
all_entities = []
# Extract common entities using selected model
if selected_standard and selected_model:
progress(0.4, desc="Extracting common entities...")
standard_entities_results = ner_manager.extract_entities_by_model(text, selected_standard, selected_model, confidence_threshold)
all_entities.extend(standard_entities_results)
# Extract custom entities using GLiNER
if custom_entities:
progress(0.6, desc="Extracting custom entities...")
custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True)
all_entities.extend(custom_entity_results)
if not all_entities:
return "❌ No entities found. Try lowering the confidence threshold or using different entity types.", "", ""
progress(0.8, desc="Processing results...")
# Assign colours
entity_colors = ner_manager.assign_colours(selected_standard, custom_entities)
# Create outputs
legend_html = create_legend_html(entity_colors, selected_standard, custom_entities)
highlighted_html = create_highlighted_html(text, all_entities, entity_colors)
results_html = create_all_entity_tables(all_entities, entity_colors)
progress(0.9, desc="Creating summary...")
# Create summary with shared entities terminology
total_entities = len(all_entities)
shared_entities = find_overlapping_entities(all_entities)
final_count = len(shared_entities)
shared_count = sum(1 for e in shared_entities if e.get('is_shared', False))
summary = f"""
## 📊 Analysis Summary
- **Total entities found:** {total_entities}
- **Final entities displayed:** {final_count}
- **Shared entities:** {shared_count}
- **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f}
"""
progress(1.0, desc="Complete!")
return summary, legend_html + highlighted_html, results_html
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Hybrid NER + GLiNER Tool", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Named Entity Recognition (NER) Explorer Tool
Combine common NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction.
### How to use:
1. **📝 Enter your text** in the text area below
2. **🎯 Select a model** from the dropdown for common entities
3. **☑️ Select common entities** you want to find (PER, ORG, LOC, etc.)
4. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills" - powered by GLiNER
5. **⚙️ Adjust confidence threshold**
6. **🔍 Click "Analyse Text"** to see results with organized output
(Common/custom entities which overlap are shown with split-colour highlighting)
""")
# Add tip box
gr.HTML("""
💡 Top tip: All models can both miss entities and/or miss categorise entity types - so keep an eye out for this.
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="📝 Text to Analyse",
placeholder="Enter your text here...",
lines=6,
max_lines=10
)
with gr.Column(scale=1):
confidence_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.1,
label="🎚️ Confidence Threshold"
)
with gr.Row():
with gr.Column():
gr.Markdown("### 🎯 Common Entity Types")
# Model selector
model_dropdown = gr.Dropdown(
choices=ner_manager.model_names,
value=ner_manager.model_names[0],
label="Select Model for Common Entities",
info="Choose which model to use for common NER"
)
# Common entities with select all functionality
standard_entities = gr.CheckboxGroup(
choices=STANDARD_ENTITIES,
value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection
label="Select Common Entities"
)
# Select/Deselect All button
with gr.Row():
select_all_btn = gr.Button("🔘 Deselect All", size="sm")
# Function for select/deselect all
def toggle_all_entities(current_selection):
if len(current_selection) > 0:
# If any are selected, deselect all
return [], "☑️ Select All"
else:
# If none selected, select all
return STANDARD_ENTITIES, "🔘 Deselect All"
select_all_btn.click(
fn=toggle_all_entities,
inputs=[standard_entities],
outputs=[standard_entities, select_all_btn]
)
with gr.Column():
gr.Markdown("### ✨ Custom Entity Types (Powered by GLiNER)")
custom_entities = gr.Textbox(
label="Custom Entities (comma-separated)",
placeholder="e.g. relationships, occupations, skills, emotions",
lines=3
)
gr.Markdown("""
**Examples:**
- relationships, occupations, skills
- emotions, actions, objects
- medical conditions, treatments
- financial terms, business roles
*GLiNER model will extract these custom entity types from your text*
""")
# Add glossary here (Option 1: below selection box but above analyse button)
gr.HTML("""
ℹ️ Entity Type Definitions (Click to expand)
- PER:
- People, including fictional characters
- ORG:
- Organizations - Companies, agencies, institutions, etc.
- LOC:
- Non-GPE locations - Mountain ranges, bodies of water
- GPE:
- Geopolitical entities - Countries, cities, states
- FAC:
- Facilities - Buildings, airports, highways, bridges, etc.
- DATE:
- Absolute or relative dates or periods
- EVENT:
- Named hurricanes, battles, wars, sports events, etc.
- NORP:
- Nationalities or religious or political groups
- LANG:
- Any named language
- MISC:
- Miscellaneous entities - Things that don't fit elsewhere
- PRODUCT:
- Objects, vehicles, foods, etc. (Not services)
- Work of Art:
- Titles of books, songs, movies, paintings, etc.
""")
analyse_btn = gr.Button("🔍 Analyse Text", variant="primary", size="lg")
# Output sections
with gr.Row():
summary_output = gr.Markdown(label="Summary")
with gr.Row():
highlighted_output = gr.HTML(label="Highlighted Text")
# Results section
with gr.Row():
with gr.Column():
gr.Markdown("### 📋 Detailed Results")
results_output = gr.HTML(label="Entity Results")
# Connect the button to the processing function
analyse_btn.click(
fn=process_text,
inputs=[
text_input,
standard_entities,
custom_entities,
confidence_threshold,
model_dropdown
],
outputs=[summary_output, highlighted_output, results_output]
)
# Add examples (removed the financial terms example as requested)
gr.Examples(
examples=[
[
"John Smith works at Google in New York. He graduated from Stanford University in 2015 and specialises in artificial intelligence research. His wife Sarah is a doctor at Mount Sinai Hospital.",
["PER", "ORG", "LOC", "DATE"],
"relationships, occupations, educational background",
0.3,
"entities_spacy_en_core_web_trf"
],
[
"Dr. Emily Watson published a research paper on machine learning algorithms at MIT. She collaborates with her colleague Prof. David Chen on natural language processing projects.",
["PER", "ORG", "Work of Art"],
"academic titles, research topics, collaborations",
0.3,
"entities_gliner_knowledgator/modern-gliner-bi-large-v1.0"
]
],
inputs=[
text_input,
standard_entities,
custom_entities,
confidence_threshold,
model_dropdown
]
)
# Add model information links
gr.HTML("""
📚 Model Information & Documentation
Learn more about the models used in this tool:
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()