import gradio as gr import torch from gliner import GLiNER import pandas as pd import warnings import random import re warnings.filterwarnings('ignore') # Standard NER entity types STANDARD_ENTITIES = [ 'DATE', 'EVENT', 'FAC', 'GPE', 'LANG', 'LOC', 'MISC', 'NORP', 'ORG', 'PER', 'PRODUCT', 'WORK_OF_ART' ] # Color schemes STANDARD_COLORS = { 'DATE': '#FF6B6B', # Red 'EVENT': '#4ECDC4', # Teal 'FAC': '#45B7D1', # Blue 'GPE': '#F9CA24', # Yellow 'LANG': '#6C5CE7', # Purple 'LOC': '#A0E7E5', # Light Cyan 'MISC': '#FD79A8', # Pink 'NORP': '#8E8E93', # Grey 'ORG': '#55A3FF', # Light Blue 'PER': '#00B894', # Green 'PRODUCT': '#E17055', # Orange-Red 'WORK_OF_ART': '#DDA0DD' # Plum } # Additional colors for custom entities CUSTOM_COLOR_PALETTE = [ '#FF9F43', '#10AC84', '#EE5A24', '#0FBC89', '#5F27CD', '#FF3838', '#2F3640', '#3742FA', '#2ED573', '#FFA502', '#FF6348', '#1E90FF', '#FF1493', '#32CD32', '#FFD700', '#FF4500', '#DA70D6', '#00CED1', '#FF69B4', '#7B68EE' ] class HybridNERManager: def __init__(self): self.gliner_model = None self.spacy_model = None self.flair_models = {} self.all_entity_colors = {} self.model_names = [ 'spacy_en_core_web_sm', 'flair_ner-ontonotes-large', 'flair_ner-large', 'gliner_medium-v2.1' ] def load_gliner_model(self): """Load GLiNER model for custom entities""" if self.gliner_model is None: try: # Use a more stable model for HF Spaces self.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1") print("✓ GLiNER model loaded successfully") except Exception as e: print(f"Error loading GLiNER model: {str(e)}") return None return self.gliner_model def load_model(self, model_name): """Load the specified model""" try: if model_name == 'spacy_en_core_web_sm': return self.load_spacy_model() elif 'flair' in model_name: return self.load_flair_model(model_name) elif 'gliner' in model_name: return self.load_gliner_model() except Exception as e: print(f"Error loading {model_name}: {str(e)}") return None def load_spacy_model(self): """Load spaCy model for standard NER""" if self.spacy_model is None: try: import spacy try: self.spacy_model = spacy.load("en_core_web_sm") print("✓ spaCy model loaded successfully") except OSError: print("spaCy model not found. Using GLiNER for all entity types.") return None except Exception as e: print(f"Error loading spaCy model: {str(e)}") return None return self.spacy_model def load_flair_model(self, model_name): """Load Flair models""" if model_name not in self.flair_models: try: from flair.models import SequenceTagger if 'ontonotes' in model_name: model = SequenceTagger.load("flair/ner-english-ontonotes-large") else: model = SequenceTagger.load("flair/ner-english-large") self.flair_models[model_name] = model print(f"✓ {model_name} loaded successfully") except Exception as e: print(f"Error loading {model_name}: {str(e)}") return None return self.flair_models[model_name] def extract_spacy_entities(self, text, entity_types): """Extract entities using spaCy""" model = self.load_spacy_model() if model is None: return [] try: doc = model(text) entities = [] for ent in doc.ents: if ent.label_ in entity_types: entities.append({ 'text': ent.text, 'label': ent.label_, 'start': ent.start_char, 'end': ent.end_char, 'confidence': 1.0, # spaCy doesn't provide confidence scores 'source': 'spaCy' }) return entities except Exception as e: print(f"Error with spaCy extraction: {str(e)}") return [] def assign_colors(self, standard_entities, custom_entities): """Assign colors to all entity types""" self.all_entity_colors = {} # Assign standard colors for entity in standard_entities: self.all_entity_colors[entity.upper()] = STANDARD_COLORS.get(entity, '#CCCCCC') # Assign custom colors for i, entity in enumerate(custom_entities): if i < len(CUSTOM_COLOR_PALETTE): self.all_entity_colors[entity.upper()] = CUSTOM_COLOR_PALETTE[i] else: # Generate random color if we run out self.all_entity_colors[entity.upper()] = f"#{random.randint(0, 0xFFFFFF):06x}" return self.all_entity_colors def extract_entities_by_model(self, text, entity_types, model_name, threshold=0.3): """Extract entities using the specified model""" if model_name == 'spacy_en_core_web_sm': return self.extract_spacy_entities(text, entity_types) elif 'flair' in model_name: return self.extract_flair_entities(text, entity_types, model_name) elif 'gliner' in model_name: return self.extract_gliner_entities(text, entity_types, threshold, is_custom=False) else: return [] def extract_flair_entities(self, text, entity_types, model_name): """Extract entities using Flair""" model = self.load_flair_model(model_name) if model is None: return [] try: from flair.data import Sentence sentence = Sentence(text) model.predict(sentence) entities = [] for entity in sentence.get_spans('ner'): # Map Flair labels to our standard set label = entity.tag if label == 'PERSON': label = 'PER' elif label == 'ORGANIZATION': label = 'ORG' elif label == 'LOCATION': label = 'LOC' elif label == 'MISCELLANEOUS': label = 'MISC' if label in entity_types: entities.append({ 'text': entity.text, 'label': label, 'start': entity.start_position, 'end': entity.end_position, 'confidence': entity.score, 'source': f'Flair-{model_name.split("-")[-1]}' }) return entities except Exception as e: print(f"Error with Flair extraction: {str(e)}") return [] def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True): """Extract entities using GLiNER""" model = self.load_gliner_model() if model is None: return [] try: entities = model.predict_entities(text, entity_types, threshold=threshold) result = [] for entity in entities: result.append({ 'text': entity['text'], 'label': entity['label'].upper(), 'start': entity['start'], 'end': entity['end'], 'confidence': entity.get('score', 0.0), 'source': 'GLiNER-Custom' if is_custom else 'GLiNER-Standard' }) return result except Exception as e: print(f"Error with GLiNER extraction: {str(e)}") return [] def find_overlapping_entities(entities): """Find and share overlapping entities""" if not entities: return [] # Sort entities by start position sorted_entities = sorted(entities, key=lambda x: x['start']) shared_entities = [] i = 0 while i < len(sorted_entities): current_entity = sorted_entities[i] overlapping_entities = [current_entity] # Find all entities that overlap with current entity j = i + 1 while j < len(sorted_entities): next_entity = sorted_entities[j] # Check if entities overlap if (current_entity['start'] <= next_entity['start'] < current_entity['end'] or next_entity['start'] <= current_entity['start'] < next_entity['end'] or current_entity['text'].lower() == next_entity['text'].lower()): overlapping_entities.append(next_entity) sorted_entities.pop(j) else: j += 1 # Create shared entity if len(overlapping_entities) == 1: shared_entities.append(overlapping_entities[0]) else: shared_entity = share_entities(overlapping_entities) shared_entities.append(shared_entity) i += 1 return shared_entities def share_entities(entity_list): """Share multiple overlapping entities into one""" if len(entity_list) == 1: return entity_list[0] # Use the entity with the longest text span as the base base_entity = max(entity_list, key=lambda x: len(x['text'])) # Collect all labels and sources labels = [entity['label'] for entity in entity_list] sources = [entity['source'] for entity in entity_list] confidences = [entity['confidence'] for entity in entity_list] return { 'text': base_entity['text'], 'start': base_entity['start'], 'end': base_entity['end'], 'labels': labels, 'sources': sources, 'confidences': confidences, 'is_shared': True, 'entity_count': len(entity_list) } def create_highlighted_html(text, entities, entity_colors): """Create HTML with highlighted entities""" if not entities: return f"

