import os import sys import random import torch import numpy as np from flask import Flask, request, jsonify, render_template import subprocess # For calling gdown os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' import torch.nn as nn import torch.nn.functional as F from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions # Get the directory of the current script (app.py) APP_ROOT = os.path.dirname(os.path.abspath(__file__)) MODEL_FILES_DIR = os.path.join(APP_ROOT, 'model_files') TOKENIZER_FILES_DIR = os.path.join(MODEL_FILES_DIR, 'tokenizers') # Path for tokenizers specifically device = torch.device('cpu') class CFG: model_name = 'csebuetnlp/banglat5' encoder_name = 'csebuetnlp/banglabert' batch_size = 1 max_len = 512 seed = 42 device = device # Tokenizers will be loaded in load_checkpoint from the TOKENIZER_FILES_DIR # So, we don't strictly need to load them globally here if load_checkpoint is robust # However, WordLSTMEncoder init might rely on CFG.encoder_tokenizer.pad_token_id # Let's keep the initial load but ensure load_checkpoint overwrites with local versions. t5_tokenizer = None encoder_tokenizer = None # --- MODEL CLASSES (CharCNNEncoder, WordLSTMEncoder, HybridEncoder, DualEncoderDecoder) --- class CharCNNEncoder(nn.Module): def __init__(self, char_vocab_size, char_embedding_dim, char_cnn_output_dim, kernel_sizes, num_filters, dropout=0.1): super(CharCNNEncoder, self).__init__() self.char_embedding = nn.Embedding(char_vocab_size, char_embedding_dim, padding_idx=0) self.conv_layers = nn.ModuleList() for ks, nf in zip(kernel_sizes, num_filters): self.conv_layers.append( nn.Sequential( nn.Conv1d(char_embedding_dim, nf, kernel_size=ks, padding=ks // 2), nn.ReLU(), nn.AdaptiveMaxPool1d(1) ) ) self.dropout = nn.Dropout(dropout) self.output_projection = nn.Linear(sum(num_filters), char_cnn_output_dim) def forward(self, char_input): batch_size, seq_len, char_len = char_input.size() char_input = char_input.view(-1, char_len) char_emb = self.char_embedding(char_input) char_emb = char_emb.permute(0, 2, 1) conv_outputs = [conv(char_emb) for conv in self.conv_layers] concat_output = torch.cat(conv_outputs, dim=1) concat_output = concat_output.squeeze(-1) concat_output = self.dropout(concat_output) char_cnn_output = self.output_projection(concat_output) char_cnn_output = char_cnn_output.view(batch_size, seq_len, -1) return char_cnn_output class WordLSTMEncoder(nn.Module): def __init__(self, word_vocab_size, word_embedding_dim, word_lstm_hidden_dim, num_lstm_layers, dropout): super(WordLSTMEncoder, self).__init__() padding_idx_val = CFG.encoder_tokenizer.pad_token_id if hasattr(CFG, 'encoder_tokenizer') and CFG.encoder_tokenizer is not None and CFG.encoder_tokenizer.pad_token_id is not None else 0 self.word_embedding = nn.Embedding( word_vocab_size, word_embedding_dim, padding_idx=padding_idx_val ) self.lstm = nn.LSTM( word_embedding_dim, word_lstm_hidden_dim, num_layers=num_lstm_layers, batch_first=True, dropout=dropout, bidirectional=True ) self.output_projection = nn.Linear(2 * word_lstm_hidden_dim, word_lstm_hidden_dim) def forward(self, word_input, sequence_lengths): batch_size = word_input.size(0) word_emb = self.word_embedding(word_input) sequence_lengths_cpu = sequence_lengths.cpu() if torch.all(sequence_lengths_cpu == 0): lstm_out = torch.zeros(batch_size, word_input.size(1), self.lstm.hidden_size * 2, device=word_input.device) hidden = torch.zeros(batch_size, self.lstm.hidden_size * 2, device=word_input.device) # Corrected hidden dim return self.output_projection(hidden), lstm_out sorted_lengths, sort_idx = sequence_lengths_cpu.sort(0, descending=True) sorted_word_emb = word_emb[sort_idx] packed_word_emb = nn.utils.rnn.pack_padded_sequence( sorted_word_emb, sorted_lengths.clamp(min=1), batch_first=True, enforce_sorted=True ) packed_lstm_out, (hidden_state, cell_state) = self.lstm(packed_word_emb) lstm_out, _ = nn.utils.rnn.pad_packed_sequence( packed_lstm_out, batch_first=True, total_length=word_input.size(1) ) _, unsort_idx = sort_idx.sort(0) lstm_out = lstm_out[unsort_idx] hidden_state = hidden_state.view(self.lstm.num_layers, 2, batch_size, self.lstm.hidden_size) hidden_state_last_layer = hidden_state[-1] final_hidden = torch.cat((hidden_state_last_layer[0], hidden_state_last_layer[1]), dim=1) final_hidden = final_hidden[unsort_idx] return self.output_projection(final_hidden), lstm_out class HybridEncoder(nn.Module): def __init__(self, char_cnn_encoder, word_lstm_encoder, hybrid_encoder_output_dim): super(HybridEncoder, self).__init__() self.char_cnn_encoder = char_cnn_encoder self.word_lstm_encoder = word_lstm_encoder self.char_hidden_size = char_cnn_encoder.output_projection.out_features self.lstm_sequence_output_size = word_lstm_encoder.lstm.hidden_size * 2 self.output_projection = nn.Linear(self.char_hidden_size + self.lstm_sequence_output_size, hybrid_encoder_output_dim) def forward(self, char_input, word_input, sequence_lengths): batch_size = char_input.size(0) max_seq_len = word_input.size(1) char_cnn_output = self.char_cnn_encoder(char_input) sequence_lengths = sequence_lengths.to(word_input.device) _, lstm_sequence_output = self.word_lstm_encoder(word_input, sequence_lengths) if char_cnn_output.size(1) < max_seq_len: char_cnn_output = F.pad(char_cnn_output, (0, 0, 0, max_seq_len - char_cnn_output.size(1)), "constant", 0) elif char_cnn_output.size(1) > max_seq_len: char_cnn_output = char_cnn_output[:, :max_seq_len, :] if lstm_sequence_output.size(1) < max_seq_len: lstm_sequence_output = F.pad(lstm_sequence_output, (0, 0, 0, max_seq_len - lstm_sequence_output.size(1)), "constant", 0) elif lstm_sequence_output.size(1) > max_seq_len: lstm_sequence_output = lstm_sequence_output[:, :max_seq_len, :] hybrid_output_concat = torch.cat((char_cnn_output, lstm_sequence_output), dim=2) return self.output_projection(hybrid_output_concat) class DualEncoderDecoder(nn.Module): def __init__(self, t5_model_name, hybrid_encoder, t5_tokenizer, freeze_t5=False): super(DualEncoderDecoder, self).__init__() self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name) self.t5_tokenizer = t5_tokenizer self.hybrid_encoder = hybrid_encoder encoder_hidden_size = self.t5.config.d_model hybrid_hidden_size = hybrid_encoder.output_projection.out_features self.encoder_projection = nn.Linear(encoder_hidden_size + hybrid_hidden_size, encoder_hidden_size) if freeze_t5: for param in self.t5.parameters(): param.requires_grad = False def forward(self, input_ids, attention_mask, char_input, word_input, sequence_lengths, labels=None): t5_encoder_outputs_dict = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) t5_encoder_last_hidden_state = t5_encoder_outputs_dict.last_hidden_state sequence_lengths = sequence_lengths.to(char_input.device) hybrid_encoder_output = self.hybrid_encoder(char_input, word_input, sequence_lengths) common_seq_len = t5_encoder_last_hidden_state.size(1) if hybrid_encoder_output.size(1) < common_seq_len: hybrid_encoder_output = F.pad(hybrid_encoder_output, (0, 0, 0, common_seq_len - hybrid_encoder_output.size(1)), "constant", 0) elif hybrid_encoder_output.size(1) > common_seq_len: hybrid_encoder_output = hybrid_encoder_output[:, :common_seq_len, :] concat_encoder_output = torch.cat((t5_encoder_last_hidden_state, hybrid_encoder_output), dim=2) projected_encoder_output = self.encoder_projection(concat_encoder_output) encoder_outputs_for_decoder = BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=projected_encoder_output) return self.t5(encoder_outputs=encoder_outputs_for_decoder, attention_mask=attention_mask, labels=labels, return_dict=True, use_cache=False) def generate(self, input_ids, attention_mask, char_input, word_input, sequence_lengths, max_length, num_beams): t5_encoder_outputs_dict = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) t5_encoder_last_hidden_state = t5_encoder_outputs_dict.last_hidden_state sequence_lengths = sequence_lengths.to(char_input.device) hybrid_encoder_output = self.hybrid_encoder(char_input, word_input, sequence_lengths) common_seq_len = t5_encoder_last_hidden_state.size(1) if hybrid_encoder_output.size(1) < common_seq_len: hybrid_encoder_output = F.pad(hybrid_encoder_output, (0, 0, 0, common_seq_len - hybrid_encoder_output.size(1)), "constant", 0) elif hybrid_encoder_output.size(1) > common_seq_len: hybrid_encoder_output = hybrid_encoder_output[:, :common_seq_len, :] concat_encoder_output = torch.cat((t5_encoder_last_hidden_state, hybrid_encoder_output), dim=2) projected_encoder_output = self.encoder_projection(concat_encoder_output) encoder_outputs_for_generate = BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=projected_encoder_output) generated_ids_dict = self.t5.generate(encoder_outputs=encoder_outputs_for_generate, attention_mask=attention_mask, max_length=max_length, num_beams=num_beams, early_stopping=True, use_cache=True, return_dict_in_generate=True) return generated_ids_dict.sequences # --- END MODEL CLASSES --- def download_model_from_gdrive(file_id, destination): print(f"Attempting to download model from Google Drive (ID: {file_id}) to {destination}...") try: # Ensure gdown is installed (will be in Dockerfile) # The --fuzzy flag can sometimes help with problematic links or large files # The -O flag specifies the output file path subprocess.run(["gdown", "--id", file_id, "-O", destination, "--fuzzy"], check=True) print("Model downloaded successfully from Google Drive.") return True except subprocess.CalledProcessError as e: print(f"ERROR: Failed to download model using gdown: {e}") if e.stderr: print(f"gdown stderr: {e.stderr.decode()}") if e.stdout: print(f"gdown stdout: {e.stdout.decode()}") return False except FileNotFoundError: print("ERROR: gdown command not found. Ensure it is installed in the Docker image.") return False def load_checkpoint(path_to_checkpoint_file): # path_to_checkpoint_file is /app/model_files/best_model.pth # Tokenizers are loaded first from the local TOKENIZER_FILES_DIR t5_tokenizer_dir_path = os.path.join(TOKENIZER_FILES_DIR, 't5_tokenizer') encoder_tokenizer_dir_path = os.path.join(TOKENIZER_FILES_DIR, 'encoders_tokenizer') if not os.path.isdir(t5_tokenizer_dir_path): raise FileNotFoundError(f"T5 tokenizer directory not found: {t5_tokenizer_dir_path}") if not os.path.isdir(encoder_tokenizer_dir_path): raise FileNotFoundError(f"Encoder tokenizer directory not found: {encoder_tokenizer_dir_path}") print(f"Loading T5 tokenizer from: {t5_tokenizer_dir_path}") CFG.t5_tokenizer = T5Tokenizer.from_pretrained( t5_tokenizer_dir_path, legacy=False, model_max_length=CFG.max_len, local_files_only=True ) if CFG.t5_tokenizer.pad_token is None: CFG.t5_tokenizer.pad_token = CFG.t5_tokenizer.eos_token if CFG.t5_tokenizer.bos_token is None: CFG.t5_tokenizer.bos_token = CFG.t5_tokenizer.eos_token print(f"Loading encoder tokenizer from: {encoder_tokenizer_dir_path}") CFG.encoder_tokenizer = AutoTokenizer.from_pretrained( encoder_tokenizer_dir_path, model_max_length=CFG.max_len, local_files_only=True ) if CFG.encoder_tokenizer.pad_token is None: print(f"Warning: Loaded encoder tokenizer from {encoder_tokenizer_dir_path} has no pad_token.") # If your WordLSTMEncoder relies on pad_token_id, ensure it's set if not in config. # Example: CFG.encoder_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # CFG.encoder_tokenizer.pad_token = '[PAD]' # This should ideally be handled during tokenizer saving. # Now load the model checkpoint which contains weights, char_to_id etc. if not os.path.exists(path_to_checkpoint_file): raise FileNotFoundError(f"Model checkpoint file not found after download attempt: {path_to_checkpoint_file}") print(f"Loading model checkpoint from: {path_to_checkpoint_file}") checkpoint = torch.load(path_to_checkpoint_file, map_location=CFG.device) loaded_config_from_checkpoint = checkpoint['config'] loaded_char_to_id = checkpoint['char_to_id'] # loaded_id_to_char = checkpoint['id_to_char'] # Not used in inference code provided model_architecture = checkpoint['model_architecture'] for key, value in loaded_config_from_checkpoint.items(): setattr(CFG, key, value) CFG.device = device # Ensure CPU loaded_max_char_len = model_architecture.get('max_char_len', 50) char_cnn_encoder = CharCNNEncoder( char_vocab_size=model_architecture['char_vocab_size'], char_embedding_dim=model_architecture['char_embedding_dim'], char_cnn_output_dim=model_architecture['char_cnn_output_dim'], kernel_sizes=model_architecture['kernel_sizes'], num_filters=model_architecture['num_filters'], dropout=model_architecture.get('dropout', 0.1) ) current_encoder_vocab_size = len(CFG.encoder_tokenizer) if model_architecture['word_vocab_size'] != current_encoder_vocab_size: print(f"Warning: Word vocab size mismatch. Checkpoint: {model_architecture['word_vocab_size']}, " f"Loaded CFG.encoder_tokenizer: {current_encoder_vocab_size}. Using loaded tokenizer's size for WordLSTM.") word_lstm_encoder = WordLSTMEncoder( word_vocab_size=current_encoder_vocab_size, # Use current vocab size word_embedding_dim=model_architecture['word_embedding_dim'], word_lstm_hidden_dim=model_architecture['word_lstm_hidden_dim'], num_lstm_layers=model_architecture['num_lstm_layers'], dropout=model_architecture.get('dropout', 0.1) ) hybrid_encoder = HybridEncoder( char_cnn_encoder, word_lstm_encoder, hybrid_encoder_output_dim=model_architecture['hybrid_encoder_output_dim'] ) model_base_name_for_t5 = loaded_config_from_checkpoint.get('model_name', CFG.model_name) print(f"Initializing DualEncoderDecoder with T5 base: {model_base_name_for_t5}") model = DualEncoderDecoder( t5_model_name=model_base_name_for_t5, hybrid_encoder=hybrid_encoder, t5_tokenizer=CFG.t5_tokenizer ) model.t5.resize_token_embeddings(len(CFG.t5_tokenizer)) print("Loading model state_dict...") model.load_state_dict(checkpoint['model_state_dict'], strict=False) model.to(CFG.device) model.eval() print("Model loaded successfully.") # Return id_to_char if needed, otherwise can remove return model, loaded_char_to_id, loaded_max_char_len #, loaded_id_to_char # --- Helper methods (tokenize_characters, pad_sequence, process_input) --- def tokenize_characters(word, char_to_id): unk_token_id = char_to_id.get("", 0) return [char_to_id.get(char, unk_token_id) for char in word] def pad_sequence(sequence, max_length, pad_value): if len(sequence) > max_length: sequence = sequence[:max_length] return sequence + [pad_value] * (max_length - len(sequence)) def process_input(text, t5_tokenizer, encoder_tokenizer, char_to_id, current_max_char_len, max_token_len): t5_inputs = t5_tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_token_len, add_special_tokens=True) encoder_inputs = encoder_tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_token_len) t5_input_ids_squeezed = t5_inputs['input_ids'].squeeze(0) t5_attention_mask_squeezed = t5_inputs['attention_mask'].squeeze(0) encoder_input_ids_squeezed = encoder_inputs['input_ids'].squeeze(0) actual_max_seq_len = encoder_input_ids_squeezed.shape[0] char_input_tensor = torch.zeros((actual_max_seq_len, current_max_char_len), dtype=torch.long) char_pad_id = char_to_id.get("", 0) for j in range(actual_max_seq_len): token_id = encoder_input_ids_squeezed[j].item() if token_id in encoder_tokenizer.all_special_ids: word = "" else: word = encoder_tokenizer.decode([token_id], skip_special_tokens=True).strip() if not word: char_ids = [char_pad_id] * current_max_char_len else: char_ids = tokenize_characters(word, char_to_id) char_ids = pad_sequence(char_ids, current_max_char_len, char_pad_id) char_input_tensor[j, :] = torch.tensor(char_ids, dtype=torch.long) sequence_lengths_tensor = encoder_inputs['attention_mask'].sum(dim=1).long().squeeze() if sequence_lengths_tensor.ndim == 0: sequence_lengths_tensor = sequence_lengths_tensor.unsqueeze(0) return { 't5_input_ids': t5_input_ids_squeezed.unsqueeze(0).to(CFG.device), 't5_attention_mask': t5_attention_mask_squeezed.unsqueeze(0).to(CFG.device), 'encoder_input_ids': encoder_input_ids_squeezed.unsqueeze(0).to(CFG.device), 'char_input': char_input_tensor.unsqueeze(0).to(CFG.device), 'sequence_lengths': sequence_lengths_tensor.to(CFG.device) } # --- END HELPER METHODS --- app = Flask(__name__) # --- Global variables for the loaded model and components --- loaded_model_global = None loaded_char_to_id_global = None loaded_max_char_len_global = None # --- # Model download and loading logic at startup MODEL_GDRIVE_FILE_ID = "1XPICO1MAdf6OTKe_SVJfTW3G6gftVW5f" # Your Google Drive File ID checkpoint_file_name = "best_model.pth" local_checkpoint_path = os.path.join(MODEL_FILES_DIR, checkpoint_file_name) # Create model_files directory if it doesn't exist os.makedirs(MODEL_FILES_DIR, exist_ok=True) os.makedirs(TOKENIZER_FILES_DIR, exist_ok=True) # Ensure tokenizer dir exists too # Download the model only if it doesn't already exist if not os.path.exists(local_checkpoint_path): print(f"'{checkpoint_file_name}' not found locally. Attempting download...") if not download_model_from_gdrive(MODEL_GDRIVE_FILE_ID, local_checkpoint_path): print(f"FATAL: Model download failed. The application might not work correctly.") # loaded_model_global remains None else: print("Model download successful.") else: print(f"'{checkpoint_file_name}' found locally at {local_checkpoint_path}. Skipping download.") # Load model if download was successful or file already existed if os.path.exists(local_checkpoint_path): print("Initializing and loading model checkpoint...") try: # Use global variables to store loaded components loaded_model_global, loaded_char_to_id_global, loaded_max_char_len_global = load_checkpoint(local_checkpoint_path) print("Model and components loaded into global variables successfully.") except Exception as e: print(f"FATAL: Could not load model on startup from {local_checkpoint_path}: {e}") # import traceback # traceback.print_exc() loaded_model_global = None # Indicate model loading failed else: if loaded_model_global is None: # If download failed and it wasn't there before print("Model file is not available. Application will not function correctly.") @app.route('/') def index(): return render_template('index.html') @app.route('/translate', methods=['POST']) def translate_text(): if loaded_model_global is None: return jsonify({"error": "Model is not available. Please check server logs."}), 500 data = request.get_json() input_text = data.get('text', '') if not input_text: return jsonify({"error": "No text provided"}), 400 try: inputs = process_input( input_text, CFG.t5_tokenizer, # Assumes CFG.t5_tokenizer is loaded by load_checkpoint CFG.encoder_tokenizer, # Assumes CFG.encoder_tokenizer is loaded by load_checkpoint loaded_char_to_id_global, loaded_max_char_len_global, CFG.max_len ) with torch.no_grad(): generated_ids = loaded_model_global.generate( inputs['t5_input_ids'], inputs['t5_attention_mask'], inputs['char_input'], inputs['encoder_input_ids'], inputs['sequence_lengths'], max_length=CFG.max_len, num_beams=4 ) translation = CFG.t5_tokenizer.decode( generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ).strip() return jsonify({"translation": translation}) except Exception as e: print(f"Error during translation: {e}") # import traceback # traceback.print_exc() return jsonify({"error": "An error occurred during translation."}), 500 if __name__ == '__main__': port = int(os.environ.get("PORT", 7860)) # HF Docker Spaces expects 7860 app.run(host='0.0.0.0', port=port, debug=False) # debug=False for production