Mhassanen commited on
Commit
7853cf0
·
verified ·
1 Parent(s): 2fb5dff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -106
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import fitz
5
  import os
6
 
 
7
  model = AutoModelForSequenceClassification.from_pretrained("REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier")
8
  tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
9
 
@@ -17,6 +18,7 @@ def extract_text_from_pdf(file_path):
17
 
18
  def predict_class(text):
19
  try:
 
20
  max_length = 4096
21
  truncated_text = text[:max_length]
22
 
@@ -30,105 +32,11 @@ def predict_class(text):
30
  st.error(f"Error during prediction: {e}")
31
  return None
32
 
 
33
  uploaded_files_dir = "uploaded_files"
34
  os.makedirs(uploaded_files_dir, exist_ok=True)
35
 
36
- class_colors = {
37
- 0: "#1f77b4", # Level 1
38
- 1: "#ff7f0e", # Level 2
39
- 2: "#2ca02c", # Level 3
40
- 3: "#d62728" # Level 4
41
- }
42
-
43
- class_info = {
44
- 0: "Highly cited",
45
- 1: "Average citations",
46
- 2: "More citations",
47
- 3: "Low citations"
48
- }
49
-
50
- st.title("Paper Citation Classifier")
51
-
52
- option = st.radio("Select input type:", ("Text", "PDF"))
53
-
54
- if option == "Text":
55
- abstract_input = st.text_area("Enter Abstract:")
56
- full_text_input = st.text_area("Enter Full Text:")
57
- affiliations_input = st.text_area("Enter Affiliations:")
58
-
59
- combined_text = f"{abstract_input} [SEP] {full_text_input} [SEP] {affiliations_input}"
60
-
61
- if st.button("Predict"):
62
- with st.spinner("Predicting..."):
63
- predicted_class = predict_class(combined_text)
64
- if predicted_class is not None:
65
- class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"]
66
- st.text("Predicted Class:")
67
- for i, label in enumerate(class_labels):
68
- if i == predicted_class:
69
- st.markdown(
70
- f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>',
71
- unsafe_allow_html=True
72
- )
73
- st.text(class_info[predicted_class])
74
- else:
75
- st.text(label)
76
-
77
- elif option == "PDF":
78
- uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
79
-
80
- if uploaded_file is not None:
81
- with st.spinner("Processing PDF..."):
82
- file_path = os.path.join(uploaded_files_dir, uploaded_file.name)
83
- with open(file_path, "wb") as f:
84
- f.write(uploaded_file.getbuffer())
85
- st.success("File uploaded successfully.")
86
- st.text(f"File Path: {file_path}")
87
-
88
- file_text = extract_text_from_pdf(file_path)
89
- st.text("Extracted Text:")
90
- st.text(file_text)
91
-
92
- if st.button("Predict from PDF Text"):
93
- with st.spinner("Predicting..."):
94
- predicted_class = predict_class(file_text)
95
- if predicted_class is not None:
96
- import streamlit as st
97
- import torch
98
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
99
- import fitz
100
- import os
101
-
102
-
103
- model = AutoModelForSequenceClassification.from_pretrained("REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier")
104
- tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
105
-
106
- def extract_text_from_pdf(file_path):
107
- text = ''
108
- with fitz.open(file_path) as pdf_document:
109
- for page_number in range(pdf_document.page_count):
110
- page = pdf_document.load_page(page_number)
111
- text += page.get_text()
112
- return text
113
-
114
- def predict_class(text):
115
- try:
116
- max_length = 4096
117
- truncated_text = text[:max_length]
118
-
119
- inputs = tokenizer(truncated_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
120
- with torch.no_grad():
121
- outputs = model(**inputs)
122
- logits = outputs.logits
123
- predicted_class = torch.argmax(logits, dim=1).item()
124
- return predicted_class
125
- except Exception as e:
126
- st.error(f"Error during prediction: {e}")
127
- return None
128
-
129
- uploaded_files_dir = "uploaded_files"
130
- os.makedirs(uploaded_files_dir, exist_ok=True)
131
-
132
  class_colors = {
133
  0: "#1f77b4", # Level 1
134
  1: "#ff7f0e", # Level 2
@@ -141,12 +49,15 @@ st.title("Paper Citation Classifier")
141
  option = st.radio("Select input type:", ("Text", "PDF"))
142
 
143
  if option == "Text":
 
144
  abstract_input = st.text_area("Enter Abstract:")
145
  full_text_input = st.text_area("Enter Full Text:")
146
  affiliations_input = st.text_area("Enter Affiliations:")
147
 
 
148
  categories = st.multiselect("Select categories:", ["Category 1", "Category 2", "Category 3", "Category 4"])
149
 
 
150
  combined_text = f"{abstract_input} [SEP] {full_text_input} [SEP] {affiliations_input} [SEP] {' [SEP] '.join(categories)}"
151
 
152
  if st.button("Predict"):
@@ -179,6 +90,7 @@ elif option == "PDF":
179
  st.text("Extracted Text:")
180
  st.text(file_text)
181
 
 
182
  if st.button("Predict from PDF Text"):
183
  with st.spinner("Predicting..."):
184
  predicted_class = predict_class(file_text)
@@ -193,13 +105,7 @@ elif option == "PDF":
193
  )
194
  else:
195
  st.text(label)
196
- st.text("Predicted Class:")
197
- for i, label in enumerate(class_labels):
198
- if i == predicted_class:
199
- st.markdown(
200
- f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>',
201
- unsafe_allow_html=True
202
- )
203
- st.text(class_info[predicted_class])
204
- else:
205
- st.text(label)
 
4
  import fitz
5
  import os
6
 
7
+ # Load the model and tokenizer
8
  model = AutoModelForSequenceClassification.from_pretrained("REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier")
9
  tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
10
 
 
18
 
19
  def predict_class(text):
20
  try:
21
+ # Truncate text to maximum length of 4096 tokens
22
  max_length = 4096
23
  truncated_text = text[:max_length]
24
 
 
32
  st.error(f"Error during prediction: {e}")
33
  return None
34
 
35
+ # Create a directory to store uploaded files
36
  uploaded_files_dir = "uploaded_files"
37
  os.makedirs(uploaded_files_dir, exist_ok=True)
38
 
39
+ # Define colors for different classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  class_colors = {
41
  0: "#1f77b4", # Level 1
42
  1: "#ff7f0e", # Level 2
 
49
  option = st.radio("Select input type:", ("Text", "PDF"))
50
 
51
  if option == "Text":
52
+ # Input text boxes for abstract, full text, and affiliations
53
  abstract_input = st.text_area("Enter Abstract:")
54
  full_text_input = st.text_area("Enter Full Text:")
55
  affiliations_input = st.text_area("Enter Affiliations:")
56
 
57
+ # Select categories using pills
58
  categories = st.multiselect("Select categories:", ["Category 1", "Category 2", "Category 3", "Category 4"])
59
 
60
+ # Combine selected categories with [SEP]
61
  combined_text = f"{abstract_input} [SEP] {full_text_input} [SEP] {affiliations_input} [SEP] {' [SEP] '.join(categories)}"
62
 
63
  if st.button("Predict"):
 
90
  st.text("Extracted Text:")
91
  st.text(file_text)
92
 
93
+ # Provide an option to predict from PDF text
94
  if st.button("Predict from PDF Text"):
95
  with st.spinner("Predicting..."):
96
  predicted_class = predict_class(file_text)
 
105
  )
106
  else:
107
  st.text(label)
108
+
109
+
110
+
111
+