amiguel commited on
Commit
9a4c5ac
·
verified ·
1 Parent(s): 36256aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -54
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextStreamer
3
- from huggingface_hub import login
4
- import PyPDF2
5
- import pandas as pd
6
  import torch
 
 
 
 
 
 
7
  import time
 
8
 
9
  # Device setup
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -17,7 +20,7 @@ st.set_page_config(
17
  )
18
 
19
  # Model name
20
- MODEL_NAME = "amiguel/custom-en2fr-transformer-v1" #"Helsinki-NLP/opus-mt-en-fr"
21
 
22
  # Title with rocket emojis
23
  st.title("🚀 English to French Translator 🚀")
@@ -60,9 +63,9 @@ def process_file(uploaded_file):
60
  st.error(f"📄 Error processing file: {str(e)}")
61
  return ""
62
 
63
- # Model loading function
64
  @st.cache_resource
65
- def load_model(hf_token):
66
  try:
67
  if not hf_token:
68
  st.error("🔐 Authentication required! Please provide a Hugging Face token.")
@@ -76,49 +79,86 @@ def load_model(hf_token):
76
  token=hf_token
77
  )
78
 
79
- # Load the model with appropriate dtype for CPU/GPU compatibility
80
- dtype = torch.float16 if DEVICE == "cuda" else torch.float32
81
- model = AutoModelForSeq2SeqLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  MODEL_NAME,
83
- token=hf_token,
84
- torch_dtype=dtype,
85
- device_map="auto" # Automatically maps to CPU or GPU
86
- )
87
 
88
- return model, tokenizer
 
 
 
 
 
 
 
 
89
 
90
  except Exception as e:
91
  st.error(f"🤖 Model loading failed: {str(e)}")
92
  return None
93
 
94
- # Generation function for translation with streaming
95
- def generate_translation(input_text, model, tokenizer):
96
  try:
97
- # Tokenize the input (no prompt needed for seq2seq translation models)
98
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
99
- inputs = inputs.to(DEVICE)
100
-
101
- # Set up the streamer for real-time output
102
- streamer = TextStreamer(tokenizer, skip_special_tokens=True)
103
-
104
- # Generate translation with streaming (disable beam search)
105
  model.eval()
106
- with torch.no_grad():
107
- outputs = model.generate(
108
- input_ids=inputs["input_ids"],
109
- attention_mask=inputs["attention_mask"],
110
- max_length=512,
111
- num_beams=1, # Set to 1 to disable beam search for streaming
112
- length_penalty=1.0,
113
- early_stopping=True,
114
- streamer=streamer,
115
- return_dict_in_generate=True,
116
- output_scores=True
117
- )
118
-
119
- # Decode the full output for storage and metrics
120
- translation = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
121
- return translation, streamer
 
 
 
 
 
 
 
 
122
 
123
  except Exception as e:
124
  raise Exception(f"Generation error: {str(e)}")
@@ -139,17 +179,22 @@ if prompt := st.chat_input("Enter text to translate into French..."):
139
  st.error("🔑 Authentication required!")
140
  st.stop()
141
 
142
- # Load model if not already loaded
143
  if "model" not in st.session_state:
144
- model_data = load_model(hf_token)
145
  if model_data is None:
146
  st.error("Failed to load model. Please check your token and try again.")
147
  st.stop()
148
 
149
- st.session_state.model, st.session_state.tokenizer = model_data
 
 
150
 
151
  model = st.session_state.model
152
  tokenizer = st.session_state.tokenizer
 
 
 
153
 
154
  # Add user message
155
  with st.chat_message("user", avatar=USER_AVATAR):
@@ -170,21 +215,19 @@ if prompt := st.chat_input("Enter text to translate into French..."):
170
  response_container = st.empty()
171
  full_response = ""
172
 
173
- # Generate translation and stream output
174
- translation, streamer = generate_translation(input_text, model, tokenizer)
175
-
176
- # Streamlit will automatically display the streamed output via the TextStreamer
177
- # Collect the full response for metrics and storage
178
- full_response = translation
179
-
180
- # Update the placeholder with the final response
181
- response_container.markdown(full_response)
182
 
183
  # Calculate performance metrics
184
  end_time = time.time()
185
  input_tokens = len(tokenizer(input_text)["input_ids"])
186
  output_tokens = len(tokenizer(full_response)["input_ids"])
187
- speed = output_tokens / (end_time - start_time)
188
 
189
  # Calculate costs (hypothetical pricing model)
190
  input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
 
1
  import streamlit as st
 
 
 
 
2
  import torch
3
+ import pandas as pd
4
+ import PyPDF2
5
+ import pickle
6
+ import os
7
+ from transformers import AutoTokenizer
8
+ from huggingface_hub import login
9
  import time
10
+ from utils.ch09util import subsequent_mask # Ensure ch09util.py is available
11
 
12
  # Device setup
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
20
  )
21
 
22
  # Model name
23
+ MODEL_NAME = "amiguel/custom-en2fr-transformer-v1"
24
 
25
  # Title with rocket emojis
26
  st.title("🚀 English to French Translator 🚀")
 
63
  st.error(f"📄 Error processing file: {str(e)}")
64
  return ""
65
 
66
+ # Custom model loading function
67
  @st.cache_resource
68
+ def load_model_and_resources(hf_token):
69
  try:
70
  if not hf_token:
71
  st.error("🔐 Authentication required! Please provide a Hugging Face token.")
 
79
  token=hf_token
80
  )
81
 
82
+ # Load model
83
+ from transformers import PreTrainedModel, PretrainedConfig
84
+ class TransformerConfig(PretrainedConfig):
85
+ model_type = "custom_transformer"
86
+ def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
87
+ super().__init__(**kwargs)
88
+ self.src_vocab_size = src_vocab_size
89
+ self.tgt_vocab_size = tgt_vocab_size
90
+ self.d_model = d_model
91
+ self.d_ff = d_ff
92
+ self.h = h
93
+ self.N = N
94
+ self.dropout = dropout
95
+
96
+ class CustomTransformer(PreTrainedModel):
97
+ config_class = TransformerConfig
98
+ def __init__(self, config):
99
+ super().__init__(config)
100
+ from utils.ch09util import create_model
101
+ self.model = create_model(
102
+ config.src_vocab_size,
103
+ config.tgt_vocab_size,
104
+ N=config.N,
105
+ d_model=config.d_model,
106
+ d_ff=config.d_ff,
107
+ h=config.h,
108
+ dropout=config.dropout
109
+ )
110
+ def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
111
+ return self.model(src, tgt, src_mask, tgt_mask)
112
+
113
+ config = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token)
114
+ model = CustomTransformer.from_pretrained(
115
  MODEL_NAME,
116
+ config=config,
117
+ token=hf_token
118
+ ).to(DEVICE)
 
119
 
120
+ # Load dictionaries (assumes dict.p was uploaded to the model repo)
121
+ dict_path = "dict.p"
122
+ if not os.path.exists(dict_path):
123
+ st.error("Dictionary file (dict.p) not found. Please ensure it was uploaded to the model repository.")
124
+ return None
125
+ with open(dict_path, "rb") as fb:
126
+ en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
127
+
128
+ return model, tokenizer, en_word_dict, fr_word_dict, en_idx_dict, fr_idx_dict
129
 
130
  except Exception as e:
131
  st.error(f"🤖 Model loading failed: {str(e)}")
132
  return None
133
 
134
+ # Custom streaming generation function
135
+ def custom_streaming_generate(input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict):
136
  try:
 
 
 
 
 
 
 
 
137
  model.eval()
138
+ PAD, UNK = 0, 1
139
+ tokenized_en = ["BOS"] + tokenizer.tokenize(input_text) + ["EOS"]
140
+ enidx = [en_word_dict.get(i, UNK) for i in tokenized_en]
141
+ src = torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)
142
+ src_mask = (src != 0).unsqueeze(-2)
143
+ memory = model.model.encode(src, src_mask)
144
+ start_symbol = fr_word_dict["BOS"]
145
+ ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
146
+ for _ in range(100):
147
+ out = model.model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
148
+ prob = model.model.generator(out[:, -1])
149
+ _, next_word = torch.max(prob, dim=1)
150
+ next_word = next_word.data[0]
151
+ sym = fr_idx_dict.get(next_word, "UNK")
152
+ if sym != "EOS":
153
+ token = sym.replace("</w>", " ")
154
+ for x in '''?:;.,'("-!&)%''':
155
+ token = token.replace(f" {x}", f"{x}")
156
+ yield token
157
+ else:
158
+ break
159
+ ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
160
+ # Yield a final empty token to ensure completion
161
+ yield ""
162
 
163
  except Exception as e:
164
  raise Exception(f"Generation error: {str(e)}")
 
179
  st.error("🔑 Authentication required!")
180
  st.stop()
181
 
182
+ # Load model and resources if not already loaded
183
  if "model" not in st.session_state:
184
+ model_data = load_model_and_resources(hf_token)
185
  if model_data is None:
186
  st.error("Failed to load model. Please check your token and try again.")
187
  st.stop()
188
 
189
+ st.session_state.model, st.session_state.tokenizer, \
190
+ st.session_state.en_word_dict, st.session_state.fr_word_dict, \
191
+ st.session_state.en_idx_dict, st.session_state.fr_idx_dict = model_data
192
 
193
  model = st.session_state.model
194
  tokenizer = st.session_state.tokenizer
195
+ en_word_dict = st.session_state.en_word_dict
196
+ fr_word_dict = st.session_state.fr_word_dict
197
+ fr_idx_dict = st.session_state.fr_idx_dict
198
 
199
  # Add user message
200
  with st.chat_message("user", avatar=USER_AVATAR):
 
215
  response_container = st.empty()
216
  full_response = ""
217
 
218
+ # Stream translation tokens
219
+ for token in custom_streaming_generate(
220
+ input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict
221
+ ):
222
+ if token: # Only append non-empty tokens
223
+ full_response += token
224
+ response_container.markdown(full_response)
 
 
225
 
226
  # Calculate performance metrics
227
  end_time = time.time()
228
  input_tokens = len(tokenizer(input_text)["input_ids"])
229
  output_tokens = len(tokenizer(full_response)["input_ids"])
230
+ speed = output_tokens / (end_time - start_time) if (end_time - start_time) > 0 else 0
231
 
232
  # Calculate costs (hypothetical pricing model)
233
  input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens