Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import pickle
|
|
6 |
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
|
7 |
from huggingface_hub import login, hf_hub_download
|
8 |
import time
|
9 |
-
from ch09util import subsequent_mask, create_model # Import from ch09util.py in the repo
|
10 |
|
11 |
# Device setup
|
12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -18,8 +18,8 @@ st.set_page_config(
|
|
18 |
layout="centered"
|
19 |
)
|
20 |
|
21 |
-
# Model name
|
22 |
-
MODEL_NAME = "amiguel/
|
23 |
|
24 |
# Title with rocket emojis
|
25 |
st.title("🚀 English to French Translator 🚀")
|
@@ -72,7 +72,7 @@ def load_model_and_resources(hf_token):
|
|
72 |
|
73 |
login(token=hf_token)
|
74 |
|
75 |
-
# Load tokenizer
|
76 |
tokenizer = AutoTokenizer.from_pretrained(
|
77 |
MODEL_NAME,
|
78 |
token=hf_token
|
@@ -108,7 +108,7 @@ def load_model_and_resources(hf_token):
|
|
108 |
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
|
109 |
return self.model(src, tgt, src_mask, tgt_mask)
|
110 |
|
111 |
-
# Load config with validation
|
112 |
config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token).to_dict()
|
113 |
if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict:
|
114 |
st.warning(
|
@@ -135,7 +135,7 @@ def load_model_and_resources(hf_token):
|
|
135 |
|
136 |
model.eval()
|
137 |
|
138 |
-
# Load dictionaries from
|
139 |
dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=hf_token)
|
140 |
with open(dict_path, "rb") as fb:
|
141 |
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
|
|
|
6 |
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
|
7 |
from huggingface_hub import login, hf_hub_download
|
8 |
import time
|
9 |
+
from ch09util import subsequent_mask, create_model # Import from ch09util.py in the Space repo
|
10 |
|
11 |
# Device setup
|
12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
18 |
layout="centered"
|
19 |
)
|
20 |
|
21 |
+
# Model repository name (corrected to the actual model repo)
|
22 |
+
MODEL_NAME = "amiguel/custom-en2fr-transformer-v1"
|
23 |
|
24 |
# Title with rocket emojis
|
25 |
st.title("🚀 English to French Translator 🚀")
|
|
|
72 |
|
73 |
login(token=hf_token)
|
74 |
|
75 |
+
# Load tokenizer from the model repo
|
76 |
tokenizer = AutoTokenizer.from_pretrained(
|
77 |
MODEL_NAME,
|
78 |
token=hf_token
|
|
|
108 |
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
|
109 |
return self.model(src, tgt, src_mask, tgt_mask)
|
110 |
|
111 |
+
# Load config with validation from the model repo
|
112 |
config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token).to_dict()
|
113 |
if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict:
|
114 |
st.warning(
|
|
|
135 |
|
136 |
model.eval()
|
137 |
|
138 |
+
# Load dictionaries from the model repo
|
139 |
dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=hf_token)
|
140 |
with open(dict_path, "rb") as fb:
|
141 |
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
|