ayushsinha's picture
Upload 2 files
5f41e26 verified
import gradio as gr
from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast
import torch
# Load Model & Tokenizer
model_name = "AventIQ-AI/distilbert-base-uncased_token_classification"
model = DistilBertForTokenClassification.from_pretrained(model_name)
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
# Define Icon Mapping for Entities
ICON_MAP = {
"Corporation": "🏒",
"Person": "πŸ‘€",
"Product": "πŸ“±",
"Location": "πŸ“",
"Creative-Work": "🎭",
"Group": "πŸ‘₯"
}
def predict_entities(text):
"""Predict Named Entities using the model."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits.float(), dim=2) # Convert logits to float32
predicted_labels = [model.config.id2label[t.item()] for t in predictions[0]]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Process Entities
entities = []
current_entity = None
for token, label in zip(tokens, predicted_labels):
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
continue
if token.startswith("##"):
if current_entity:
current_entity["text"] += token[2:]
continue
if label == "O":
if current_entity:
entities.append(current_entity)
current_entity = None
else:
if label.startswith("B-"):
if current_entity:
entities.append(current_entity)
current_entity = {"text": token, "type": label[2:]}
elif label.startswith("I-") and current_entity:
current_entity["text"] += " " + token
if current_entity:
entities.append(current_entity)
return format_output(text, entities)
def format_output(text, entities):
"""Format output for Gradio UI."""
output = f"πŸ“₯ **Input**: {text}\n\nπŸ” **Detected Entities**:\n"
if not entities:
output += "ℹ️ No named entities detected. Try another sentence!\n"
else:
for entity in entities:
icon = ICON_MAP.get(entity["type"], "πŸ”Ή")
output += f"- {icon} **{entity['text']}** β†’ `{entity['type']}`\n"
return output
# Create Gradio UI
gr.Interface(
fn=predict_entities,
inputs=gr.Textbox(placeholder="Enter text here...", label="Input Text"),
outputs=gr.Textbox(label="NER Output"),
title="πŸ“ Named Entity Recognition (NER)",
description="πŸ” Enter a sentence and the model will detect entities like **Person, Location, Product, etc.**",
theme="default",
allow_flagging="never"
).launch()