import gradio as gr import torch from gliner import GLiNER import pandas as pd import warnings import random import re warnings.filterwarnings('ignore') # Standard NER entity types STANDARD_ENTITIES = [ 'DATE', 'EVENT', 'FAC', 'GPE', 'LANG', 'LOC', 'MISC', 'NORP', 'ORG', 'PER', 'PRODUCT', 'WORK_OF_ART' ] # Color schemes STANDARD_COLORS = { 'DATE': '#FF6B6B', # Red 'EVENT': '#4ECDC4', # Teal 'FAC': '#45B7D1', # Blue 'GPE': '#F9CA24', # Yellow 'LANG': '#6C5CE7', # Purple 'LOC': '#A0E7E5', # Light Cyan 'MISC': '#FD79A8', # Pink 'NORP': '#8E8E93', # Grey 'ORG': '#55A3FF', # Light Blue 'PER': '#00B894', # Green 'PRODUCT': '#E17055', # Orange-Red 'WORK_OF_ART': '#DDA0DD' # Plum } # Additional colors for custom entities CUSTOM_COLOR_PALETTE = [ '#FF9F43', '#10AC84', '#EE5A24', '#0FBC89', '#5F27CD', '#FF3838', '#2F3640', '#3742FA', '#2ED573', '#FFA502', '#FF6348', '#1E90FF', '#FF1493', '#32CD32', '#FFD700', '#FF4500', '#DA70D6', '#00CED1', '#FF69B4', '#7B68EE' ] class HybridNERManager: def __init__(self): self.gliner_model = None self.spacy_model = None self.all_entity_colors = {} def load_gliner_model(self): """Load GLiNER model for custom entities""" if self.gliner_model is None: try: # Use a more stable model for HF Spaces self.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1") print("✓ GLiNER model loaded successfully") except Exception as e: print(f"Error loading GLiNER model: {str(e)}") return None return self.gliner_model def load_spacy_model(self): """Load spaCy model for standard NER""" if self.spacy_model is None: try: import spacy # Try to load the transformer model first, fallback to smaller model try: self.spacy_model = spacy.load("en_core_web_sm") print("✓ spaCy model loaded successfully") except OSError: print("spaCy model not found. Using GLiNER for all entity types.") return None except Exception as e: print(f"Error loading spaCy model: {str(e)}") return None return self.spacy_model def assign_colors(self, standard_entities, custom_entities): """Assign colors to all entity types""" self.all_entity_colors = {} # Assign standard colors for entity in standard_entities: self.all_entity_colors[entity.upper()] = STANDARD_COLORS.get(entity, '#CCCCCC') # Assign custom colors for i, entity in enumerate(custom_entities): if i < len(CUSTOM_COLOR_PALETTE): self.all_entity_colors[entity.upper()] = CUSTOM_COLOR_PALETTE[i] else: # Generate random color if we run out self.all_entity_colors[entity.upper()] = f"#{random.randint(0, 0xFFFFFF):06x}" return self.all_entity_colors def extract_spacy_entities(self, text, entity_types): """Extract entities using spaCy""" model = self.load_spacy_model() if model is None: return [] try: doc = model(text) entities = [] for ent in doc.ents: if ent.label_ in entity_types: entities.append({ 'text': ent.text, 'label': ent.label_, 'start': ent.start_char, 'end': ent.end_char, 'confidence': 1.0, # spaCy doesn't provide confidence scores 'source': 'spaCy' }) return entities except Exception as e: print(f"Error with spaCy extraction: {str(e)}") return [] def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True): """Extract entities using GLiNER""" model = self.load_gliner_model() if model is None: return [] try: entities = model.predict_entities(text, entity_types, threshold=threshold) result = [] for entity in entities: result.append({ 'text': entity['text'], 'label': entity['label'].upper(), 'start': entity['start'], 'end': entity['end'], 'confidence': entity.get('score', 0.0), 'source': 'GLiNER-Custom' if is_custom else 'GLiNER-Standard' }) return result except Exception as e: print(f"Error with GLiNER extraction: {str(e)}") return [] def find_overlapping_entities(entities): """Find and merge overlapping entities""" if not entities: return [] # Sort entities by start position sorted_entities = sorted(entities, key=lambda x: x['start']) merged_entities = [] i = 0 while i < len(sorted_entities): current_entity = sorted_entities[i] overlapping_entities = [current_entity] # Find all entities that overlap with current entity j = i + 1 while j < len(sorted_entities): next_entity = sorted_entities[j] # Check if entities overlap if (current_entity['start'] <= next_entity['start'] < current_entity['end'] or next_entity['start'] <= current_entity['start'] < next_entity['end'] or current_entity['text'].lower() == next_entity['text'].lower()): overlapping_entities.append(next_entity) sorted_entities.pop(j) else: j += 1 # Create merged entity if len(overlapping_entities) == 1: merged_entities.append(overlapping_entities[0]) else: merged_entity = merge_entities(overlapping_entities) merged_entities.append(merged_entity) i += 1 return merged_entities def merge_entities(entity_list): """Merge multiple overlapping entities into one""" if len(entity_list) == 1: return entity_list[0] # Use the entity with the longest text span as the base base_entity = max(entity_list, key=lambda x: len(x['text'])) # Collect all labels and sources labels = [entity['label'] for entity in entity_list] sources = [entity['source'] for entity in entity_list] confidences = [entity['confidence'] for entity in entity_list] return { 'text': base_entity['text'], 'start': base_entity['start'], 'end': base_entity['end'], 'labels': labels, 'sources': sources, 'confidences': confidences, 'is_merged': True, 'entity_count': len(entity_list) } def create_highlighted_html(text, entities, entity_colors): """Create HTML with highlighted entities""" if not entities: return f"

