Mhassanen commited on
Commit
2fb5dff
·
verified ·
1 Parent(s): 59bda42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -9
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import fitz
5
  import os
6
 
7
- # Load the model and tokenizer
8
  model = AutoModelForSequenceClassification.from_pretrained("REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier")
9
  tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
10
 
@@ -18,7 +17,6 @@ def extract_text_from_pdf(file_path):
18
 
19
  def predict_class(text):
20
  try:
21
- # Truncate text to maximum length of 4096 tokens
22
  max_length = 4096
23
  truncated_text = text[:max_length]
24
 
@@ -32,11 +30,9 @@ def predict_class(text):
32
  st.error(f"Error during prediction: {e}")
33
  return None
34
 
35
- # Create a directory to store uploaded files
36
  uploaded_files_dir = "uploaded_files"
37
  os.makedirs(uploaded_files_dir, exist_ok=True)
38
 
39
- # Define colors for different classes
40
  class_colors = {
41
  0: "#1f77b4", # Level 1
42
  1: "#ff7f0e", # Level 2
@@ -44,7 +40,6 @@ class_colors = {
44
  3: "#d62728" # Level 4
45
  }
46
 
47
- # Define information for each level
48
  class_info = {
49
  0: "Highly cited",
50
  1: "Average citations",
@@ -57,12 +52,10 @@ st.title("Paper Citation Classifier")
57
  option = st.radio("Select input type:", ("Text", "PDF"))
58
 
59
  if option == "Text":
60
- # Input text boxes for abstract, full text, and affiliations
61
  abstract_input = st.text_area("Enter Abstract:")
62
  full_text_input = st.text_area("Enter Full Text:")
63
  affiliations_input = st.text_area("Enter Affiliations:")
64
 
65
- # Concatenate inputs with [SEP]
66
  combined_text = f"{abstract_input} [SEP] {full_text_input} [SEP] {affiliations_input}"
67
 
68
  if st.button("Predict"):
@@ -96,12 +89,110 @@ elif option == "PDF":
96
  st.text("Extracted Text:")
97
  st.text(file_text)
98
 
99
- # Provide an option to predict from PDF text
100
  if st.button("Predict from PDF Text"):
101
  with st.spinner("Predicting..."):
102
  predicted_class = predict_class(file_text)
103
  if predicted_class is not None:
104
- class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  st.text("Predicted Class:")
106
  for i, label in enumerate(class_labels):
107
  if i == predicted_class:
 
4
  import fitz
5
  import os
6
 
 
7
  model = AutoModelForSequenceClassification.from_pretrained("REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier")
8
  tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
9
 
 
17
 
18
  def predict_class(text):
19
  try:
 
20
  max_length = 4096
21
  truncated_text = text[:max_length]
22
 
 
30
  st.error(f"Error during prediction: {e}")
31
  return None
32
 
 
33
  uploaded_files_dir = "uploaded_files"
34
  os.makedirs(uploaded_files_dir, exist_ok=True)
35
 
 
36
  class_colors = {
37
  0: "#1f77b4", # Level 1
38
  1: "#ff7f0e", # Level 2
 
40
  3: "#d62728" # Level 4
41
  }
42
 
 
43
  class_info = {
44
  0: "Highly cited",
45
  1: "Average citations",
 
52
  option = st.radio("Select input type:", ("Text", "PDF"))
53
 
54
  if option == "Text":
 
55
  abstract_input = st.text_area("Enter Abstract:")
56
  full_text_input = st.text_area("Enter Full Text:")
57
  affiliations_input = st.text_area("Enter Affiliations:")
58
 
 
59
  combined_text = f"{abstract_input} [SEP] {full_text_input} [SEP] {affiliations_input}"
60
 
61
  if st.button("Predict"):
 
89
  st.text("Extracted Text:")
90
  st.text(file_text)
91
 
 
92
  if st.button("Predict from PDF Text"):
93
  with st.spinner("Predicting..."):
94
  predicted_class = predict_class(file_text)
95
  if predicted_class is not None:
96
+ import streamlit as st
97
+ import torch
98
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
99
+ import fitz
100
+ import os
101
+
102
+
103
+ model = AutoModelForSequenceClassification.from_pretrained("REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier")
104
+ tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
105
+
106
+ def extract_text_from_pdf(file_path):
107
+ text = ''
108
+ with fitz.open(file_path) as pdf_document:
109
+ for page_number in range(pdf_document.page_count):
110
+ page = pdf_document.load_page(page_number)
111
+ text += page.get_text()
112
+ return text
113
+
114
+ def predict_class(text):
115
+ try:
116
+ max_length = 4096
117
+ truncated_text = text[:max_length]
118
+
119
+ inputs = tokenizer(truncated_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
120
+ with torch.no_grad():
121
+ outputs = model(**inputs)
122
+ logits = outputs.logits
123
+ predicted_class = torch.argmax(logits, dim=1).item()
124
+ return predicted_class
125
+ except Exception as e:
126
+ st.error(f"Error during prediction: {e}")
127
+ return None
128
+
129
+ uploaded_files_dir = "uploaded_files"
130
+ os.makedirs(uploaded_files_dir, exist_ok=True)
131
+
132
+ class_colors = {
133
+ 0: "#1f77b4", # Level 1
134
+ 1: "#ff7f0e", # Level 2
135
+ 2: "#2ca02c", # Level 3
136
+ 3: "#d62728" # Level 4
137
+ }
138
+
139
+ st.title("Paper Citation Classifier")
140
+
141
+ option = st.radio("Select input type:", ("Text", "PDF"))
142
+
143
+ if option == "Text":
144
+ abstract_input = st.text_area("Enter Abstract:")
145
+ full_text_input = st.text_area("Enter Full Text:")
146
+ affiliations_input = st.text_area("Enter Affiliations:")
147
+
148
+ categories = st.multiselect("Select categories:", ["Category 1", "Category 2", "Category 3", "Category 4"])
149
+
150
+ combined_text = f"{abstract_input} [SEP] {full_text_input} [SEP] {affiliations_input} [SEP] {' [SEP] '.join(categories)}"
151
+
152
+ if st.button("Predict"):
153
+ with st.spinner("Predicting..."):
154
+ predicted_class = predict_class(combined_text)
155
+ if predicted_class is not None:
156
+ class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"]
157
+ st.text("Predicted Class:")
158
+ for i, label in enumerate(class_labels):
159
+ if i == predicted_class:
160
+ st.markdown(
161
+ f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>',
162
+ unsafe_allow_html=True
163
+ )
164
+ else:
165
+ st.text(label)
166
+
167
+ elif option == "PDF":
168
+ uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
169
+
170
+ if uploaded_file is not None:
171
+ with st.spinner("Processing PDF..."):
172
+ file_path = os.path.join(uploaded_files_dir, uploaded_file.name)
173
+ with open(file_path, "wb") as f:
174
+ f.write(uploaded_file.getbuffer())
175
+ st.success("File uploaded successfully.")
176
+ st.text(f"File Path: {file_path}")
177
+
178
+ file_text = extract_text_from_pdf(file_path)
179
+ st.text("Extracted Text:")
180
+ st.text(file_text)
181
+
182
+ if st.button("Predict from PDF Text"):
183
+ with st.spinner("Predicting..."):
184
+ predicted_class = predict_class(file_text)
185
+ if predicted_class is not None:
186
+ class_labels = ["Level 1 (Highly Cited Paper)", "Level 2 (Average Cited Paper)", "Level 3 (More Cited Paper)", "Level 4 (Low Cited Paper)"]
187
+ st.text("Predicted Class:")
188
+ for i, label in enumerate(class_labels):
189
+ if i == predicted_class:
190
+ st.markdown(
191
+ f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>',
192
+ unsafe_allow_html=True
193
+ )
194
+ else:
195
+ st.text(label)
196
  st.text("Predicted Class:")
197
  for i, label in enumerate(class_labels):
198
  if i == predicted_class: