|
import gradio as gr
|
|
from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast
|
|
import torch
|
|
|
|
|
|
model_name = "AventIQ-AI/distilbert-base-uncased_token_classification"
|
|
model = DistilBertForTokenClassification.from_pretrained(model_name)
|
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
|
|
|
|
|
ICON_MAP = {
|
|
"Corporation": "π’",
|
|
"Person": "π€",
|
|
"Product": "π±",
|
|
"Location": "π",
|
|
"Creative-Work": "π",
|
|
"Group": "π₯"
|
|
}
|
|
|
|
def predict_entities(text):
|
|
"""Predict Named Entities using the model."""
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
predictions = torch.argmax(outputs.logits.float(), dim=2)
|
|
|
|
predicted_labels = [model.config.id2label[t.item()] for t in predictions[0]]
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
|
|
|
|
|
entities = []
|
|
current_entity = None
|
|
|
|
for token, label in zip(tokens, predicted_labels):
|
|
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
|
|
continue
|
|
|
|
if token.startswith("##"):
|
|
if current_entity:
|
|
current_entity["text"] += token[2:]
|
|
continue
|
|
|
|
if label == "O":
|
|
if current_entity:
|
|
entities.append(current_entity)
|
|
current_entity = None
|
|
else:
|
|
if label.startswith("B-"):
|
|
if current_entity:
|
|
entities.append(current_entity)
|
|
current_entity = {"text": token, "type": label[2:]}
|
|
elif label.startswith("I-") and current_entity:
|
|
current_entity["text"] += " " + token
|
|
|
|
if current_entity:
|
|
entities.append(current_entity)
|
|
|
|
return format_output(text, entities)
|
|
|
|
def format_output(text, entities):
|
|
"""Format output for Gradio UI."""
|
|
output = f"π₯ **Input**: {text}\n\nπ **Detected Entities**:\n"
|
|
|
|
if not entities:
|
|
output += "βΉοΈ No named entities detected. Try another sentence!\n"
|
|
else:
|
|
for entity in entities:
|
|
icon = ICON_MAP.get(entity["type"], "πΉ")
|
|
output += f"- {icon} **{entity['text']}** β `{entity['type']}`\n"
|
|
|
|
return output
|
|
|
|
|
|
gr.Interface(
|
|
fn=predict_entities,
|
|
inputs=gr.Textbox(placeholder="Enter text here...", label="Input Text"),
|
|
outputs=gr.Textbox(label="NER Output"),
|
|
title="π Named Entity Recognition (NER)",
|
|
description="π Enter a sentence and the model will detect entities like **Person, Location, Product, etc.**",
|
|
theme="default",
|
|
allow_flagging="never"
|
|
).launch()
|
|
|