Spaces:
Running
Running
import os | |
import sys | |
import random | |
import torch | |
import numpy as np | |
from flask import Flask, request, jsonify, render_template | |
import subprocess # For calling gdown | |
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration | |
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions | |
# Get the directory of the current script (app.py) | |
APP_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
MODEL_FILES_DIR = os.path.join(APP_ROOT, 'model_files') | |
TOKENIZER_FILES_DIR = os.path.join(MODEL_FILES_DIR, 'tokenizers') # Path for tokenizers specifically | |
device = torch.device('cpu') | |
class CFG: | |
model_name = 'csebuetnlp/banglat5' | |
encoder_name = 'csebuetnlp/banglabert' | |
batch_size = 1 | |
max_len = 512 | |
seed = 42 | |
device = device | |
# Tokenizers will be loaded in load_checkpoint from the TOKENIZER_FILES_DIR | |
# So, we don't strictly need to load them globally here if load_checkpoint is robust | |
# However, WordLSTMEncoder init might rely on CFG.encoder_tokenizer.pad_token_id | |
# Let's keep the initial load but ensure load_checkpoint overwrites with local versions. | |
t5_tokenizer = None | |
encoder_tokenizer = None | |
# --- MODEL CLASSES (CharCNNEncoder, WordLSTMEncoder, HybridEncoder, DualEncoderDecoder) --- | |
class CharCNNEncoder(nn.Module): | |
def __init__(self, char_vocab_size, char_embedding_dim, char_cnn_output_dim, kernel_sizes, num_filters, dropout=0.1): | |
super(CharCNNEncoder, self).__init__() | |
self.char_embedding = nn.Embedding(char_vocab_size, char_embedding_dim, padding_idx=0) | |
self.conv_layers = nn.ModuleList() | |
for ks, nf in zip(kernel_sizes, num_filters): | |
self.conv_layers.append( | |
nn.Sequential( | |
nn.Conv1d(char_embedding_dim, nf, kernel_size=ks, padding=ks // 2), | |
nn.ReLU(), | |
nn.AdaptiveMaxPool1d(1) | |
) | |
) | |
self.dropout = nn.Dropout(dropout) | |
self.output_projection = nn.Linear(sum(num_filters), char_cnn_output_dim) | |
def forward(self, char_input): | |
batch_size, seq_len, char_len = char_input.size() | |
char_input = char_input.view(-1, char_len) | |
char_emb = self.char_embedding(char_input) | |
char_emb = char_emb.permute(0, 2, 1) | |
conv_outputs = [conv(char_emb) for conv in self.conv_layers] | |
concat_output = torch.cat(conv_outputs, dim=1) | |
concat_output = concat_output.squeeze(-1) | |
concat_output = self.dropout(concat_output) | |
char_cnn_output = self.output_projection(concat_output) | |
char_cnn_output = char_cnn_output.view(batch_size, seq_len, -1) | |
return char_cnn_output | |
class WordLSTMEncoder(nn.Module): | |
def __init__(self, word_vocab_size, word_embedding_dim, word_lstm_hidden_dim, num_lstm_layers, dropout): | |
super(WordLSTMEncoder, self).__init__() | |
padding_idx_val = CFG.encoder_tokenizer.pad_token_id if hasattr(CFG, 'encoder_tokenizer') and CFG.encoder_tokenizer is not None and CFG.encoder_tokenizer.pad_token_id is not None else 0 | |
self.word_embedding = nn.Embedding( | |
word_vocab_size, | |
word_embedding_dim, | |
padding_idx=padding_idx_val | |
) | |
self.lstm = nn.LSTM( | |
word_embedding_dim, | |
word_lstm_hidden_dim, | |
num_layers=num_lstm_layers, | |
batch_first=True, | |
dropout=dropout, | |
bidirectional=True | |
) | |
self.output_projection = nn.Linear(2 * word_lstm_hidden_dim, word_lstm_hidden_dim) | |
def forward(self, word_input, sequence_lengths): | |
batch_size = word_input.size(0) | |
word_emb = self.word_embedding(word_input) | |
sequence_lengths_cpu = sequence_lengths.cpu() | |
if torch.all(sequence_lengths_cpu == 0): | |
lstm_out = torch.zeros(batch_size, word_input.size(1), self.lstm.hidden_size * 2, device=word_input.device) | |
hidden = torch.zeros(batch_size, self.lstm.hidden_size * 2, device=word_input.device) # Corrected hidden dim | |
return self.output_projection(hidden), lstm_out | |
sorted_lengths, sort_idx = sequence_lengths_cpu.sort(0, descending=True) | |
sorted_word_emb = word_emb[sort_idx] | |
packed_word_emb = nn.utils.rnn.pack_padded_sequence( | |
sorted_word_emb, | |
sorted_lengths.clamp(min=1), | |
batch_first=True, | |
enforce_sorted=True | |
) | |
packed_lstm_out, (hidden_state, cell_state) = self.lstm(packed_word_emb) | |
lstm_out, _ = nn.utils.rnn.pad_packed_sequence( | |
packed_lstm_out, | |
batch_first=True, | |
total_length=word_input.size(1) | |
) | |
_, unsort_idx = sort_idx.sort(0) | |
lstm_out = lstm_out[unsort_idx] | |
hidden_state = hidden_state.view(self.lstm.num_layers, 2, batch_size, self.lstm.hidden_size) | |
hidden_state_last_layer = hidden_state[-1] | |
final_hidden = torch.cat((hidden_state_last_layer[0], hidden_state_last_layer[1]), dim=1) | |
final_hidden = final_hidden[unsort_idx] | |
return self.output_projection(final_hidden), lstm_out | |
class HybridEncoder(nn.Module): | |
def __init__(self, char_cnn_encoder, word_lstm_encoder, hybrid_encoder_output_dim): | |
super(HybridEncoder, self).__init__() | |
self.char_cnn_encoder = char_cnn_encoder | |
self.word_lstm_encoder = word_lstm_encoder | |
self.char_hidden_size = char_cnn_encoder.output_projection.out_features | |
self.lstm_sequence_output_size = word_lstm_encoder.lstm.hidden_size * 2 | |
self.output_projection = nn.Linear(self.char_hidden_size + self.lstm_sequence_output_size, hybrid_encoder_output_dim) | |
def forward(self, char_input, word_input, sequence_lengths): | |
batch_size = char_input.size(0) | |
max_seq_len = word_input.size(1) | |
char_cnn_output = self.char_cnn_encoder(char_input) | |
sequence_lengths = sequence_lengths.to(word_input.device) | |
_, lstm_sequence_output = self.word_lstm_encoder(word_input, sequence_lengths) | |
if char_cnn_output.size(1) < max_seq_len: | |
char_cnn_output = F.pad(char_cnn_output, (0, 0, 0, max_seq_len - char_cnn_output.size(1)), "constant", 0) | |
elif char_cnn_output.size(1) > max_seq_len: | |
char_cnn_output = char_cnn_output[:, :max_seq_len, :] | |
if lstm_sequence_output.size(1) < max_seq_len: | |
lstm_sequence_output = F.pad(lstm_sequence_output, (0, 0, 0, max_seq_len - lstm_sequence_output.size(1)), "constant", 0) | |
elif lstm_sequence_output.size(1) > max_seq_len: | |
lstm_sequence_output = lstm_sequence_output[:, :max_seq_len, :] | |
hybrid_output_concat = torch.cat((char_cnn_output, lstm_sequence_output), dim=2) | |
return self.output_projection(hybrid_output_concat) | |
class DualEncoderDecoder(nn.Module): | |
def __init__(self, t5_model_name, hybrid_encoder, t5_tokenizer, freeze_t5=False): | |
super(DualEncoderDecoder, self).__init__() | |
self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name) | |
self.t5_tokenizer = t5_tokenizer | |
self.hybrid_encoder = hybrid_encoder | |
encoder_hidden_size = self.t5.config.d_model | |
hybrid_hidden_size = hybrid_encoder.output_projection.out_features | |
self.encoder_projection = nn.Linear(encoder_hidden_size + hybrid_hidden_size, encoder_hidden_size) | |
if freeze_t5: | |
for param in self.t5.parameters(): param.requires_grad = False | |
def forward(self, input_ids, attention_mask, char_input, word_input, sequence_lengths, labels=None): | |
t5_encoder_outputs_dict = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) | |
t5_encoder_last_hidden_state = t5_encoder_outputs_dict.last_hidden_state | |
sequence_lengths = sequence_lengths.to(char_input.device) | |
hybrid_encoder_output = self.hybrid_encoder(char_input, word_input, sequence_lengths) | |
common_seq_len = t5_encoder_last_hidden_state.size(1) | |
if hybrid_encoder_output.size(1) < common_seq_len: | |
hybrid_encoder_output = F.pad(hybrid_encoder_output, (0, 0, 0, common_seq_len - hybrid_encoder_output.size(1)), "constant", 0) | |
elif hybrid_encoder_output.size(1) > common_seq_len: | |
hybrid_encoder_output = hybrid_encoder_output[:, :common_seq_len, :] | |
concat_encoder_output = torch.cat((t5_encoder_last_hidden_state, hybrid_encoder_output), dim=2) | |
projected_encoder_output = self.encoder_projection(concat_encoder_output) | |
encoder_outputs_for_decoder = BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=projected_encoder_output) | |
return self.t5(encoder_outputs=encoder_outputs_for_decoder, attention_mask=attention_mask, labels=labels, return_dict=True, use_cache=False) | |
def generate(self, input_ids, attention_mask, char_input, word_input, sequence_lengths, max_length, num_beams): | |
t5_encoder_outputs_dict = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) | |
t5_encoder_last_hidden_state = t5_encoder_outputs_dict.last_hidden_state | |
sequence_lengths = sequence_lengths.to(char_input.device) | |
hybrid_encoder_output = self.hybrid_encoder(char_input, word_input, sequence_lengths) | |
common_seq_len = t5_encoder_last_hidden_state.size(1) | |
if hybrid_encoder_output.size(1) < common_seq_len: | |
hybrid_encoder_output = F.pad(hybrid_encoder_output, (0, 0, 0, common_seq_len - hybrid_encoder_output.size(1)), "constant", 0) | |
elif hybrid_encoder_output.size(1) > common_seq_len: | |
hybrid_encoder_output = hybrid_encoder_output[:, :common_seq_len, :] | |
concat_encoder_output = torch.cat((t5_encoder_last_hidden_state, hybrid_encoder_output), dim=2) | |
projected_encoder_output = self.encoder_projection(concat_encoder_output) | |
encoder_outputs_for_generate = BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=projected_encoder_output) | |
generated_ids_dict = self.t5.generate(encoder_outputs=encoder_outputs_for_generate, attention_mask=attention_mask, max_length=max_length, num_beams=num_beams, early_stopping=True, use_cache=True, return_dict_in_generate=True) | |
return generated_ids_dict.sequences | |
# --- END MODEL CLASSES --- | |
def download_model_from_gdrive(file_id, destination): | |
print(f"Attempting to download model from Google Drive (ID: {file_id}) to {destination}...") | |
try: | |
# Ensure gdown is installed (will be in Dockerfile) | |
# The --fuzzy flag can sometimes help with problematic links or large files | |
# The -O flag specifies the output file path | |
subprocess.run(["gdown", "--id", file_id, "-O", destination, "--fuzzy"], check=True) | |
print("Model downloaded successfully from Google Drive.") | |
return True | |
except subprocess.CalledProcessError as e: | |
print(f"ERROR: Failed to download model using gdown: {e}") | |
if e.stderr: | |
print(f"gdown stderr: {e.stderr.decode()}") | |
if e.stdout: | |
print(f"gdown stdout: {e.stdout.decode()}") | |
return False | |
except FileNotFoundError: | |
print("ERROR: gdown command not found. Ensure it is installed in the Docker image.") | |
return False | |
def load_checkpoint(path_to_checkpoint_file): # path_to_checkpoint_file is /app/model_files/best_model.pth | |
# Tokenizers are loaded first from the local TOKENIZER_FILES_DIR | |
t5_tokenizer_dir_path = os.path.join(TOKENIZER_FILES_DIR, 't5_tokenizer') | |
encoder_tokenizer_dir_path = os.path.join(TOKENIZER_FILES_DIR, 'encoders_tokenizer') | |
if not os.path.isdir(t5_tokenizer_dir_path): | |
raise FileNotFoundError(f"T5 tokenizer directory not found: {t5_tokenizer_dir_path}") | |
if not os.path.isdir(encoder_tokenizer_dir_path): | |
raise FileNotFoundError(f"Encoder tokenizer directory not found: {encoder_tokenizer_dir_path}") | |
print(f"Loading T5 tokenizer from: {t5_tokenizer_dir_path}") | |
CFG.t5_tokenizer = T5Tokenizer.from_pretrained( | |
t5_tokenizer_dir_path, legacy=False, model_max_length=CFG.max_len, local_files_only=True | |
) | |
if CFG.t5_tokenizer.pad_token is None: CFG.t5_tokenizer.pad_token = CFG.t5_tokenizer.eos_token | |
if CFG.t5_tokenizer.bos_token is None: CFG.t5_tokenizer.bos_token = CFG.t5_tokenizer.eos_token | |
print(f"Loading encoder tokenizer from: {encoder_tokenizer_dir_path}") | |
CFG.encoder_tokenizer = AutoTokenizer.from_pretrained( | |
encoder_tokenizer_dir_path, model_max_length=CFG.max_len, local_files_only=True | |
) | |
if CFG.encoder_tokenizer.pad_token is None: | |
print(f"Warning: Loaded encoder tokenizer from {encoder_tokenizer_dir_path} has no pad_token.") | |
# If your WordLSTMEncoder relies on pad_token_id, ensure it's set if not in config. | |
# Example: CFG.encoder_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
# CFG.encoder_tokenizer.pad_token = '[PAD]' | |
# This should ideally be handled during tokenizer saving. | |
# Now load the model checkpoint which contains weights, char_to_id etc. | |
if not os.path.exists(path_to_checkpoint_file): | |
raise FileNotFoundError(f"Model checkpoint file not found after download attempt: {path_to_checkpoint_file}") | |
print(f"Loading model checkpoint from: {path_to_checkpoint_file}") | |
checkpoint = torch.load(path_to_checkpoint_file, map_location=CFG.device) | |
loaded_config_from_checkpoint = checkpoint['config'] | |
loaded_char_to_id = checkpoint['char_to_id'] | |
# loaded_id_to_char = checkpoint['id_to_char'] # Not used in inference code provided | |
model_architecture = checkpoint['model_architecture'] | |
for key, value in loaded_config_from_checkpoint.items(): | |
setattr(CFG, key, value) | |
CFG.device = device # Ensure CPU | |
loaded_max_char_len = model_architecture.get('max_char_len', 50) | |
char_cnn_encoder = CharCNNEncoder( | |
char_vocab_size=model_architecture['char_vocab_size'], | |
char_embedding_dim=model_architecture['char_embedding_dim'], | |
char_cnn_output_dim=model_architecture['char_cnn_output_dim'], | |
kernel_sizes=model_architecture['kernel_sizes'], | |
num_filters=model_architecture['num_filters'], | |
dropout=model_architecture.get('dropout', 0.1) | |
) | |
current_encoder_vocab_size = len(CFG.encoder_tokenizer) | |
if model_architecture['word_vocab_size'] != current_encoder_vocab_size: | |
print(f"Warning: Word vocab size mismatch. Checkpoint: {model_architecture['word_vocab_size']}, " | |
f"Loaded CFG.encoder_tokenizer: {current_encoder_vocab_size}. Using loaded tokenizer's size for WordLSTM.") | |
word_lstm_encoder = WordLSTMEncoder( | |
word_vocab_size=current_encoder_vocab_size, # Use current vocab size | |
word_embedding_dim=model_architecture['word_embedding_dim'], | |
word_lstm_hidden_dim=model_architecture['word_lstm_hidden_dim'], | |
num_lstm_layers=model_architecture['num_lstm_layers'], | |
dropout=model_architecture.get('dropout', 0.1) | |
) | |
hybrid_encoder = HybridEncoder( | |
char_cnn_encoder, word_lstm_encoder, | |
hybrid_encoder_output_dim=model_architecture['hybrid_encoder_output_dim'] | |
) | |
model_base_name_for_t5 = loaded_config_from_checkpoint.get('model_name', CFG.model_name) | |
print(f"Initializing DualEncoderDecoder with T5 base: {model_base_name_for_t5}") | |
model = DualEncoderDecoder( | |
t5_model_name=model_base_name_for_t5, | |
hybrid_encoder=hybrid_encoder, | |
t5_tokenizer=CFG.t5_tokenizer | |
) | |
model.t5.resize_token_embeddings(len(CFG.t5_tokenizer)) | |
print("Loading model state_dict...") | |
model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
model.to(CFG.device) | |
model.eval() | |
print("Model loaded successfully.") | |
# Return id_to_char if needed, otherwise can remove | |
return model, loaded_char_to_id, loaded_max_char_len #, loaded_id_to_char | |
# --- Helper methods (tokenize_characters, pad_sequence, process_input) --- | |
def tokenize_characters(word, char_to_id): | |
unk_token_id = char_to_id.get("<UNK>", 0) | |
return [char_to_id.get(char, unk_token_id) for char in word] | |
def pad_sequence(sequence, max_length, pad_value): | |
if len(sequence) > max_length: sequence = sequence[:max_length] | |
return sequence + [pad_value] * (max_length - len(sequence)) | |
def process_input(text, t5_tokenizer, encoder_tokenizer, char_to_id, current_max_char_len, max_token_len): | |
t5_inputs = t5_tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_token_len, add_special_tokens=True) | |
encoder_inputs = encoder_tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_token_len) | |
t5_input_ids_squeezed = t5_inputs['input_ids'].squeeze(0) | |
t5_attention_mask_squeezed = t5_inputs['attention_mask'].squeeze(0) | |
encoder_input_ids_squeezed = encoder_inputs['input_ids'].squeeze(0) | |
actual_max_seq_len = encoder_input_ids_squeezed.shape[0] | |
char_input_tensor = torch.zeros((actual_max_seq_len, current_max_char_len), dtype=torch.long) | |
char_pad_id = char_to_id.get("<PAD>", 0) | |
for j in range(actual_max_seq_len): | |
token_id = encoder_input_ids_squeezed[j].item() | |
if token_id in encoder_tokenizer.all_special_ids: word = "" | |
else: word = encoder_tokenizer.decode([token_id], skip_special_tokens=True).strip() | |
if not word: char_ids = [char_pad_id] * current_max_char_len | |
else: | |
char_ids = tokenize_characters(word, char_to_id) | |
char_ids = pad_sequence(char_ids, current_max_char_len, char_pad_id) | |
char_input_tensor[j, :] = torch.tensor(char_ids, dtype=torch.long) | |
sequence_lengths_tensor = encoder_inputs['attention_mask'].sum(dim=1).long().squeeze() | |
if sequence_lengths_tensor.ndim == 0: sequence_lengths_tensor = sequence_lengths_tensor.unsqueeze(0) | |
return { | |
't5_input_ids': t5_input_ids_squeezed.unsqueeze(0).to(CFG.device), | |
't5_attention_mask': t5_attention_mask_squeezed.unsqueeze(0).to(CFG.device), | |
'encoder_input_ids': encoder_input_ids_squeezed.unsqueeze(0).to(CFG.device), | |
'char_input': char_input_tensor.unsqueeze(0).to(CFG.device), | |
'sequence_lengths': sequence_lengths_tensor.to(CFG.device) | |
} | |
# --- END HELPER METHODS --- | |
app = Flask(__name__) | |
# --- Global variables for the loaded model and components --- | |
loaded_model_global = None | |
loaded_char_to_id_global = None | |
loaded_max_char_len_global = None | |
# --- | |
# Model download and loading logic at startup | |
MODEL_GDRIVE_FILE_ID = "1XPICO1MAdf6OTKe_SVJfTW3G6gftVW5f" # Your Google Drive File ID | |
checkpoint_file_name = "best_model.pth" | |
local_checkpoint_path = os.path.join(MODEL_FILES_DIR, checkpoint_file_name) | |
# Create model_files directory if it doesn't exist | |
os.makedirs(MODEL_FILES_DIR, exist_ok=True) | |
os.makedirs(TOKENIZER_FILES_DIR, exist_ok=True) # Ensure tokenizer dir exists too | |
# Download the model only if it doesn't already exist | |
if not os.path.exists(local_checkpoint_path): | |
print(f"'{checkpoint_file_name}' not found locally. Attempting download...") | |
if not download_model_from_gdrive(MODEL_GDRIVE_FILE_ID, local_checkpoint_path): | |
print(f"FATAL: Model download failed. The application might not work correctly.") | |
# loaded_model_global remains None | |
else: | |
print("Model download successful.") | |
else: | |
print(f"'{checkpoint_file_name}' found locally at {local_checkpoint_path}. Skipping download.") | |
# Load model if download was successful or file already existed | |
if os.path.exists(local_checkpoint_path): | |
print("Initializing and loading model checkpoint...") | |
try: | |
# Use global variables to store loaded components | |
loaded_model_global, loaded_char_to_id_global, loaded_max_char_len_global = load_checkpoint(local_checkpoint_path) | |
print("Model and components loaded into global variables successfully.") | |
except Exception as e: | |
print(f"FATAL: Could not load model on startup from {local_checkpoint_path}: {e}") | |
# import traceback | |
# traceback.print_exc() | |
loaded_model_global = None # Indicate model loading failed | |
else: | |
if loaded_model_global is None: # If download failed and it wasn't there before | |
print("Model file is not available. Application will not function correctly.") | |
def index(): | |
return render_template('index.html') | |
def translate_text(): | |
if loaded_model_global is None: | |
return jsonify({"error": "Model is not available. Please check server logs."}), 500 | |
data = request.get_json() | |
input_text = data.get('text', '') | |
if not input_text: | |
return jsonify({"error": "No text provided"}), 400 | |
try: | |
inputs = process_input( | |
input_text, | |
CFG.t5_tokenizer, # Assumes CFG.t5_tokenizer is loaded by load_checkpoint | |
CFG.encoder_tokenizer, # Assumes CFG.encoder_tokenizer is loaded by load_checkpoint | |
loaded_char_to_id_global, | |
loaded_max_char_len_global, | |
CFG.max_len | |
) | |
with torch.no_grad(): | |
generated_ids = loaded_model_global.generate( | |
inputs['t5_input_ids'], inputs['t5_attention_mask'], | |
inputs['char_input'], inputs['encoder_input_ids'], | |
inputs['sequence_lengths'], | |
max_length=CFG.max_len, num_beams=4 | |
) | |
translation = CFG.t5_tokenizer.decode( | |
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True | |
).strip() | |
return jsonify({"translation": translation}) | |
except Exception as e: | |
print(f"Error during translation: {e}") | |
# import traceback | |
# traceback.print_exc() | |
return jsonify({"error": "An error occurred during translation."}), 500 | |
if __name__ == '__main__': | |
port = int(os.environ.get("PORT", 7860)) # HF Docker Spaces expects 7860 | |
app.run(host='0.0.0.0', port=port, debug=False) # debug=False for production | |