Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextStreamer
|
3 |
-
from huggingface_hub import login
|
4 |
-
import PyPDF2
|
5 |
-
import pandas as pd
|
6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import time
|
|
|
8 |
|
9 |
# Device setup
|
10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -17,7 +20,7 @@ st.set_page_config(
|
|
17 |
)
|
18 |
|
19 |
# Model name
|
20 |
-
MODEL_NAME = "amiguel/custom-en2fr-transformer-v1"
|
21 |
|
22 |
# Title with rocket emojis
|
23 |
st.title("🚀 English to French Translator 🚀")
|
@@ -60,9 +63,9 @@ def process_file(uploaded_file):
|
|
60 |
st.error(f"📄 Error processing file: {str(e)}")
|
61 |
return ""
|
62 |
|
63 |
-
#
|
64 |
@st.cache_resource
|
65 |
-
def
|
66 |
try:
|
67 |
if not hf_token:
|
68 |
st.error("🔐 Authentication required! Please provide a Hugging Face token.")
|
@@ -76,49 +79,86 @@ def load_model(hf_token):
|
|
76 |
token=hf_token
|
77 |
)
|
78 |
|
79 |
-
# Load
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
MODEL_NAME,
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
)
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
except Exception as e:
|
91 |
st.error(f"🤖 Model loading failed: {str(e)}")
|
92 |
return None
|
93 |
|
94 |
-
#
|
95 |
-
def
|
96 |
try:
|
97 |
-
# Tokenize the input (no prompt needed for seq2seq translation models)
|
98 |
-
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
99 |
-
inputs = inputs.to(DEVICE)
|
100 |
-
|
101 |
-
# Set up the streamer for real-time output
|
102 |
-
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
103 |
-
|
104 |
-
# Generate translation with streaming (disable beam search)
|
105 |
model.eval()
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
)
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
except Exception as e:
|
124 |
raise Exception(f"Generation error: {str(e)}")
|
@@ -139,17 +179,22 @@ if prompt := st.chat_input("Enter text to translate into French..."):
|
|
139 |
st.error("🔑 Authentication required!")
|
140 |
st.stop()
|
141 |
|
142 |
-
# Load model if not already loaded
|
143 |
if "model" not in st.session_state:
|
144 |
-
model_data =
|
145 |
if model_data is None:
|
146 |
st.error("Failed to load model. Please check your token and try again.")
|
147 |
st.stop()
|
148 |
|
149 |
-
st.session_state.model, st.session_state.tokenizer
|
|
|
|
|
150 |
|
151 |
model = st.session_state.model
|
152 |
tokenizer = st.session_state.tokenizer
|
|
|
|
|
|
|
153 |
|
154 |
# Add user message
|
155 |
with st.chat_message("user", avatar=USER_AVATAR):
|
@@ -170,21 +215,19 @@ if prompt := st.chat_input("Enter text to translate into French..."):
|
|
170 |
response_container = st.empty()
|
171 |
full_response = ""
|
172 |
|
173 |
-
#
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
# Update the placeholder with the final response
|
181 |
-
response_container.markdown(full_response)
|
182 |
|
183 |
# Calculate performance metrics
|
184 |
end_time = time.time()
|
185 |
input_tokens = len(tokenizer(input_text)["input_ids"])
|
186 |
output_tokens = len(tokenizer(full_response)["input_ids"])
|
187 |
-
speed = output_tokens / (end_time - start_time)
|
188 |
|
189 |
# Calculate costs (hypothetical pricing model)
|
190 |
input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
import pandas as pd
|
4 |
+
import PyPDF2
|
5 |
+
import pickle
|
6 |
+
import os
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
from huggingface_hub import login
|
9 |
import time
|
10 |
+
from utils.ch09util import subsequent_mask # Ensure ch09util.py is available
|
11 |
|
12 |
# Device setup
|
13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
20 |
)
|
21 |
|
22 |
# Model name
|
23 |
+
MODEL_NAME = "amiguel/custom-en2fr-transformer-v1"
|
24 |
|
25 |
# Title with rocket emojis
|
26 |
st.title("🚀 English to French Translator 🚀")
|
|
|
63 |
st.error(f"📄 Error processing file: {str(e)}")
|
64 |
return ""
|
65 |
|
66 |
+
# Custom model loading function
|
67 |
@st.cache_resource
|
68 |
+
def load_model_and_resources(hf_token):
|
69 |
try:
|
70 |
if not hf_token:
|
71 |
st.error("🔐 Authentication required! Please provide a Hugging Face token.")
|
|
|
79 |
token=hf_token
|
80 |
)
|
81 |
|
82 |
+
# Load model
|
83 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
84 |
+
class TransformerConfig(PretrainedConfig):
|
85 |
+
model_type = "custom_transformer"
|
86 |
+
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
self.src_vocab_size = src_vocab_size
|
89 |
+
self.tgt_vocab_size = tgt_vocab_size
|
90 |
+
self.d_model = d_model
|
91 |
+
self.d_ff = d_ff
|
92 |
+
self.h = h
|
93 |
+
self.N = N
|
94 |
+
self.dropout = dropout
|
95 |
+
|
96 |
+
class CustomTransformer(PreTrainedModel):
|
97 |
+
config_class = TransformerConfig
|
98 |
+
def __init__(self, config):
|
99 |
+
super().__init__(config)
|
100 |
+
from utils.ch09util import create_model
|
101 |
+
self.model = create_model(
|
102 |
+
config.src_vocab_size,
|
103 |
+
config.tgt_vocab_size,
|
104 |
+
N=config.N,
|
105 |
+
d_model=config.d_model,
|
106 |
+
d_ff=config.d_ff,
|
107 |
+
h=config.h,
|
108 |
+
dropout=config.dropout
|
109 |
+
)
|
110 |
+
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
|
111 |
+
return self.model(src, tgt, src_mask, tgt_mask)
|
112 |
+
|
113 |
+
config = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token)
|
114 |
+
model = CustomTransformer.from_pretrained(
|
115 |
MODEL_NAME,
|
116 |
+
config=config,
|
117 |
+
token=hf_token
|
118 |
+
).to(DEVICE)
|
|
|
119 |
|
120 |
+
# Load dictionaries (assumes dict.p was uploaded to the model repo)
|
121 |
+
dict_path = "dict.p"
|
122 |
+
if not os.path.exists(dict_path):
|
123 |
+
st.error("Dictionary file (dict.p) not found. Please ensure it was uploaded to the model repository.")
|
124 |
+
return None
|
125 |
+
with open(dict_path, "rb") as fb:
|
126 |
+
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
|
127 |
+
|
128 |
+
return model, tokenizer, en_word_dict, fr_word_dict, en_idx_dict, fr_idx_dict
|
129 |
|
130 |
except Exception as e:
|
131 |
st.error(f"🤖 Model loading failed: {str(e)}")
|
132 |
return None
|
133 |
|
134 |
+
# Custom streaming generation function
|
135 |
+
def custom_streaming_generate(input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict):
|
136 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
model.eval()
|
138 |
+
PAD, UNK = 0, 1
|
139 |
+
tokenized_en = ["BOS"] + tokenizer.tokenize(input_text) + ["EOS"]
|
140 |
+
enidx = [en_word_dict.get(i, UNK) for i in tokenized_en]
|
141 |
+
src = torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)
|
142 |
+
src_mask = (src != 0).unsqueeze(-2)
|
143 |
+
memory = model.model.encode(src, src_mask)
|
144 |
+
start_symbol = fr_word_dict["BOS"]
|
145 |
+
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
|
146 |
+
for _ in range(100):
|
147 |
+
out = model.model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
|
148 |
+
prob = model.model.generator(out[:, -1])
|
149 |
+
_, next_word = torch.max(prob, dim=1)
|
150 |
+
next_word = next_word.data[0]
|
151 |
+
sym = fr_idx_dict.get(next_word, "UNK")
|
152 |
+
if sym != "EOS":
|
153 |
+
token = sym.replace("</w>", " ")
|
154 |
+
for x in '''?:;.,'("-!&)%''':
|
155 |
+
token = token.replace(f" {x}", f"{x}")
|
156 |
+
yield token
|
157 |
+
else:
|
158 |
+
break
|
159 |
+
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
|
160 |
+
# Yield a final empty token to ensure completion
|
161 |
+
yield ""
|
162 |
|
163 |
except Exception as e:
|
164 |
raise Exception(f"Generation error: {str(e)}")
|
|
|
179 |
st.error("🔑 Authentication required!")
|
180 |
st.stop()
|
181 |
|
182 |
+
# Load model and resources if not already loaded
|
183 |
if "model" not in st.session_state:
|
184 |
+
model_data = load_model_and_resources(hf_token)
|
185 |
if model_data is None:
|
186 |
st.error("Failed to load model. Please check your token and try again.")
|
187 |
st.stop()
|
188 |
|
189 |
+
st.session_state.model, st.session_state.tokenizer, \
|
190 |
+
st.session_state.en_word_dict, st.session_state.fr_word_dict, \
|
191 |
+
st.session_state.en_idx_dict, st.session_state.fr_idx_dict = model_data
|
192 |
|
193 |
model = st.session_state.model
|
194 |
tokenizer = st.session_state.tokenizer
|
195 |
+
en_word_dict = st.session_state.en_word_dict
|
196 |
+
fr_word_dict = st.session_state.fr_word_dict
|
197 |
+
fr_idx_dict = st.session_state.fr_idx_dict
|
198 |
|
199 |
# Add user message
|
200 |
with st.chat_message("user", avatar=USER_AVATAR):
|
|
|
215 |
response_container = st.empty()
|
216 |
full_response = ""
|
217 |
|
218 |
+
# Stream translation tokens
|
219 |
+
for token in custom_streaming_generate(
|
220 |
+
input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict
|
221 |
+
):
|
222 |
+
if token: # Only append non-empty tokens
|
223 |
+
full_response += token
|
224 |
+
response_container.markdown(full_response)
|
|
|
|
|
225 |
|
226 |
# Calculate performance metrics
|
227 |
end_time = time.time()
|
228 |
input_tokens = len(tokenizer(input_text)["input_ids"])
|
229 |
output_tokens = len(tokenizer(full_response)["input_ids"])
|
230 |
+
speed = output_tokens / (end_time - start_time) if (end_time - start_time) > 0 else 0
|
231 |
|
232 |
# Calculate costs (hypothetical pricing model)
|
233 |
input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
|