Camille
commited on
Commit
·
8384a73
1
Parent(s):
8d252dd
fix: laod model
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import List
|
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
from transformers import BertTokenizer, TFAutoModelForMaskedLM
|
|
|
|
| 7 |
|
| 8 |
from rhyme_with_ai.utils import color_new_words, sanitize
|
| 9 |
from rhyme_with_ai.rhyme import query_rhyme_words
|
|
@@ -70,7 +71,7 @@ def start_rhyming(query, rhyme_words_options):
|
|
| 70 |
|
| 71 |
rhyme_words = rhyme_words_options[:N_RHYMES]
|
| 72 |
|
| 73 |
-
model, tokenizer = load_model(MODEL_PATH)
|
| 74 |
sentence_generator = RhymeGenerator(model, tokenizer)
|
| 75 |
sentence_generator.start(query, rhyme_words)
|
| 76 |
|
|
@@ -84,13 +85,18 @@ def start_rhyming(query, rhyme_words_options):
|
|
| 84 |
|
| 85 |
|
| 86 |
@st.cache(allow_output_mutation=True)
|
| 87 |
-
def load_model(model_path):
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
-
|
| 94 |
def display_output(status_text, query, current_sentences, previous_sentences):
|
| 95 |
print_sentences = []
|
| 96 |
for new, old in zip(current_sentences, previous_sentences):
|
|
|
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
from transformers import BertTokenizer, TFAutoModelForMaskedLM
|
| 7 |
+
from transformers import CamembertModel, CamembertTokenizer
|
| 8 |
|
| 9 |
from rhyme_with_ai.utils import color_new_words, sanitize
|
| 10 |
from rhyme_with_ai.rhyme import query_rhyme_words
|
|
|
|
| 71 |
|
| 72 |
rhyme_words = rhyme_words_options[:N_RHYMES]
|
| 73 |
|
| 74 |
+
model, tokenizer = load_model(MODEL_PATH, LANGUAGE)
|
| 75 |
sentence_generator = RhymeGenerator(model, tokenizer)
|
| 76 |
sentence_generator.start(query, rhyme_words)
|
| 77 |
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
@st.cache(allow_output_mutation=True)
|
| 88 |
+
def load_model(model_path, language):
|
| 89 |
+
if language != "french":
|
| 90 |
+
return (
|
| 91 |
+
TFAutoModelForMaskedLM.from_pretrained(model_path),
|
| 92 |
+
BertTokenizer.from_pretrained(model_path),
|
| 93 |
+
)
|
| 94 |
+
else :
|
| 95 |
+
return (
|
| 96 |
+
CamembertModel.from_pretrained(model_path),
|
| 97 |
+
CamembertTokenizer.from_pretrained(model_path),
|
| 98 |
)
|
| 99 |
|
|
|
|
| 100 |
def display_output(status_text, query, current_sentences, previous_sentences):
|
| 101 |
print_sentences = []
|
| 102 |
for new, old in zip(current_sentences, previous_sentences):
|