{text}

" # Find and share overlapping entities shared_entities = find_overlapping_entities(entities) # Sort by start position sorted_entities = sorted(shared_entities, key=lambda x: x['start']) # Create HTML with highlighting html_parts = [] last_end = 0 for entity in sorted_entities: # Add text before entity html_parts.append(text[last_end:entity['start']]) if entity.get('is_shared', False): # Handle shared entity with multiple colors html_parts.append(create_shared_entity_html(entity, entity_colors)) else: # Handle single entity html_parts.append(create_single_entity_html(entity, entity_colors)) last_end = entity['end'] # Add remaining text html_parts.append(text[last_end:]) highlighted_text = ''.join(html_parts) return f"""

📝 Text with Highlighted Entities

{highlighted_text}
""" def create_single_entity_html(entity, entity_colors): """Create HTML for a single entity""" label = entity['label'] color = entity_colors.get(label.upper(), '#CCCCCC') confidence = entity.get('confidence', 0.0) source = entity.get('source', 'Unknown') return (f'' f'{entity["text"]}') def create_shared_entity_html(entity, entity_colors): """Create HTML for a shared entity with multiple colors""" labels = entity['labels'] sources = entity['sources'] confidences = entity['confidences'] # Get colors for each label colors = [] for label in labels: color = entity_colors.get(label.upper(), '#CCCCCC') colors.append(color) # Create gradient background if len(colors) == 2: gradient = f"linear-gradient(to right, {colors[0]} 50%, {colors[1]} 50%)" else: # For more colors, create equal segments segment_size = 100 / len(colors) gradient_parts = [] for i, color in enumerate(colors): start = i * segment_size end = (i + 1) * segment_size gradient_parts.append(f"{color} {start}%, {color} {end}%") gradient = f"linear-gradient(to right, {', '.join(gradient_parts)})" # Create tooltip tooltip_parts = [] for i, label in enumerate(labels): tooltip_parts.append(f"{label} ({sources[i]}) - {confidences[i]:.2f}") tooltip = " | ".join(tooltip_parts) return (f'' f'{entity["text"]} 🔗') def create_entity_table_html(entities, entity_colors): """Create HTML table with tabbed interface like the original""" if not entities: return "

No entities found.

" # Share overlapping entities shared_entities = find_overlapping_entities(entities) # Group entities by type entity_groups = {} for entity in shared_entities: if entity.get('is_shared', False): key = 'SHARED_ENTITIES' else: key = entity['label'] if key not in entity_groups: entity_groups[key] = [] entity_groups[key].append(entity) if not entity_groups: return "

No entities found.

" # Create tabbed interface tab_html = "
" # Tab headers tab_html += "
" tab_headers = [] for i, entity_type in enumerate(sorted(entity_groups.keys())): count = len(entity_groups[entity_type]) if entity_type == 'SHARED_ENTITIES': color = '#666666' icon = "🔗" display_name = "SHARED" else: color = entity_colors.get(entity_type.upper(), '#f0f0f0') # Determine if it's standard or custom is_standard = entity_type in STANDARD_ENTITIES icon = "🎯" if is_standard else "✨" display_name = entity_type active_style = f"background-color: #f8f9fa; border-bottom: 3px solid {color};" if i == 0 else "background-color: #fff;" tab_headers.append(f""" """) tab_html += ''.join(tab_headers) tab_html += "
" # Tab content for i, entity_type in enumerate(sorted(entity_groups.keys())): entities_of_type = entity_groups[entity_type] display_style = "display: block;" if i == 0 else "display: none;" if entity_type == 'SHARED_ENTITIES': color = '#666666' header_text = f"🔗 Shared Entities ({len(entities_of_type)} found)" else: color = entity_colors.get(entity_type.upper(), '#f0f0f0') source_type = entities_of_type[0].get('source', 'Unknown') is_standard = entity_type in STANDARD_ENTITIES source_icon = "🎯 Standard NER" if is_standard else "✨ Custom GLiNER" header_text = f"{source_icon} - {entity_type} Entities ({len(entities_of_type)} found)" tab_html += f"""

