Spaces:
Paused
Paused
DHRUV SHEKHAWAT
commited on
Commit
·
896bfe3
1
Parent(s):
6707fb0
Update app.py
Browse files
app.py
CHANGED
|
@@ -40,7 +40,7 @@ class TransformerChatbot(Model):
|
|
| 40 |
def create_padding_mask(self, seq):
|
| 41 |
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
|
| 42 |
return mask[:, tf.newaxis, tf.newaxis, :]
|
| 43 |
-
def completion_model(vocab_size, max_len, d_model, n_head, ff_dim, dropout_rate,weights,datafile,dict,
|
| 44 |
|
| 45 |
with open(datafile,"r") as f:
|
| 46 |
text = f.read()
|
|
@@ -77,7 +77,7 @@ def completion_model(vocab_size, max_len, d_model, n_head, ff_dim, dropout_rate,
|
|
| 77 |
given_X1 = other_num1
|
| 78 |
input_sequence1 = pad_sequences([given_X1], maxlen=max_len, padding='post')
|
| 79 |
output_sentence = ""
|
| 80 |
-
for _ in range(
|
| 81 |
predicted_token = np.argmax(chatbot.predict(input_sequence1), axis=-1)
|
| 82 |
predicted_token = predicted_token.item()
|
| 83 |
out = num_to_word[predicted_token]
|
|
@@ -94,14 +94,14 @@ def completion_model(vocab_size, max_len, d_model, n_head, ff_dim, dropout_rate,
|
|
| 94 |
st.title("UniGLM TEXT completion Model")
|
| 95 |
st.subheader("Next Word Prediction AI Model by Webraft-AI")
|
| 96 |
#Picking what NLP task you want to do
|
| 97 |
-
option = st.selectbox('Model',('
|
| 98 |
#Textbox for text user is entering
|
| 99 |
st.subheader("Enter a word from which a sentence / word would be predicted")
|
| 100 |
|
| 101 |
text2 = st.text_input('Enter word: ') #text is stored in this variable
|
| 102 |
|
| 103 |
|
| 104 |
-
if option == '
|
| 105 |
option2 = st.selectbox('Type',('word','sentence'))
|
| 106 |
if option2 == 'word':
|
| 107 |
len = 1
|
|
@@ -121,13 +121,24 @@ if option == '13M':
|
|
| 121 |
st.write(out2)
|
| 122 |
|
| 123 |
|
| 124 |
-
elif option=="
|
| 125 |
option2 = st.selectbox('Type',('word','sentence'))
|
| 126 |
if option2 == 'word':
|
| 127 |
len = 1
|
| 128 |
else:
|
| 129 |
len = 13
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
else:
|
| 132 |
out2 = "Error: Wrong Model Selected"
|
| 133 |
|
|
|
|
| 40 |
def create_padding_mask(self, seq):
|
| 41 |
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
|
| 42 |
return mask[:, tf.newaxis, tf.newaxis, :]
|
| 43 |
+
def completion_model(vocab_size, max_len, d_model, n_head, ff_dim, dropout_rate,weights,datafile,dict,len2,text2):
|
| 44 |
|
| 45 |
with open(datafile,"r") as f:
|
| 46 |
text = f.read()
|
|
|
|
| 77 |
given_X1 = other_num1
|
| 78 |
input_sequence1 = pad_sequences([given_X1], maxlen=max_len, padding='post')
|
| 79 |
output_sentence = ""
|
| 80 |
+
for _ in range(len2):
|
| 81 |
predicted_token = np.argmax(chatbot.predict(input_sequence1), axis=-1)
|
| 82 |
predicted_token = predicted_token.item()
|
| 83 |
out = num_to_word[predicted_token]
|
|
|
|
| 94 |
st.title("UniGLM TEXT completion Model")
|
| 95 |
st.subheader("Next Word Prediction AI Model by Webraft-AI")
|
| 96 |
#Picking what NLP task you want to do
|
| 97 |
+
option = st.selectbox('Model',('13M_OLD','26M_OLD')) #option is stored in this variable
|
| 98 |
#Textbox for text user is entering
|
| 99 |
st.subheader("Enter a word from which a sentence / word would be predicted")
|
| 100 |
|
| 101 |
text2 = st.text_input('Enter word: ') #text is stored in this variable
|
| 102 |
|
| 103 |
|
| 104 |
+
if option == '13M_OLD':
|
| 105 |
option2 = st.selectbox('Type',('word','sentence'))
|
| 106 |
if option2 == 'word':
|
| 107 |
len = 1
|
|
|
|
| 121 |
st.write(out2)
|
| 122 |
|
| 123 |
|
| 124 |
+
elif option=="26M_OLD":
|
| 125 |
option2 = st.selectbox('Type',('word','sentence'))
|
| 126 |
if option2 == 'word':
|
| 127 |
len = 1
|
| 128 |
else:
|
| 129 |
len = 13
|
| 130 |
+
vocab_size = 100000
|
| 131 |
+
max_len = 1
|
| 132 |
+
d_model = 128 # 64 , 1024
|
| 133 |
+
n_head = 4 # 8 , 16
|
| 134 |
+
ff_dim = 256 # 256 , 2048
|
| 135 |
+
dropout_rate = 0.1 # 0.5 , 0.2
|
| 136 |
+
weights = "predict1"
|
| 137 |
+
datafile = "data2.txt"
|
| 138 |
+
dict = "dict_predict1.bin.npz"
|
| 139 |
+
out2 = completion_model(vocab_size, max_len, d_model, n_head, ff_dim, dropout_rate,weights,datafile,dict,len,text2)
|
| 140 |
+
st.write("Predicted Text: ")
|
| 141 |
+
st.write(out2)
|
| 142 |
else:
|
| 143 |
out2 = "Error: Wrong Model Selected"
|
| 144 |
|