nguyennghia0902 commited on
Commit
e8b04e7
·
verified ·
1 Parent(s): 89802d6

Create streamlit_app.py / pages/Information_Retrieval.py

Browse files
streamlit_app.py / pages/Information_Retrieval.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ import streamlit as st
3
+
4
+ import nltk, subprocess, sys
5
+
6
+ def install(package):
7
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
8
+
9
+ install("pyvi")
10
+
11
+ stwfilename = "Vstopword_new.txt"
12
+ punfilename = "punctuation.txt"
13
+ STW_PATH = path.join(path.dirname(__file__), stwfilename)
14
+ PUNCT_PATH = path.join(path.dirname(__file__), punfilename)
15
+
16
+
17
+ from pyvi import ViTokenizer
18
+ @st.cache_resource
19
+ def open2list_vn(path):
20
+ if path:
21
+ with open(path) as f:
22
+ line = list(f.read().splitlines())
23
+ return line
24
+ def pre_progress(input):
25
+ stw = open2list_vn(STW_PATH)
26
+ punctuations = open2list_vn(PUNCT_PATH)
27
+ textU = ViTokenizer.tokenize(input)
28
+ text = textU.lower()
29
+ tokens = []
30
+ all_tokens = []
31
+ raw = nltk.wordpunct_tokenize(text)
32
+ for token in raw:
33
+ if token not in punctuations:
34
+ tokens.append(token)
35
+ for i in range(len(tokens)):
36
+ if tokens[i] not in stw:
37
+ all_tokens.append(tokens[i])
38
+ return " ".join(all_tokens)
39
+
40
+
41
+ # from tensorflow import keras
42
+ import tensorflow as tf
43
+ from transformers import ElectraTokenizer, TFElectraForSequenceClassification
44
+
45
+ MODEL_NAME = "google/electra-small-discriminator"
46
+ MODEL_PATH = 'nguyennghia0902/textming_proj01_electra'
47
+
48
+ tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME)
49
+
50
+ id2label = {0: "FALSE", 1: "TRUE"}
51
+ label2id = {"FALSE": 0, "TRUE": 1}
52
+ loaded_model = TFElectraForSequenceClassification.from_pretrained(MODEL_PATH, id2label=id2label, label2id=label2id)
53
+
54
+ def predict(question, text):
55
+ combined = pre_progress(question + ' ' + text)
56
+
57
+ inputs = tokenizer(combined, truncation=True, padding=True, return_tensors='tf')
58
+ logits = loaded_model(**inputs).logits
59
+ predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
60
+
61
+ return loaded_model.config.id2label[predicted_class_id]
62
+
63
+
64
+ def main():
65
+ st.set_page_config(page_title="Information Retrieval", page_icon="📝")
66
+
67
+ # giving a title to our page
68
+ st.title("Information Retrieval")
69
+ text = st.text_area(
70
+ "Please enter a text:",
71
+ placeholder="Enter your text here",
72
+ height=200,
73
+ )
74
+ question = st.text_area(
75
+ "Please enter a question:",
76
+ placeholder="Enter your question here",
77
+ height=200,
78
+ )
79
+
80
+ prediction = ""
81
+
82
+ # Create a prediction button
83
+ if st.button("Predict"):
84
+ stripped = text.strip()
85
+ if not stripped:
86
+ st.error("Please enter some text.")
87
+ return
88
+ stripped = question.strip()
89
+ if not stripped:
90
+ st.error("Please enter a question.")
91
+ return
92
+ text = text.replace("\n", "")
93
+ prediction = predict(question, text)
94
+ if prediction == "TRUE":
95
+ st.success("TRUE 😄")
96
+ else:
97
+ st.warning("FALSE 😟")
98
+
99
+ if __name__ == "__main__":
100
+ main()