{header_text}

""" if entity_type == 'SHARED_ENTITIES': tab_html += f""" """ for entity in entities_of_type: labels_text = " | ".join(entity['labels']) sources_text = " | ".join(entity['sources']) tab_html += f""" """ else: tab_html += f""" """ # Sort by confidence score entities_of_type.sort(key=lambda x: x.get('confidence', 0), reverse=True) for entity in entities_of_type: confidence = entity.get('confidence', 0.0) confidence_color = "#28a745" if confidence > 0.7 else "#ffc107" if confidence > 0.4 else "#dc3545" source = entity.get('source', 'Unknown') source_badge = f"{source}" tab_html += f""" """ tab_html += """
Entity Text All Labels Sources Count
{entity['text']} {labels_text} {sources_text} {entity['entity_count']}
Entity Text Confidence Type Source
{entity['text']} {confidence:.3f} {entity['label']} {source_badge}
""" # JavaScript for tab switching tab_html += """ """ tab_html += "
" return tab_html def create_legend_html(entity_colors, standard_entities, custom_entities): """Create a legend showing entity colors""" if not entity_colors: return "" html = "
" html += "

🎨 Entity Type Legend

" if standard_entities: html += "
" html += "
🎯 Standard Entities:
" html += "
" for entity_type in standard_entities: color = entity_colors.get(entity_type.upper(), '#ccc') html += f"{entity_type}" html += "
" if custom_entities: html += "
" html += "
✨ Custom Entities:
" html += "
" for entity_type in custom_entities: color = entity_colors.get(entity_type.upper(), '#ccc') html += f"{entity_type}" html += "
" html += "
" return html # Initialize the NER manager ner_manager = HybridNERManager() def process_text(text, standard_entities, custom_entities_str, confidence_threshold, selected_model): """Main processing function for Gradio interface""" if not text.strip(): return "❌ Please enter some text to analyze", "", "" # Parse custom entities custom_entities = [] if custom_entities_str.strip(): custom_entities = [entity.strip() for entity in custom_entities_str.split(',') if entity.strip()] # Parse standard entities selected_standard = [entity for entity in standard_entities if entity] if not selected_standard and not custom_entities: return "❌ Please select at least one standard entity type OR enter custom entity types", "", "" all_entities = [] # Extract standard entities using selected model if selected_standard and selected_model: standard_entities_results = ner_manager.extract_entities_by_model(text, selected_standard, selected_model, confidence_threshold) all_entities.extend(standard_entities_results) # Extract custom entities using GLiNER if custom_entities: custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True) all_entities.extend(custom_entity_results) if not all_entities: return "❌ No entities found. Try lowering the confidence threshold or using different entity types.", "", "" # Assign colors entity_colors = ner_manager.assign_colors(selected_standard, custom_entities) # Create outputs legend_html = create_legend_html(entity_colors, selected_standard, custom_entities) highlighted_html = create_highlighted_html(text, all_entities, entity_colors) table_html = create_entity_table_html(all_entities, entity_colors) # Create summary with shared entities terminology total_entities = len(all_entities) shared_entities = find_overlapping_entities(all_entities) final_count = len(shared_entities) shared_count = sum(1 for e in shared_entities if e.get('is_shared', False)) summary = f""" ## 📊 Analysis Summary - **Total entities found:** {total_entities} - **Final entities displayed:** {final_count} - **Shared entities:** {shared_count} - **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f} """ return summary, legend_html + highlighted_html, table_html # Create Gradio interface def create_interface(): with gr.Blocks(title="Hybrid NER + GLiNER Tool", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎯 Hybrid NER + Custom GLiNER Entity Recognition Tool Combine standard NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction. ## 🔗 NEW: Overlapping entities are automatically shared with split-color highlighting! ### How to use: 1. **📝 Enter your text** in the text area below 2. **🎯 Select a model** from the dropdown for standard entities 3. **☑️ Select standard entities** you want to find (PER, ORG, LOC, etc.) 4. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills" 5. **⚙️ Adjust confidence threshold** 6. **🔍 Click "Analyze Text"** to see results with tabbed output """) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="📝 Text to Analyze", placeholder="Enter your text here...", lines=6, max_lines=10 ) with gr.Column(scale=1): confidence_threshold = gr.Slider( minimum=0.1, maximum=0.9, value=0.3, step=0.1, label="🎚️ Confidence Threshold" ) with gr.Row(): with gr.Column(): gr.Markdown("### 🎯 Standard Entity Types") # Model selector model_dropdown = gr.Dropdown( choices=ner_manager.model_names, value=ner_manager.model_names[0], label="Select Model for Standard Entities", info="Choose which model to use for standard NER" ) # Standard entities with select all functionality standard_entities = gr.CheckboxGroup( choices=STANDARD_ENTITIES, value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection label="Select Standard Entities" ) # Select/Deselect All button with gr.Row(): select_all_btn = gr.Button("🔘 Deselect All", size="sm") # Function for select/deselect all def toggle_all_entities(current_selection): if len(current_selection) > 0: # If any are selected, deselect all return [], "☑️ Select All" else: # If none selected, select all return STANDARD_ENTITIES, "🔘 Deselect All" select_all_btn.click( fn=toggle_all_entities, inputs=[standard_entities], outputs=[standard_entities, select_all_btn] ) with gr.Column(): gr.Markdown("### ✨ Custom Entity Types") custom_entities = gr.Textbox( label="Custom Entities (comma-separated)", placeholder="e.g. relationships, occupations, skills, emotions", lines=3 ) gr.Markdown(""" **Examples:** - relationships, occupations, skills - emotions, actions, objects - medical conditions, treatments - financial terms, business roles """) analyze_btn = gr.Button("🔍 Analyze Text", variant="primary", size="lg") # Output sections with gr.Row(): summary_output = gr.Markdown(label="Summary") with gr.Row(): highlighted_output = gr.HTML(label="Highlighted Text") with gr.Row(): table_output = gr.HTML(label="Detailed Results (Tabbed)") # Connect the button to the processing function analyze_btn.click( fn=process_text, inputs=[ text_input, standard_entities, custom_entities, confidence_threshold, model_dropdown ], outputs=[summary_output, highlighted_output, table_output] ) # Add examples gr.Examples( examples=[ [ "John Smith works at Google in New York. He graduated from Stanford University in 2015 and specializes in artificial intelligence research. His wife Sarah is a doctor at Mount Sinai Hospital.", ["PER", "ORG", "LOC", "DATE"], "relationships, occupations, educational background", 0.3, "spacy_en_core_web_sm" ], [ "The meeting between CEO Jane Doe and the board of directors at Microsoft headquarters in Seattle discussed the Q4 financial results and the new AI strategy for 2024.", ["PER", "ORG", "LOC", "DATE"], "corporate roles, business events, financial terms", 0.4, "flair_ner-ontonotes-large" ], [ "Dr. Emily Watson published a research paper on machine learning algorithms at MIT. She collaborates with her colleague Prof. David Chen on natural language processing projects.", ["PER", "ORG", "WORK_OF_ART"], "academic titles, research topics, collaborations", 0.3, "gliner_medium-v2.1" ] ], inputs=[ text_input, standard_entities, custom_entities, confidence_threshold, model_dropdown ] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()