Spaces:
Runtime error
Runtime error
Create streamlit_app.py / pages/Information_Retrieval.py
Browse files
streamlit_app.py / pages/Information_Retrieval.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
import nltk, subprocess, sys
|
5 |
+
|
6 |
+
def install(package):
|
7 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
8 |
+
|
9 |
+
install("pyvi")
|
10 |
+
|
11 |
+
stwfilename = "Vstopword_new.txt"
|
12 |
+
punfilename = "punctuation.txt"
|
13 |
+
STW_PATH = path.join(path.dirname(__file__), stwfilename)
|
14 |
+
PUNCT_PATH = path.join(path.dirname(__file__), punfilename)
|
15 |
+
|
16 |
+
|
17 |
+
from pyvi import ViTokenizer
|
18 |
+
@st.cache_resource
|
19 |
+
def open2list_vn(path):
|
20 |
+
if path:
|
21 |
+
with open(path) as f:
|
22 |
+
line = list(f.read().splitlines())
|
23 |
+
return line
|
24 |
+
def pre_progress(input):
|
25 |
+
stw = open2list_vn(STW_PATH)
|
26 |
+
punctuations = open2list_vn(PUNCT_PATH)
|
27 |
+
textU = ViTokenizer.tokenize(input)
|
28 |
+
text = textU.lower()
|
29 |
+
tokens = []
|
30 |
+
all_tokens = []
|
31 |
+
raw = nltk.wordpunct_tokenize(text)
|
32 |
+
for token in raw:
|
33 |
+
if token not in punctuations:
|
34 |
+
tokens.append(token)
|
35 |
+
for i in range(len(tokens)):
|
36 |
+
if tokens[i] not in stw:
|
37 |
+
all_tokens.append(tokens[i])
|
38 |
+
return " ".join(all_tokens)
|
39 |
+
|
40 |
+
|
41 |
+
# from tensorflow import keras
|
42 |
+
import tensorflow as tf
|
43 |
+
from transformers import ElectraTokenizer, TFElectraForSequenceClassification
|
44 |
+
|
45 |
+
MODEL_NAME = "google/electra-small-discriminator"
|
46 |
+
MODEL_PATH = 'nguyennghia0902/textming_proj01_electra'
|
47 |
+
|
48 |
+
tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME)
|
49 |
+
|
50 |
+
id2label = {0: "FALSE", 1: "TRUE"}
|
51 |
+
label2id = {"FALSE": 0, "TRUE": 1}
|
52 |
+
loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id)
|
53 |
+
|
54 |
+
def predict(question, text):
|
55 |
+
combined = pre_progress(question + ' ' + text)
|
56 |
+
|
57 |
+
inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf')
|
58 |
+
logits = loaded_model(**inputs).logits
|
59 |
+
predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
|
60 |
+
|
61 |
+
return loaded_model.config.id2label[predicted_class_id]
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
st.set_page_config(page_title="Information Retrieval", page_icon="📝")
|
66 |
+
|
67 |
+
# giving a title to our page
|
68 |
+
st.title("Information Retrieval")
|
69 |
+
text = st.text_area(
|
70 |
+
"Please enter a text:",
|
71 |
+
placeholder="Enter your text here",
|
72 |
+
height=200,
|
73 |
+
)
|
74 |
+
question = st.text_area(
|
75 |
+
"Please enter a question:",
|
76 |
+
placeholder="Enter your question here",
|
77 |
+
height=200,
|
78 |
+
)
|
79 |
+
|
80 |
+
prediction = ""
|
81 |
+
|
82 |
+
# Create a prediction button
|
83 |
+
if st.button("Predict"):
|
84 |
+
stripped = text.strip()
|
85 |
+
if not stripped:
|
86 |
+
st.error("Please enter some text.")
|
87 |
+
return
|
88 |
+
stripped = question.strip()
|
89 |
+
if not stripped:
|
90 |
+
st.error("Please enter a question.")
|
91 |
+
return
|
92 |
+
text = text.replace("\n", "")
|
93 |
+
prediction = predict(question, text)
|
94 |
+
if prediction == "TRUE":
|
95 |
+
st.success("TRUE 😄")
|
96 |
+
else:
|
97 |
+
st.warning("FALSE 😟")
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|