BanglaFeel / app.py
uBaby4life
Google Drive Support Added
69a586e
import os
import sys
import random
import torch
import numpy as np
from flask import Flask, request, jsonify, render_template
import subprocess # For calling gdown
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
# Get the directory of the current script (app.py)
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
MODEL_FILES_DIR = os.path.join(APP_ROOT, 'model_files')
TOKENIZER_FILES_DIR = os.path.join(MODEL_FILES_DIR, 'tokenizers') # Path for tokenizers specifically
device = torch.device('cpu')
class CFG:
model_name = 'csebuetnlp/banglat5'
encoder_name = 'csebuetnlp/banglabert'
batch_size = 1
max_len = 512
seed = 42
device = device
# Tokenizers will be loaded in load_checkpoint from the TOKENIZER_FILES_DIR
# So, we don't strictly need to load them globally here if load_checkpoint is robust
# However, WordLSTMEncoder init might rely on CFG.encoder_tokenizer.pad_token_id
# Let's keep the initial load but ensure load_checkpoint overwrites with local versions.
t5_tokenizer = None
encoder_tokenizer = None
# --- MODEL CLASSES (CharCNNEncoder, WordLSTMEncoder, HybridEncoder, DualEncoderDecoder) ---
class CharCNNEncoder(nn.Module):
def __init__(self, char_vocab_size, char_embedding_dim, char_cnn_output_dim, kernel_sizes, num_filters, dropout=0.1):
super(CharCNNEncoder, self).__init__()
self.char_embedding = nn.Embedding(char_vocab_size, char_embedding_dim, padding_idx=0)
self.conv_layers = nn.ModuleList()
for ks, nf in zip(kernel_sizes, num_filters):
self.conv_layers.append(
nn.Sequential(
nn.Conv1d(char_embedding_dim, nf, kernel_size=ks, padding=ks // 2),
nn.ReLU(),
nn.AdaptiveMaxPool1d(1)
)
)
self.dropout = nn.Dropout(dropout)
self.output_projection = nn.Linear(sum(num_filters), char_cnn_output_dim)
def forward(self, char_input):
batch_size, seq_len, char_len = char_input.size()
char_input = char_input.view(-1, char_len)
char_emb = self.char_embedding(char_input)
char_emb = char_emb.permute(0, 2, 1)
conv_outputs = [conv(char_emb) for conv in self.conv_layers]
concat_output = torch.cat(conv_outputs, dim=1)
concat_output = concat_output.squeeze(-1)
concat_output = self.dropout(concat_output)
char_cnn_output = self.output_projection(concat_output)
char_cnn_output = char_cnn_output.view(batch_size, seq_len, -1)
return char_cnn_output
class WordLSTMEncoder(nn.Module):
def __init__(self, word_vocab_size, word_embedding_dim, word_lstm_hidden_dim, num_lstm_layers, dropout):
super(WordLSTMEncoder, self).__init__()
padding_idx_val = CFG.encoder_tokenizer.pad_token_id if hasattr(CFG, 'encoder_tokenizer') and CFG.encoder_tokenizer is not None and CFG.encoder_tokenizer.pad_token_id is not None else 0
self.word_embedding = nn.Embedding(
word_vocab_size,
word_embedding_dim,
padding_idx=padding_idx_val
)
self.lstm = nn.LSTM(
word_embedding_dim,
word_lstm_hidden_dim,
num_layers=num_lstm_layers,
batch_first=True,
dropout=dropout,
bidirectional=True
)
self.output_projection = nn.Linear(2 * word_lstm_hidden_dim, word_lstm_hidden_dim)
def forward(self, word_input, sequence_lengths):
batch_size = word_input.size(0)
word_emb = self.word_embedding(word_input)
sequence_lengths_cpu = sequence_lengths.cpu()
if torch.all(sequence_lengths_cpu == 0):
lstm_out = torch.zeros(batch_size, word_input.size(1), self.lstm.hidden_size * 2, device=word_input.device)
hidden = torch.zeros(batch_size, self.lstm.hidden_size * 2, device=word_input.device) # Corrected hidden dim
return self.output_projection(hidden), lstm_out
sorted_lengths, sort_idx = sequence_lengths_cpu.sort(0, descending=True)
sorted_word_emb = word_emb[sort_idx]
packed_word_emb = nn.utils.rnn.pack_padded_sequence(
sorted_word_emb,
sorted_lengths.clamp(min=1),
batch_first=True,
enforce_sorted=True
)
packed_lstm_out, (hidden_state, cell_state) = self.lstm(packed_word_emb)
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
packed_lstm_out,
batch_first=True,
total_length=word_input.size(1)
)
_, unsort_idx = sort_idx.sort(0)
lstm_out = lstm_out[unsort_idx]
hidden_state = hidden_state.view(self.lstm.num_layers, 2, batch_size, self.lstm.hidden_size)
hidden_state_last_layer = hidden_state[-1]
final_hidden = torch.cat((hidden_state_last_layer[0], hidden_state_last_layer[1]), dim=1)
final_hidden = final_hidden[unsort_idx]
return self.output_projection(final_hidden), lstm_out
class HybridEncoder(nn.Module):
def __init__(self, char_cnn_encoder, word_lstm_encoder, hybrid_encoder_output_dim):
super(HybridEncoder, self).__init__()
self.char_cnn_encoder = char_cnn_encoder
self.word_lstm_encoder = word_lstm_encoder
self.char_hidden_size = char_cnn_encoder.output_projection.out_features
self.lstm_sequence_output_size = word_lstm_encoder.lstm.hidden_size * 2
self.output_projection = nn.Linear(self.char_hidden_size + self.lstm_sequence_output_size, hybrid_encoder_output_dim)
def forward(self, char_input, word_input, sequence_lengths):
batch_size = char_input.size(0)
max_seq_len = word_input.size(1)
char_cnn_output = self.char_cnn_encoder(char_input)
sequence_lengths = sequence_lengths.to(word_input.device)
_, lstm_sequence_output = self.word_lstm_encoder(word_input, sequence_lengths)
if char_cnn_output.size(1) < max_seq_len:
char_cnn_output = F.pad(char_cnn_output, (0, 0, 0, max_seq_len - char_cnn_output.size(1)), "constant", 0)
elif char_cnn_output.size(1) > max_seq_len:
char_cnn_output = char_cnn_output[:, :max_seq_len, :]
if lstm_sequence_output.size(1) < max_seq_len:
lstm_sequence_output = F.pad(lstm_sequence_output, (0, 0, 0, max_seq_len - lstm_sequence_output.size(1)), "constant", 0)
elif lstm_sequence_output.size(1) > max_seq_len:
lstm_sequence_output = lstm_sequence_output[:, :max_seq_len, :]
hybrid_output_concat = torch.cat((char_cnn_output, lstm_sequence_output), dim=2)
return self.output_projection(hybrid_output_concat)
class DualEncoderDecoder(nn.Module):
def __init__(self, t5_model_name, hybrid_encoder, t5_tokenizer, freeze_t5=False):
super(DualEncoderDecoder, self).__init__()
self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name)
self.t5_tokenizer = t5_tokenizer
self.hybrid_encoder = hybrid_encoder
encoder_hidden_size = self.t5.config.d_model
hybrid_hidden_size = hybrid_encoder.output_projection.out_features
self.encoder_projection = nn.Linear(encoder_hidden_size + hybrid_hidden_size, encoder_hidden_size)
if freeze_t5:
for param in self.t5.parameters(): param.requires_grad = False
def forward(self, input_ids, attention_mask, char_input, word_input, sequence_lengths, labels=None):
t5_encoder_outputs_dict = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
t5_encoder_last_hidden_state = t5_encoder_outputs_dict.last_hidden_state
sequence_lengths = sequence_lengths.to(char_input.device)
hybrid_encoder_output = self.hybrid_encoder(char_input, word_input, sequence_lengths)
common_seq_len = t5_encoder_last_hidden_state.size(1)
if hybrid_encoder_output.size(1) < common_seq_len:
hybrid_encoder_output = F.pad(hybrid_encoder_output, (0, 0, 0, common_seq_len - hybrid_encoder_output.size(1)), "constant", 0)
elif hybrid_encoder_output.size(1) > common_seq_len:
hybrid_encoder_output = hybrid_encoder_output[:, :common_seq_len, :]
concat_encoder_output = torch.cat((t5_encoder_last_hidden_state, hybrid_encoder_output), dim=2)
projected_encoder_output = self.encoder_projection(concat_encoder_output)
encoder_outputs_for_decoder = BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=projected_encoder_output)
return self.t5(encoder_outputs=encoder_outputs_for_decoder, attention_mask=attention_mask, labels=labels, return_dict=True, use_cache=False)
def generate(self, input_ids, attention_mask, char_input, word_input, sequence_lengths, max_length, num_beams):
t5_encoder_outputs_dict = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
t5_encoder_last_hidden_state = t5_encoder_outputs_dict.last_hidden_state
sequence_lengths = sequence_lengths.to(char_input.device)
hybrid_encoder_output = self.hybrid_encoder(char_input, word_input, sequence_lengths)
common_seq_len = t5_encoder_last_hidden_state.size(1)
if hybrid_encoder_output.size(1) < common_seq_len:
hybrid_encoder_output = F.pad(hybrid_encoder_output, (0, 0, 0, common_seq_len - hybrid_encoder_output.size(1)), "constant", 0)
elif hybrid_encoder_output.size(1) > common_seq_len:
hybrid_encoder_output = hybrid_encoder_output[:, :common_seq_len, :]
concat_encoder_output = torch.cat((t5_encoder_last_hidden_state, hybrid_encoder_output), dim=2)
projected_encoder_output = self.encoder_projection(concat_encoder_output)
encoder_outputs_for_generate = BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=projected_encoder_output)
generated_ids_dict = self.t5.generate(encoder_outputs=encoder_outputs_for_generate, attention_mask=attention_mask, max_length=max_length, num_beams=num_beams, early_stopping=True, use_cache=True, return_dict_in_generate=True)
return generated_ids_dict.sequences
# --- END MODEL CLASSES ---
def download_model_from_gdrive(file_id, destination):
print(f"Attempting to download model from Google Drive (ID: {file_id}) to {destination}...")
try:
# Ensure gdown is installed (will be in Dockerfile)
# The --fuzzy flag can sometimes help with problematic links or large files
# The -O flag specifies the output file path
subprocess.run(["gdown", "--id", file_id, "-O", destination, "--fuzzy"], check=True)
print("Model downloaded successfully from Google Drive.")
return True
except subprocess.CalledProcessError as e:
print(f"ERROR: Failed to download model using gdown: {e}")
if e.stderr:
print(f"gdown stderr: {e.stderr.decode()}")
if e.stdout:
print(f"gdown stdout: {e.stdout.decode()}")
return False
except FileNotFoundError:
print("ERROR: gdown command not found. Ensure it is installed in the Docker image.")
return False
def load_checkpoint(path_to_checkpoint_file): # path_to_checkpoint_file is /app/model_files/best_model.pth
# Tokenizers are loaded first from the local TOKENIZER_FILES_DIR
t5_tokenizer_dir_path = os.path.join(TOKENIZER_FILES_DIR, 't5_tokenizer')
encoder_tokenizer_dir_path = os.path.join(TOKENIZER_FILES_DIR, 'encoders_tokenizer')
if not os.path.isdir(t5_tokenizer_dir_path):
raise FileNotFoundError(f"T5 tokenizer directory not found: {t5_tokenizer_dir_path}")
if not os.path.isdir(encoder_tokenizer_dir_path):
raise FileNotFoundError(f"Encoder tokenizer directory not found: {encoder_tokenizer_dir_path}")
print(f"Loading T5 tokenizer from: {t5_tokenizer_dir_path}")
CFG.t5_tokenizer = T5Tokenizer.from_pretrained(
t5_tokenizer_dir_path, legacy=False, model_max_length=CFG.max_len, local_files_only=True
)
if CFG.t5_tokenizer.pad_token is None: CFG.t5_tokenizer.pad_token = CFG.t5_tokenizer.eos_token
if CFG.t5_tokenizer.bos_token is None: CFG.t5_tokenizer.bos_token = CFG.t5_tokenizer.eos_token
print(f"Loading encoder tokenizer from: {encoder_tokenizer_dir_path}")
CFG.encoder_tokenizer = AutoTokenizer.from_pretrained(
encoder_tokenizer_dir_path, model_max_length=CFG.max_len, local_files_only=True
)
if CFG.encoder_tokenizer.pad_token is None:
print(f"Warning: Loaded encoder tokenizer from {encoder_tokenizer_dir_path} has no pad_token.")
# If your WordLSTMEncoder relies on pad_token_id, ensure it's set if not in config.
# Example: CFG.encoder_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# CFG.encoder_tokenizer.pad_token = '[PAD]'
# This should ideally be handled during tokenizer saving.
# Now load the model checkpoint which contains weights, char_to_id etc.
if not os.path.exists(path_to_checkpoint_file):
raise FileNotFoundError(f"Model checkpoint file not found after download attempt: {path_to_checkpoint_file}")
print(f"Loading model checkpoint from: {path_to_checkpoint_file}")
checkpoint = torch.load(path_to_checkpoint_file, map_location=CFG.device)
loaded_config_from_checkpoint = checkpoint['config']
loaded_char_to_id = checkpoint['char_to_id']
# loaded_id_to_char = checkpoint['id_to_char'] # Not used in inference code provided
model_architecture = checkpoint['model_architecture']
for key, value in loaded_config_from_checkpoint.items():
setattr(CFG, key, value)
CFG.device = device # Ensure CPU
loaded_max_char_len = model_architecture.get('max_char_len', 50)
char_cnn_encoder = CharCNNEncoder(
char_vocab_size=model_architecture['char_vocab_size'],
char_embedding_dim=model_architecture['char_embedding_dim'],
char_cnn_output_dim=model_architecture['char_cnn_output_dim'],
kernel_sizes=model_architecture['kernel_sizes'],
num_filters=model_architecture['num_filters'],
dropout=model_architecture.get('dropout', 0.1)
)
current_encoder_vocab_size = len(CFG.encoder_tokenizer)
if model_architecture['word_vocab_size'] != current_encoder_vocab_size:
print(f"Warning: Word vocab size mismatch. Checkpoint: {model_architecture['word_vocab_size']}, "
f"Loaded CFG.encoder_tokenizer: {current_encoder_vocab_size}. Using loaded tokenizer's size for WordLSTM.")
word_lstm_encoder = WordLSTMEncoder(
word_vocab_size=current_encoder_vocab_size, # Use current vocab size
word_embedding_dim=model_architecture['word_embedding_dim'],
word_lstm_hidden_dim=model_architecture['word_lstm_hidden_dim'],
num_lstm_layers=model_architecture['num_lstm_layers'],
dropout=model_architecture.get('dropout', 0.1)
)
hybrid_encoder = HybridEncoder(
char_cnn_encoder, word_lstm_encoder,
hybrid_encoder_output_dim=model_architecture['hybrid_encoder_output_dim']
)
model_base_name_for_t5 = loaded_config_from_checkpoint.get('model_name', CFG.model_name)
print(f"Initializing DualEncoderDecoder with T5 base: {model_base_name_for_t5}")
model = DualEncoderDecoder(
t5_model_name=model_base_name_for_t5,
hybrid_encoder=hybrid_encoder,
t5_tokenizer=CFG.t5_tokenizer
)
model.t5.resize_token_embeddings(len(CFG.t5_tokenizer))
print("Loading model state_dict...")
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(CFG.device)
model.eval()
print("Model loaded successfully.")
# Return id_to_char if needed, otherwise can remove
return model, loaded_char_to_id, loaded_max_char_len #, loaded_id_to_char
# --- Helper methods (tokenize_characters, pad_sequence, process_input) ---
def tokenize_characters(word, char_to_id):
unk_token_id = char_to_id.get("<UNK>", 0)
return [char_to_id.get(char, unk_token_id) for char in word]
def pad_sequence(sequence, max_length, pad_value):
if len(sequence) > max_length: sequence = sequence[:max_length]
return sequence + [pad_value] * (max_length - len(sequence))
def process_input(text, t5_tokenizer, encoder_tokenizer, char_to_id, current_max_char_len, max_token_len):
t5_inputs = t5_tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_token_len, add_special_tokens=True)
encoder_inputs = encoder_tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_token_len)
t5_input_ids_squeezed = t5_inputs['input_ids'].squeeze(0)
t5_attention_mask_squeezed = t5_inputs['attention_mask'].squeeze(0)
encoder_input_ids_squeezed = encoder_inputs['input_ids'].squeeze(0)
actual_max_seq_len = encoder_input_ids_squeezed.shape[0]
char_input_tensor = torch.zeros((actual_max_seq_len, current_max_char_len), dtype=torch.long)
char_pad_id = char_to_id.get("<PAD>", 0)
for j in range(actual_max_seq_len):
token_id = encoder_input_ids_squeezed[j].item()
if token_id in encoder_tokenizer.all_special_ids: word = ""
else: word = encoder_tokenizer.decode([token_id], skip_special_tokens=True).strip()
if not word: char_ids = [char_pad_id] * current_max_char_len
else:
char_ids = tokenize_characters(word, char_to_id)
char_ids = pad_sequence(char_ids, current_max_char_len, char_pad_id)
char_input_tensor[j, :] = torch.tensor(char_ids, dtype=torch.long)
sequence_lengths_tensor = encoder_inputs['attention_mask'].sum(dim=1).long().squeeze()
if sequence_lengths_tensor.ndim == 0: sequence_lengths_tensor = sequence_lengths_tensor.unsqueeze(0)
return {
't5_input_ids': t5_input_ids_squeezed.unsqueeze(0).to(CFG.device),
't5_attention_mask': t5_attention_mask_squeezed.unsqueeze(0).to(CFG.device),
'encoder_input_ids': encoder_input_ids_squeezed.unsqueeze(0).to(CFG.device),
'char_input': char_input_tensor.unsqueeze(0).to(CFG.device),
'sequence_lengths': sequence_lengths_tensor.to(CFG.device)
}
# --- END HELPER METHODS ---
app = Flask(__name__)
# --- Global variables for the loaded model and components ---
loaded_model_global = None
loaded_char_to_id_global = None
loaded_max_char_len_global = None
# ---
# Model download and loading logic at startup
MODEL_GDRIVE_FILE_ID = "1XPICO1MAdf6OTKe_SVJfTW3G6gftVW5f" # Your Google Drive File ID
checkpoint_file_name = "best_model.pth"
local_checkpoint_path = os.path.join(MODEL_FILES_DIR, checkpoint_file_name)
# Create model_files directory if it doesn't exist
os.makedirs(MODEL_FILES_DIR, exist_ok=True)
os.makedirs(TOKENIZER_FILES_DIR, exist_ok=True) # Ensure tokenizer dir exists too
# Download the model only if it doesn't already exist
if not os.path.exists(local_checkpoint_path):
print(f"'{checkpoint_file_name}' not found locally. Attempting download...")
if not download_model_from_gdrive(MODEL_GDRIVE_FILE_ID, local_checkpoint_path):
print(f"FATAL: Model download failed. The application might not work correctly.")
# loaded_model_global remains None
else:
print("Model download successful.")
else:
print(f"'{checkpoint_file_name}' found locally at {local_checkpoint_path}. Skipping download.")
# Load model if download was successful or file already existed
if os.path.exists(local_checkpoint_path):
print("Initializing and loading model checkpoint...")
try:
# Use global variables to store loaded components
loaded_model_global, loaded_char_to_id_global, loaded_max_char_len_global = load_checkpoint(local_checkpoint_path)
print("Model and components loaded into global variables successfully.")
except Exception as e:
print(f"FATAL: Could not load model on startup from {local_checkpoint_path}: {e}")
# import traceback
# traceback.print_exc()
loaded_model_global = None # Indicate model loading failed
else:
if loaded_model_global is None: # If download failed and it wasn't there before
print("Model file is not available. Application will not function correctly.")
@app.route('/')
def index():
return render_template('index.html')
@app.route('/translate', methods=['POST'])
def translate_text():
if loaded_model_global is None:
return jsonify({"error": "Model is not available. Please check server logs."}), 500
data = request.get_json()
input_text = data.get('text', '')
if not input_text:
return jsonify({"error": "No text provided"}), 400
try:
inputs = process_input(
input_text,
CFG.t5_tokenizer, # Assumes CFG.t5_tokenizer is loaded by load_checkpoint
CFG.encoder_tokenizer, # Assumes CFG.encoder_tokenizer is loaded by load_checkpoint
loaded_char_to_id_global,
loaded_max_char_len_global,
CFG.max_len
)
with torch.no_grad():
generated_ids = loaded_model_global.generate(
inputs['t5_input_ids'], inputs['t5_attention_mask'],
inputs['char_input'], inputs['encoder_input_ids'],
inputs['sequence_lengths'],
max_length=CFG.max_len, num_beams=4
)
translation = CFG.t5_tokenizer.decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
).strip()
return jsonify({"translation": translation})
except Exception as e:
print(f"Error during translation: {e}")
# import traceback
# traceback.print_exc()
return jsonify({"error": "An error occurred during translation."}), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860)) # HF Docker Spaces expects 7860
app.run(host='0.0.0.0', port=port, debug=False) # debug=False for production