import gradio as gr from transformers import DistilBertForTokenClassification, DistilBertTokenizerFast import torch # Load Model & Tokenizer model_name = "AventIQ-AI/distilbert-base-uncased_token_classification" model = DistilBertForTokenClassification.from_pretrained(model_name) tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) # Define Icon Mapping for Entities ICON_MAP = { "Corporation": "šŸ¢", "Person": "šŸ‘¤", "Product": "šŸ“±", "Location": "šŸ“", "Creative-Work": "šŸŽ­", "Group": "šŸ‘„" } def predict_entities(text): """Predict Named Entities using the model.""" inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) predictions = torch.argmax(outputs.logits.float(), dim=2) # Convert logits to float32 predicted_labels = [model.config.id2label[t.item()] for t in predictions[0]] tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Process Entities entities = [] current_entity = None for token, label in zip(tokens, predicted_labels): if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue if token.startswith("##"): if current_entity: current_entity["text"] += token[2:] continue if label == "O": if current_entity: entities.append(current_entity) current_entity = None else: if label.startswith("B-"): if current_entity: entities.append(current_entity) current_entity = {"text": token, "type": label[2:]} elif label.startswith("I-") and current_entity: current_entity["text"] += " " + token if current_entity: entities.append(current_entity) return format_output(text, entities) def format_output(text, entities): """Format output for Gradio UI.""" output = f"šŸ“„ **Input**: {text}\n\nšŸ” **Detected Entities**:\n" if not entities: output += "ā„¹ļø No named entities detected. Try another sentence!\n" else: for entity in entities: icon = ICON_MAP.get(entity["type"], "šŸ”¹") output += f"- {icon} **{entity['text']}** → `{entity['type']}`\n" return output # Create Gradio UI gr.Interface( fn=predict_entities, inputs=gr.Textbox(placeholder="Enter text here...", label="Input Text"), outputs=gr.Textbox(label="NER Output"), title="šŸ“ Named Entity Recognition (NER)", description="šŸ” Enter a sentence and the model will detect entities like **Person, Location, Product, etc.**", theme="default", allow_flagging="never" ).launch()