|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AlbertForSequenceClassification, AlbertTokenizer |
|
import numpy as np |
|
import os |
|
import gdown |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_file_ids = { |
|
"sentiment": { |
|
"config": "11jwMJmQMGkiVZWBRQ5BLFyot1520FYIQ", |
|
"model": "115N5yiu9lfw4uJE5YxHNoHauHeYSSusu" |
|
}, |
|
"emotion": { |
|
"config": "1dSxK10jbZyRpMDCm6MCRf9Jy0weOzLP9", |
|
"model": "1Y3rTtPfo4zu28OhsRybdJF6czZN46I0Y" |
|
}, |
|
"hate_speech": { |
|
"config": "1QTejES8BZQs3qnxom9ymiZkLRUAZ91NP", |
|
"model": "1ol2xO4XbdHwP_HHCYsnX8iVutA6javy_" |
|
}, |
|
"sarcasm": { |
|
"config": "1ypl0j1Yp_-0szR4-P1-0CMyDYBwUn5Wz", |
|
"model": "1pbByLvTIHO_sT9HMeypvXbsdHsLVzTdk" |
|
} |
|
} |
|
|
|
|
|
save_dir = "./saved_models" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
for task, files in model_file_ids.items(): |
|
output_dir = os.path.join(save_dir, task) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
config_path = os.path.join(output_dir, "config.json") |
|
model_path = os.path.join(output_dir, "model.safetensors") |
|
|
|
if not os.path.exists(config_path): |
|
logger.info(f"Downloading {task} config.json from Google Drive...") |
|
gdown.download(f"https://drive.google.com/uc?id={files['config']}", config_path, quiet=False) |
|
else: |
|
logger.info(f"Config for {task} already exists, skipping download.") |
|
|
|
if not os.path.exists(model_path): |
|
logger.info(f"Downloading {task} model.safetensors from Google Drive...") |
|
gdown.download(f"https://drive.google.com/uc?id={files['model']}", model_path, quiet=False) |
|
else: |
|
logger.info(f"Model for {task} already exists, skipping download.") |
|
|
|
|
|
tasks = ["sentiment", "emotion", "hate_speech", "sarcasm"] |
|
model_paths = {task: f"{save_dir}/{task}" for task in tasks} |
|
|
|
|
|
label_mappings = { |
|
"sentiment": ["negative", "neutral", "positive"], |
|
"emotion": ["happy", "sad", "angry", "fear"], |
|
"hate_speech": ["no", "yes"], |
|
"sarcasm": ["no", "yes"] |
|
} |
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
try: |
|
|
|
tokenizer = AlbertTokenizer.from_pretrained("ai4bharat/indic-bert", use_fast=False) |
|
except Exception as e: |
|
logger.error(f"Failed to load tokenizer: {str(e)}") |
|
raise |
|
|
|
|
|
models = {} |
|
for task in tasks: |
|
logger.info(f"Loading model for {task}...") |
|
if not os.path.exists(model_paths[task]): |
|
raise FileNotFoundError(f"Model directory {model_paths[task]} not found.") |
|
try: |
|
models[task] = AlbertForSequenceClassification.from_pretrained(model_paths[task]) |
|
except Exception as e: |
|
logger.error(f"Failed to load model for {task}: {str(e)}") |
|
raise |
|
|
|
|
|
def predict_task(text, task, model, tokenizer, max_length=128): |
|
inputs = tokenizer( |
|
text, |
|
padding=True, |
|
truncation=True, |
|
max_length=max_length, |
|
return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy() |
|
|
|
labels = label_mappings[task] |
|
return {label: f"{prob*100:.2f}%" for label, prob in zip(labels, probabilities)} |
|
|
|
|
|
def predict_all_tasks(text): |
|
if not text.strip(): |
|
return "Please enter some text." |
|
|
|
results = {} |
|
for task in tasks: |
|
results[task] = predict_task(text, task, models[task], tokenizer) |
|
|
|
output = "" |
|
for task, probs in results.items(): |
|
output += f"\n{task.capitalize()} Prediction:\n" |
|
for label, prob in probs.items(): |
|
output += f" {label}: {prob}\n" |
|
|
|
return output |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_all_tasks, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter Telugu text here..."), |
|
outputs="text", |
|
title="Telugu Text Analysis", |
|
description="Enter Telugu text to predict sentiment, emotion, hate speech, and sarcasm." |
|
) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Launching Gradio interface...") |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |