Spaces:
Sleeping
Sleeping
import spacy | |
import gradio as gr | |
import json | |
from typing import Dict, List, Tuple, Any | |
from zshot import PipelineConfig | |
from zshot.linker import LinkerSMXM | |
from zshot.utils.data_models import Entity | |
from spacy.cli import download | |
download("en_core_web_sm") | |
# Function to load the NER model | |
def load_model(entity_data): | |
entities = [ | |
Entity( | |
name=entity["name"], | |
description=entity["description"], | |
vocabulary=entity.get("vocabulary") | |
) for entity in entity_data | |
] | |
nlp = spacy.blank("en") | |
nlp_config = PipelineConfig( | |
linker=LinkerSMXM(model_name="disi-unibo-nlp/openbioner-base"), | |
entities=entities, | |
device='cpu' # Change to 'cpu' if GPU not available | |
) | |
nlp.add_pipe("zshot", config=nlp_config, last=True) | |
return nlp | |
# Default entities - focusing on BACTERIUM example | |
default_entities = [ | |
{ | |
"name": "BACTERIUM", | |
"description": "A bacterium refers to a type of microorganism that can exist as a single cell and may cause infections or play a role in various biological processes. Examples include species like Streptococcus pneumoniae and Streptomyces ahygroscopicus.", | |
} | |
] | |
# Initialize model with default entities | |
nlp = load_model(default_entities) | |
# Function to create HTML visualization of entities | |
def get_entity_html(doc) -> str: | |
colors = { | |
"BACTERIUM": "#8dd3c7", | |
"CHEMICAL": "#fb8072", | |
"DISEASE": "#80b1d3", | |
"GENE": "#fdb462", | |
"SPECIES": "#b3de69" | |
} | |
html_parts = [] | |
last_idx = 0 | |
# Display text with highlighted entities | |
for ent in doc.ents: | |
# Add text before the entity | |
html_parts.append(doc.text[last_idx:ent.start_char]) | |
# Add the highlighted entity | |
color = colors.get(ent.label_, "#ddd") | |
html_parts.append( | |
f'<span style="background-color: {color}; padding: 0.2em 0.3em; ' | |
f'border-radius: 0.35em; margin: 0 0.1em; font-weight: bold; color: #000;">' | |
f'{doc.text[ent.start_char:ent.end_char]}' | |
f'<span style="font-size: 0.8em; font-weight: bold; margin-left: 0.5em">{ent.label_}</span>' | |
f'</span>' | |
) | |
# Update the last index | |
last_idx = ent.end_char | |
# Add any remaining text | |
html_parts.append(doc.text[last_idx:]) | |
# Wrap the result in a div with dark theme styling | |
return f'<div style="line-height: 1.5; padding: 10px; background: #222; color: #fff; border-radius: 5px;">{"".join(html_parts)}</div>' | |
# Function to get entity details including spans | |
def get_entity_details(doc) -> List[Dict[str, Any]]: | |
entity_details = [] | |
for ent in doc.ents: | |
entity_details.append({ | |
"text": ent.text, | |
"type": ent.label_, | |
"start": ent.start_char, | |
"end": ent.end_char | |
}) | |
return entity_details | |
# Main processing function | |
def process_text(text: str, entities_json: str) -> Tuple[str, List[Dict[str, Any]]]: | |
global nlp | |
# Update model if entities have changed | |
try: | |
entities = json.loads(entities_json) | |
nlp = load_model(entities) | |
except json.JSONDecodeError: | |
return "Error: Invalid JSON in entity configuration", [] | |
# Process the text with the NER model | |
doc = nlp(text) | |
# Generate visualization HTML | |
html_output = get_entity_html(doc) | |
# Get detailed entity information including spans | |
entity_details = get_entity_details(doc) | |
return html_output, entity_details | |
# Set theme to dark | |
theme = gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="slate", | |
neutral_hue="slate", | |
text_size=gr.themes.sizes.text_md, | |
).set( | |
body_background_fill="#1a1a1a", | |
background_fill_primary="#222", | |
background_fill_secondary="#333", | |
border_color_primary="#444", | |
block_background_fill="#222", | |
block_label_background_fill="#333", | |
block_label_text_color="#fff", | |
block_title_text_color="#fff", | |
body_text_color="#fff", | |
button_primary_background_fill="#2563eb", | |
button_primary_text_color="#fff", | |
input_background_fill="#333", | |
input_border_color="#555", | |
input_placeholder_color="#888", | |
panel_background_fill="#222", | |
slider_color="#2563eb", | |
) | |
# Create Gradio interface with dark theme | |
with gr.Blocks(title="Named Entity Recognition", theme=theme) as demo: | |
gr.Markdown("# OpenBioNER - Demo") | |
# First row: Entity Definitions | |
with gr.Row(): | |
entities_input = gr.Code( | |
label="Entity Definitions (JSON)", | |
language="json", | |
value=json.dumps(default_entities, indent=2), | |
lines=6 | |
) | |
# Second row: Input text and examples side by side | |
with gr.Row(): | |
# Left side - Input text | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Text to analyze", | |
placeholder="Enter text to analyze...", | |
value="Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase", | |
lines=3 | |
) | |
analyze_btn = gr.Button("Analyze Text", variant="primary") | |
# Right side - Example texts | |
with gr.Column(): | |
gr.Markdown("### Quick Examples") | |
example1_btn = gr.Button("E. coli research") | |
example2_btn = gr.Button("Bacterial infection case") | |
example3_btn = gr.Button("Multiple bacterial species") | |
# Third row: Output visualization and spans side by side | |
with gr.Row(): | |
# Left side - Highlighted text output | |
with gr.Column(): | |
gr.Markdown("### Recognized Entities") | |
result_html = gr.HTML() | |
# Right side - Entity spans details | |
with gr.Column(): | |
gr.Markdown("### Entity Details with Spans") | |
entity_details = gr.JSON() | |
# Set up event handlers for the analyze button | |
analyze_btn.click( | |
fn=process_text, | |
inputs=[text_input, entities_input], | |
outputs=[result_html, entity_details] | |
) | |
# Set up event handlers for example buttons | |
example1_btn.click( | |
fn=lambda: "Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase", | |
inputs=None, | |
outputs=text_input | |
) | |
example2_btn.click( | |
fn=lambda: "The patient was diagnosed with pneumonia caused by Streptococcus pneumoniae and treated with antibiotics for 7 days.", | |
inputs=None, | |
outputs=text_input | |
) | |
example3_btn.click( | |
fn=lambda: "We compared growth rates of E. coli, B. subtilis and S. aureus in various media containing different carbon sources.", | |
inputs=None, | |
outputs=text_input | |
) | |
if __name__ == "__main__": | |
demo.launch() |