{text}

" # Find and merge overlapping entities merged_entities = find_overlapping_entities(entities) # Sort by start position sorted_entities = sorted(merged_entities, key=lambda x: x['start']) # Create HTML with highlighting html_parts = [] last_end = 0 for entity in sorted_entities: # Add text before entity html_parts.append(text[last_end:entity['start']]) if entity.get('is_merged', False): # Handle merged entity with multiple colors html_parts.append(create_merged_entity_html(entity, entity_colors)) else: # Handle single entity html_parts.append(create_single_entity_html(entity, entity_colors)) last_end = entity['end'] # Add remaining text html_parts.append(text[last_end:]) highlighted_text = ''.join(html_parts) return f"""

📝 Text with Highlighted Entities

{highlighted_text}
""" def create_single_entity_html(entity, entity_colors): """Create HTML for a single entity""" label = entity['label'] color = entity_colors.get(label.upper(), '#CCCCCC') confidence = entity.get('confidence', 0.0) source = entity.get('source', 'Unknown') return (f'' f'{entity["text"]}') def create_merged_entity_html(entity, entity_colors): """Create HTML for a merged entity with multiple colors""" labels = entity['labels'] sources = entity['sources'] confidences = entity['confidences'] # Get colors for each label colors = [] for label in labels: color = entity_colors.get(label.upper(), '#CCCCCC') colors.append(color) # Create gradient background if len(colors) == 2: gradient = f"linear-gradient(to right, {colors[0]} 50%, {colors[1]} 50%)" else: # For more colors, create equal segments segment_size = 100 / len(colors) gradient_parts = [] for i, color in enumerate(colors): start = i * segment_size end = (i + 1) * segment_size gradient_parts.append(f"{color} {start}%, {color} {end}%") gradient = f"linear-gradient(to right, {', '.join(gradient_parts)})" # Create tooltip tooltip_parts = [] for i, label in enumerate(labels): tooltip_parts.append(f"{label} ({sources[i]}) - {confidences[i]:.2f}") tooltip = " | ".join(tooltip_parts) return (f'' f'{entity["text"]} 🔗') def create_entity_table_html(entities, entity_colors): """Create HTML table of entities""" if not entities: return "

No entities found.

" # Merge overlapping entities merged_entities = find_overlapping_entities(entities) # Group entities by type entity_groups = {} for entity in merged_entities: if entity.get('is_merged', False): key = 'MERGED_ENTITIES' else: key = entity['label'] if key not in entity_groups: entity_groups[key] = [] entity_groups[key].append(entity) # Create HTML table html = "
" for entity_type, entities_of_type in entity_groups.items(): if entity_type == 'MERGED_ENTITIES': color = '#666666' header = f"🔗 Merged Entities ({len(entities_of_type)})" else: color = entity_colors.get(entity_type.upper(), '#CCCCCC') header = f"{entity_type} ({len(entities_of_type)})" html += f"""

{header}

