NERIndo / app.py
Hokeno's picture
Upload 9 files
3ac3892 verified
import os
import sys
import subprocess
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification
import torch
import gradio as gr
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Menggunakan perangkat: {device}")
# Load dataset to get label list
try:
dataset = load_dataset("indonlp/indonlu", "nergrit", trust_remote_code=True)
except Exception as e:
print(f"Gagal memuat dataset: {e}")
sys.exit(1)
# Verify dataset structure
if "train" not in dataset or "test" not in dataset:
print("Dataset tidak memiliki split train/test yang diharapkan.")
sys.exit(1)
if "tokens" not in dataset["train"].column_names or "ner_tags" not in dataset["train"].column_names:
print("Dataset tidak memiliki kolom 'tokens' atau 'ner_tags'.")
sys.exit(1)
# Define label list
try:
label_list = dataset["train"].features["ner_tags"].feature.names
id2label = {i: label for i, label in enumerate(label_list)}
label2id = {label: i for i, label in enumerate(label_list)}
except Exception as e:
print(f"Gagal mendapatkan label: {e}")
sys.exit(1)
# Load tokenizer and model from saved directory
try:
tokenizer = AutoTokenizer.from_pretrained("./ner_model")
model = AutoModelForTokenClassification.from_pretrained(
"./ner_model",
num_labels=len(label_list),
id2label=id2label,
label2id=label2id
)
model.to(device)
except Exception as e:
print(f"Gagal memuat model atau tokenizer dari './ner_model': {e}")
print("Pastikan folder './ner_model' ada dan berisi model yang telah dilatih.")
sys.exit(1)
# Tokenize and align labels for test data
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
labels = []
for i, label in enumerate(examples["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
# Tokenize test dataset
try:
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)
except Exception as e:
print(f"Gagal menokenisasi dataset: {e}")
sys.exit(1)
# Function to predict entities for input text
def predict_entities(input_text):
if not input_text.strip():
return "Masukkan teks untuk diprediksi."
# Tokenize input text
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Predict
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
predictions = outputs.logits.argmax(dim=2)[0].cpu().numpy()
# Get tokens and predicted labels
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
labels = [id2label[pred] for pred in predictions]
# Remove special tokens ([CLS], [SEP]) and align
result = []
for token, label in zip(tokens, labels):
if token not in ["[CLS]", "[SEP]"]:
result.append({"Token": token, "Entity": label})
# Convert to DataFrame for display
return pd.DataFrame(result)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Named Entity Recognition (NER) dengan IndoBERT")
gr.Markdown("Masukkan teks dalam bahasa Indonesia untuk mendeteksi entitas seperti PERSON, ORGANISATION, PLACE, dll.")
gr.Markdown("## Keterangan Label Entitas")
gr.Markdown("""
- **O**: Token bukan entitas (contoh: "dan", "mengunjungi").
- **B-PERSON**: Awal nama orang (contoh: "Joko" dalam "Joko Widodo").
- **I-PERSON**: Lanjutan nama orang (contoh: "Widodo" atau "##do" dalam "Joko Widodo").
- **B-PLACE**: Awal nama tempat (contoh: "Bali").
- **I-PLACE**: Lanjutan nama tempat (contoh: "Indonesia" dalam "Bali, Indonesia").
""")
with gr.Row():
text_input = gr.Textbox(
label="Masukkan Teks",
placeholder="Contoh: Joko Widodo menghadiri acara di Universitas Indonesia pada tanggal 14 Juni 2025",
lines=3
)
submit_button = gr.Button("Prediksi")
clear_button = gr.Button("Bersihkan")
output_table = gr.Dataframe(label="Hasil Prediksi")
gr.Markdown("## Contoh Teks")
gr.Markdown("- SBY berkunjung ke Bali bersama Jokowi.\n- Universitas Gadjah Mada menyelenggarakan seminar pada 10 Maret 2025.")
gr.Markdown("## Pertimbangan Keamanan Data, Privasi, dan Etika")
gr.Markdown("""
- **Keamanan Data**: Dataset bersumber dari berita publik, tidak mengandung informasi sensitif seperti alamat atau nomor identitas.
- **Privasi**: Input pengguna tidak disimpan, menjaga privasi.
- **Etika AI**: Dataset mencakup berbagai topik berita (politik, olahraga, budaya), mengurangi risiko bias terhadap entitas tertentu.
""")
submit_button.click(fn=predict_entities, inputs=text_input, outputs=output_table)
clear_button.click(fn=lambda: "", inputs=None, outputs=text_input)
# Launch Gradio interface
demo.launch()