File size: 4,132 Bytes
5a0110d 1a946a6 96ff0d7 5a0110d 1a946a6 5a0110d 1a946a6 96ff0d7 1a946a6 5a0110d 96ff0d7 5a0110d 96ff0d7 5a0110d 96ff0d7 5a0110d 96ff0d7 5a0110d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AlbertForSequenceClassification
import numpy as np
import os
import gdown
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define Google Drive folder IDs for each model
model_drive_ids = {
"sentiment": "your_sentiment_folder_id", # Replace with actual folder ID
"emotion": "your_emotion_folder_id", # Replace with actual folder ID
"hate_speech": "your_hate_speech_folder_id", # Replace with actual folder ID
"sarcasm": "your_sarcasm_folder_id" # Replace with actual folder ID
}
# Define local directory to store downloaded models
save_dir = "./saved_models"
os.makedirs(save_dir, exist_ok=True)
# Download models from Google Drive
for task, folder_id in model_drive_ids.items():
output_dir = os.path.join(save_dir, task)
if not os.path.exists(output_dir):
logger.info(f"Downloading {task} model from Google Drive...")
try:
gdown.download_folder(
f"https://drive.google.com/drive/folders/1kEXKoJxxD5-0FO8WvtagzseSIC5q-rRY?usp=sharing/{folder_id}",
output=output_dir,
quiet=False
)
except Exception as e:
logger.error(f"Failed to download {task} model: {str(e)}")
raise
# Define model paths
tasks = ["sentiment", "emotion", "hate_speech", "sarcasm"]
model_paths = {task: f"{save_dir}/{task}" for task in tasks}
# Define label mappings
label_mappings = {
"sentiment": ["negative", "neutral", "positive"],
"emotion": ["happy", "sad", "angry", "fear"],
"hate_speech": ["no", "yes"],
"sarcasm": ["no", "yes"]
}
# Load tokenizer with use_fast=False to avoid fast tokenizer issues
try:
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-bert", use_fast=False)
except Exception as e:
logger.error(f"Failed to load tokenizer: {str(e)}")
raise
# Load all models
models = {}
for task in tasks:
model_path = model_paths[task]
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model directory {model_path} not found.")
try:
logger.info(f"Loading {task} model...")
models[task] = AlbertForSequenceClassification.from_pretrained(model_path)
except Exception as e:
logger.error(f"Failed to load {task} model: {str(e)}")
raise
# Function to predict for a single task
def predict_task(text, task, model, tokenizer, max_length=128):
try:
inputs = tokenizer(
text,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
labels = label_mappings[task]
return {label: f"{prob*100:.2f}%" for label, prob in zip(labels, probabilities)}
except Exception as e:
logger.error(f"Error predicting for {task}: {str(e)}")
return {label: "Error" for label in label_mappings[task]}
# Gradio interface function
def predict_all_tasks(text):
if not text.strip():
return "Please enter some text."
results = {}
for task in tasks:
results[task] = predict_task(text, task, models[task], tokenizer)
output = ""
for task, probs in results.items():
output += f"\n{task.capitalize()} Prediction:\n"
for label, prob in probs.items():
output += f" {label}: {prob}\n"
return output
# Create Gradio interface
iface = gr.Interface(
fn=predict_all_tasks,
inputs=gr.Textbox(lines=2, placeholder="Enter Telugu text here..."),
outputs="text",
title="Telugu Text Analysis",
description="Enter Telugu text to predict sentiment, emotion, hate speech, and sarcasm."
)
if __name__ == "__main__":
logger.info("Launching Gradio interface...")
iface.launch(server_name="0.0.0.0", server_port=7860) |