""" for entity in entities_of_type: if entity.get('is_merged', False): labels_text = " | ".join(entity['labels']) sources_text = " | ".join(entity['sources']) conf_text = " | ".join([f"{c:.2f}" for c in entity['confidences']]) else: labels_text = entity['label'] sources_text = entity['source'] conf_text = f"{entity['confidence']:.2f}" html += f""" """ html += "
Entity Text Label(s) Source(s) Confidence
{entity['text']} {labels_text} {sources_text} {conf_text}
" html += "
" return html def create_legend_html(entity_colors, standard_entities, custom_entities): """Create a legend showing entity colors""" if not entity_colors: return "" html = "
" html += "

🎨 Entity Type Legend

" if standard_entities: html += "
" html += "
🎯 Standard Entities:
" html += "
" for entity_type in standard_entities: color = entity_colors.get(entity_type.upper(), '#ccc') html += f"{entity_type}" html += "
" if custom_entities: html += "
" html += "
✨ Custom Entities:
" html += "
" for entity_type in custom_entities: color = entity_colors.get(entity_type.upper(), '#ccc') html += f"{entity_type}" html += "
" html += "
" return html # Initialize the NER manager ner_manager = HybridNERManager() def process_text(text, standard_entities, custom_entities_str, confidence_threshold, use_spacy, use_gliner_standard): """Main processing function for Gradio interface""" if not text.strip(): return "❌ Please enter some text to analyze", "", "" # Parse custom entities custom_entities = [] if custom_entities_str.strip(): custom_entities = [entity.strip() for entity in custom_entities_str.split(',') if entity.strip()] # Parse standard entities selected_standard = [entity for entity in standard_entities if entity] if not selected_standard and not custom_entities: return "❌ Please select at least one standard entity type OR enter custom entity types", "", "" all_entities = [] # Extract standard entities if selected_standard: if use_spacy: spacy_entities = ner_manager.extract_spacy_entities(text, selected_standard) all_entities.extend(spacy_entities) if use_gliner_standard: gliner_standard_entities = ner_manager.extract_gliner_entities(text, selected_standard, confidence_threshold, is_custom=False) all_entities.extend(gliner_standard_entities) # Extract custom entities if custom_entities: custom_entity_results = ner_manager.extract_gliner_entities(text, custom_entities, confidence_threshold, is_custom=True) all_entities.extend(custom_entity_results) if not all_entities: return "❌ No entities found. Try lowering the confidence threshold or using different entity types.", "", "" # Assign colors entity_colors = ner_manager.assign_colors(selected_standard, custom_entities) # Create outputs legend_html = create_legend_html(entity_colors, selected_standard, custom_entities) highlighted_html = create_highlighted_html(text, all_entities, entity_colors) table_html = create_entity_table_html(all_entities, entity_colors) # Create summary total_entities = len(all_entities) merged_entities = find_overlapping_entities(all_entities) final_count = len(merged_entities) merged_count = sum(1 for e in merged_entities if e.get('is_merged', False)) summary = f""" ## 📊 Analysis Summary - **Total entities found:** {total_entities} - **Final entities displayed:** {final_count} - **Merged entities:** {merged_count} - **Average confidence:** {sum(e.get('confidence', 0) for e in all_entities) / total_entities:.3f} """ return summary, legend_html + highlighted_html, table_html # Create Gradio interface def create_interface(): with gr.Blocks(title="Hybrid NER + GLiNER Tool", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎯 Hybrid NER + Custom GLiNER Entity Recognition Tool Combine standard NER categories with your own custom entity types! This tool uses both traditional NER models and GLiNER for comprehensive entity extraction. ## 🔗 NEW: Overlapping entities are automatically merged with split-color highlighting! ### How to use: 1. **📝 Enter your text** in the text area below 2. **🎯 Select standard entities** you want to find (PER, ORG, LOC, etc.) 3. **✨ Add custom entities** (comma-separated) like "relationships, occupations, skills" 4. **⚙️ Choose models** and adjust confidence threshold 5. **🔍 Click "Analyze Text"** to see results """) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="📝 Text to Analyze", placeholder="Enter your text here...", lines=6, max_lines=10 ) with gr.Column(scale=1): confidence_threshold = gr.Slider( minimum=0.1, maximum=0.9, value=0.3, step=0.1, label="🎚️ Confidence Threshold" ) with gr.Row(): with gr.Column(): gr.Markdown("### 🎯 Standard Entity Types") standard_entities = gr.CheckboxGroup( choices=STANDARD_ENTITIES, value=['PER', 'ORG', 'LOC', 'MISC'], # Default selection label="Select Standard Entities" ) with gr.Row(): use_spacy = gr.Checkbox(label="Use spaCy", value=True) use_gliner_standard = gr.Checkbox(label="Use GLiNER for Standard", value=False) with gr.Column(): gr.Markdown("### ✨ Custom Entity Types") custom_entities = gr.Textbox( label="Custom Entities (comma-separated)", placeholder="e.g. relationships, occupations, skills, emotions", lines=3 ) gr.Markdown(""" **Examples:** - relationships, occupations, skills - emotions, actions, objects - medical conditions, treatments """) analyze_btn = gr.Button("🔍 Analyze Text", variant="primary", size="lg") # Output sections with gr.Row(): summary_output = gr.Markdown(label="Summary") with gr.Row(): highlighted_output = gr.HTML(label="Highlighted Text") with gr.Row(): table_output = gr.HTML(label="Detailed Results") # Connect the button to the processing function analyze_btn.click( fn=process_text, inputs=[ text_input, standard_entities, custom_entities, confidence_threshold, use_spacy, use_gliner_standard ], outputs=[summary_output, highlighted_output, table_output] ) # Add examples gr.Examples( examples=[ [ "John Smith works at Google in New York. He graduated from Stanford University in 2015 and specializes in artificial intelligence research. His wife Sarah is a doctor at Mount Sinai Hospital.", ["PER", "ORG", "LOC", "DATE"], "relationships, occupations, educational background", 0.3, True, False ], [ "The meeting between CEO Jane Doe and the board of directors at Microsoft headquarters in Seattle discussed the Q4 financial results and the new AI strategy for 2024.", ["PER", "ORG", "LOC", "DATE"], "corporate roles, business events, financial terms", 0.4, True, True ] ], inputs=[ text_input, standard_entities, custom_entities, confidence_threshold, use_spacy, use_gliner_standard ] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()