import spacy
import gradio as gr
import json
from typing import Dict, List, Tuple, Any
from zshot import PipelineConfig
from zshot.linker import LinkerSMXM
from zshot.utils.data_models import Entity
from spacy.cli import download
download("en_core_web_sm")
# Function to load the NER model
def load_model(entity_data):
entities = [
Entity(
name=entity["name"],
description=entity["description"],
vocabulary=entity.get("vocabulary")
) for entity in entity_data
]
nlp = spacy.blank("en")
nlp_config = PipelineConfig(
linker=LinkerSMXM(model_name="disi-unibo-nlp/openbioner-base"),
entities=entities,
device='cpu' # Change to 'cpu' if GPU not available
)
nlp.add_pipe("zshot", config=nlp_config, last=True)
return nlp
# Default entities - focusing on BACTERIUM example
default_entities = [
{
"name": "BACTERIUM",
"description": "A bacterium refers to a type of microorganism that can exist as a single cell and may cause infections or play a role in various biological processes. Examples include species like Streptococcus pneumoniae and Streptomyces ahygroscopicus.",
}
]
# Initialize model with default entities
nlp = load_model(default_entities)
# Function to create HTML visualization of entities
def get_entity_html(doc) -> str:
colors = {
"BACTERIUM": "#8dd3c7",
"CHEMICAL": "#fb8072",
"DISEASE": "#80b1d3",
"GENE": "#fdb462",
"SPECIES": "#b3de69"
}
html_parts = []
last_idx = 0
# Display text with highlighted entities
for ent in doc.ents:
# Add text before the entity
html_parts.append(doc.text[last_idx:ent.start_char])
# Add the highlighted entity
color = colors.get(ent.label_, "#ddd")
html_parts.append(
f''
f'{doc.text[ent.start_char:ent.end_char]}'
f'{ent.label_}'
f''
)
# Update the last index
last_idx = ent.end_char
# Add any remaining text
html_parts.append(doc.text[last_idx:])
# Wrap the result in a div with dark theme styling
return f'
{"".join(html_parts)}
'
# Function to get entity details including spans
def get_entity_details(doc) -> List[Dict[str, Any]]:
entity_details = []
for ent in doc.ents:
entity_details.append({
"text": ent.text,
"type": ent.label_,
"start": ent.start_char,
"end": ent.end_char
})
return entity_details
# Main processing function
def process_text(text: str, entities_json: str) -> Tuple[str, List[Dict[str, Any]]]:
global nlp
# Update model if entities have changed
try:
entities = json.loads(entities_json)
nlp = load_model(entities)
except json.JSONDecodeError:
return "Error: Invalid JSON in entity configuration", []
# Process the text with the NER model
doc = nlp(text)
# Generate visualization HTML
html_output = get_entity_html(doc)
# Get detailed entity information including spans
entity_details = get_entity_details(doc)
return html_output, entity_details
# Set theme to dark
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
neutral_hue="slate",
text_size=gr.themes.sizes.text_md,
).set(
body_background_fill="#1a1a1a",
background_fill_primary="#222",
background_fill_secondary="#333",
border_color_primary="#444",
block_background_fill="#222",
block_label_background_fill="#333",
block_label_text_color="#fff",
block_title_text_color="#fff",
body_text_color="#fff",
button_primary_background_fill="#2563eb",
button_primary_text_color="#fff",
input_background_fill="#333",
input_border_color="#555",
input_placeholder_color="#888",
panel_background_fill="#222",
slider_color="#2563eb",
)
# Create Gradio interface with dark theme
with gr.Blocks(title="Named Entity Recognition", theme=theme) as demo:
gr.Markdown("# OpenBioNER - Demo")
# First row: Entity Definitions
with gr.Row():
entities_input = gr.Code(
label="Entity Definitions (JSON)",
language="json",
value=json.dumps(default_entities, indent=2),
lines=6
)
# Second row: Input text and examples side by side
with gr.Row():
# Left side - Input text
with gr.Column():
text_input = gr.Textbox(
label="Text to analyze",
placeholder="Enter text to analyze...",
value="Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase",
lines=3
)
analyze_btn = gr.Button("Analyze Text", variant="primary")
# Right side - Example texts
with gr.Column():
gr.Markdown("### Quick Examples")
example1_btn = gr.Button("E. coli research")
example2_btn = gr.Button("Bacterial infection case")
example3_btn = gr.Button("Multiple bacterial species")
# Third row: Output visualization and spans side by side
with gr.Row():
# Left side - Highlighted text output
with gr.Column():
gr.Markdown("### Recognized Entities")
result_html = gr.HTML()
# Right side - Entity spans details
with gr.Column():
gr.Markdown("### Entity Details with Spans")
entity_details = gr.JSON()
# Set up event handlers for the analyze button
analyze_btn.click(
fn=process_text,
inputs=[text_input, entities_input],
outputs=[result_html, entity_details]
)
# Set up event handlers for example buttons
example1_btn.click(
fn=lambda: "Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase",
inputs=None,
outputs=text_input
)
example2_btn.click(
fn=lambda: "The patient was diagnosed with pneumonia caused by Streptococcus pneumoniae and treated with antibiotics for 7 days.",
inputs=None,
outputs=text_input
)
example3_btn.click(
fn=lambda: "We compared growth rates of E. coli, B. subtilis and S. aureus in various media containing different carbon sources.",
inputs=None,
outputs=text_input
)
if __name__ == "__main__":
demo.launch()