File size: 2,923 Bytes
5f41e26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast
import torch
# Load Model & Tokenizer
model_name = "AventIQ-AI/distilbert-base-uncased_token_classification"
model = DistilBertForTokenClassification.from_pretrained(model_name)
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
# Define Icon Mapping for Entities
ICON_MAP = {
"Corporation": "π’",
"Person": "π€",
"Product": "π±",
"Location": "π",
"Creative-Work": "π",
"Group": "π₯"
}
def predict_entities(text):
"""Predict Named Entities using the model."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits.float(), dim=2) # Convert logits to float32
predicted_labels = [model.config.id2label[t.item()] for t in predictions[0]]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Process Entities
entities = []
current_entity = None
for token, label in zip(tokens, predicted_labels):
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
continue
if token.startswith("##"):
if current_entity:
current_entity["text"] += token[2:]
continue
if label == "O":
if current_entity:
entities.append(current_entity)
current_entity = None
else:
if label.startswith("B-"):
if current_entity:
entities.append(current_entity)
current_entity = {"text": token, "type": label[2:]}
elif label.startswith("I-") and current_entity:
current_entity["text"] += " " + token
if current_entity:
entities.append(current_entity)
return format_output(text, entities)
def format_output(text, entities):
"""Format output for Gradio UI."""
output = f"π₯ **Input**: {text}\n\nπ **Detected Entities**:\n"
if not entities:
output += "βΉοΈ No named entities detected. Try another sentence!\n"
else:
for entity in entities:
icon = ICON_MAP.get(entity["type"], "πΉ")
output += f"- {icon} **{entity['text']}** β `{entity['type']}`\n"
return output
# Create Gradio UI
gr.Interface(
fn=predict_entities,
inputs=gr.Textbox(placeholder="Enter text here...", label="Input Text"),
outputs=gr.Textbox(label="NER Output"),
title="π Named Entity Recognition (NER)",
description="π Enter a sentence and the model will detect entities like **Person, Location, Product, etc.**",
theme="default",
allow_flagging="never"
).launch()
|