Spaces:
Runtime error
Runtime error
added phrase highlights
Browse files
app.py
CHANGED
|
@@ -10,12 +10,16 @@ from input_format import *
|
|
| 10 |
from score import *
|
| 11 |
|
| 12 |
# load document scoring model
|
|
|
|
|
|
|
| 13 |
pretrained_model = 'allenai/specter'
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
|
| 15 |
doc_model = AutoModel.from_pretrained(pretrained_model)
|
|
|
|
| 16 |
|
| 17 |
# load sentence model
|
| 18 |
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
|
|
|
|
| 19 |
|
| 20 |
def get_similar_paper(
|
| 21 |
abstract_text_input,
|
|
@@ -25,8 +29,6 @@ def get_similar_paper(
|
|
| 25 |
):
|
| 26 |
input_sentences = sent_tokenize(abstract_text_input)
|
| 27 |
|
| 28 |
-
pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb'))
|
| 29 |
-
|
| 30 |
# TODO handle pdf file input
|
| 31 |
if pdf_file_input is not None:
|
| 32 |
name = None
|
|
@@ -42,7 +44,7 @@ def get_similar_paper(
|
|
| 42 |
tokenizer,
|
| 43 |
abstract_text_input,
|
| 44 |
papers,
|
| 45 |
-
batch=
|
| 46 |
)
|
| 47 |
|
| 48 |
tmp = {
|
|
|
|
| 10 |
from score import *
|
| 11 |
|
| 12 |
# load document scoring model
|
| 13 |
+
torch.cuda.is_available = lambda : False
|
| 14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
pretrained_model = 'allenai/specter'
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
|
| 17 |
doc_model = AutoModel.from_pretrained(pretrained_model)
|
| 18 |
+
doc_model.to(device)
|
| 19 |
|
| 20 |
# load sentence model
|
| 21 |
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
|
| 22 |
+
sent_model.to(device)
|
| 23 |
|
| 24 |
def get_similar_paper(
|
| 25 |
abstract_text_input,
|
|
|
|
| 29 |
):
|
| 30 |
input_sentences = sent_tokenize(abstract_text_input)
|
| 31 |
|
|
|
|
|
|
|
| 32 |
# TODO handle pdf file input
|
| 33 |
if pdf_file_input is not None:
|
| 34 |
name = None
|
|
|
|
| 44 |
tokenizer,
|
| 45 |
abstract_text_input,
|
| 46 |
papers,
|
| 47 |
+
batch=50
|
| 48 |
)
|
| 49 |
|
| 50 |
tmp = {
|
score.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from sentence_transformers import util
|
| 2 |
from nltk.tokenize import sent_tokenize
|
|
|
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
|
|
@@ -33,19 +34,52 @@ def get_words(sent):
|
|
| 33 |
sent_start_id = [] # keep track of the word index where the new sentence starts
|
| 34 |
counter = 0
|
| 35 |
for x in sent:
|
| 36 |
-
w = x.split()
|
|
|
|
| 37 |
nw = len(w)
|
| 38 |
counter += nw
|
| 39 |
words.append(w)
|
| 40 |
sent_start_id.append(counter)
|
| 41 |
-
words = [x
|
| 42 |
all_words = [item for sublist in words for item in sublist]
|
| 43 |
sent_start_id.pop()
|
| 44 |
sent_start_id = [0] + sent_start_id
|
| 45 |
assert(len(sent_start_id) == len(sent))
|
| 46 |
return words, all_words, sent_start_id
|
| 47 |
|
| 48 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
num_query_sent = sent_ids.shape[0]
|
| 50 |
num_words = len(all_words)
|
| 51 |
|
|
@@ -55,22 +89,29 @@ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
|
|
| 55 |
|
| 56 |
# for each query sentence, mark the highlight information
|
| 57 |
for i in range(num_query_sent):
|
|
|
|
| 58 |
is_selected_sent = np.zeros(num_words)
|
| 59 |
is_selected_phrase = np.zeros(num_words)
|
| 60 |
-
word_scores = np.zeros(num_words)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
|
| 64 |
#print(len(sent_start_id), sid, sid+1)
|
| 65 |
if sid+1 < len(sent_start_id):
|
| 66 |
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
| 67 |
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
| 68 |
word_scores[sent_range[0]:sent_range[1]] = sscore
|
|
|
|
|
|
|
| 69 |
else:
|
| 70 |
-
is_selected_sent[
|
| 71 |
-
word_scores[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
# TODO get phrase selection information
|
| 74 |
output[i] = {
|
| 75 |
'is_selected_sent': is_selected_sent,
|
| 76 |
'is_selected_phrase': is_selected_phrase,
|
|
@@ -79,16 +120,18 @@ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
|
|
| 79 |
|
| 80 |
return output
|
| 81 |
|
| 82 |
-
def get_highlight_info(model, text1, text2, K=
|
| 83 |
sent1 = sent_tokenize(text1) # query
|
| 84 |
sent2 = sent_tokenize(text2) # candidate
|
|
|
|
|
|
|
| 85 |
score_mat = compute_sentencewise_scores(model, sent1, sent2)
|
| 86 |
|
| 87 |
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
| 88 |
#print(sent_ids, sent_scores)
|
| 89 |
-
|
| 90 |
#print(all_words1, sent_start_id1)
|
| 91 |
-
info = mark_words(
|
| 92 |
|
| 93 |
return sent_ids, sent_scores, info
|
| 94 |
|
|
|
|
| 1 |
from sentence_transformers import util
|
| 2 |
from nltk.tokenize import sent_tokenize
|
| 3 |
+
from nltk import word_tokenize, pos_tag
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
|
|
|
|
| 34 |
sent_start_id = [] # keep track of the word index where the new sentence starts
|
| 35 |
counter = 0
|
| 36 |
for x in sent:
|
| 37 |
+
#w = x.split()
|
| 38 |
+
w = word_tokenize(x)
|
| 39 |
nw = len(w)
|
| 40 |
counter += nw
|
| 41 |
words.append(w)
|
| 42 |
sent_start_id.append(counter)
|
| 43 |
+
words = [word_tokenize(x) for x in sent]
|
| 44 |
all_words = [item for sublist in words for item in sublist]
|
| 45 |
sent_start_id.pop()
|
| 46 |
sent_start_id = [0] + sent_start_id
|
| 47 |
assert(len(sent_start_id) == len(sent))
|
| 48 |
return words, all_words, sent_start_id
|
| 49 |
|
| 50 |
+
def get_match_phrase(w1, w2):
|
| 51 |
+
# list of words for query and candidate as input
|
| 52 |
+
# return the word list and binary mask of matching phrases
|
| 53 |
+
# POS tags that should be considered for matching phrase
|
| 54 |
+
include = [
|
| 55 |
+
'JJ',
|
| 56 |
+
'JJR',
|
| 57 |
+
'JJS',
|
| 58 |
+
'MD',
|
| 59 |
+
'NN',
|
| 60 |
+
'NNS',
|
| 61 |
+
'NNP',
|
| 62 |
+
'NNPS',
|
| 63 |
+
'RB',
|
| 64 |
+
'RBR',
|
| 65 |
+
'RBS',
|
| 66 |
+
'SYM',
|
| 67 |
+
'VB',
|
| 68 |
+
'VBD',
|
| 69 |
+
'VBG',
|
| 70 |
+
'VBN',
|
| 71 |
+
'FW'
|
| 72 |
+
]
|
| 73 |
+
mask1 = np.zeros(len(w1))
|
| 74 |
+
mask2 = np.zeros(len(w2))
|
| 75 |
+
pos1 = pos_tag(w1)
|
| 76 |
+
pos2 = pos_tag(w2)
|
| 77 |
+
for i, (w, p) in enumerate(pos2):
|
| 78 |
+
if w.lower() in w1 and p in include:
|
| 79 |
+
mask2[i] = 1
|
| 80 |
+
return mask2
|
| 81 |
+
|
| 82 |
+
def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
|
| 83 |
num_query_sent = sent_ids.shape[0]
|
| 84 |
num_words = len(all_words)
|
| 85 |
|
|
|
|
| 89 |
|
| 90 |
# for each query sentence, mark the highlight information
|
| 91 |
for i in range(num_query_sent):
|
| 92 |
+
query_words = word_tokenize(query_sents[i])
|
| 93 |
is_selected_sent = np.zeros(num_words)
|
| 94 |
is_selected_phrase = np.zeros(num_words)
|
| 95 |
+
word_scores = np.zeros(num_words)
|
| 96 |
|
| 97 |
+
# for each selected sentences from the candidate, compile information
|
| 98 |
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
|
| 99 |
#print(len(sent_start_id), sid, sid+1)
|
| 100 |
if sid+1 < len(sent_start_id):
|
| 101 |
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
| 102 |
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
| 103 |
word_scores[sent_range[0]:sent_range[1]] = sscore
|
| 104 |
+
is_selected_phrase[sent_range[0]:sent_range[1]] = \
|
| 105 |
+
get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
|
| 106 |
else:
|
| 107 |
+
is_selected_sent[sent_start_id[sid]:] = 1
|
| 108 |
+
word_scores[sent_start_id[sid]:] = sscore
|
| 109 |
+
is_selected_phrase[sent_start_id[sid]:] = \
|
| 110 |
+
get_match_phrase(query_words, all_words[sent_start_id[sid]:])
|
| 111 |
+
|
| 112 |
+
# update selected phrase scores (-1 meaning a different color in gradio)
|
| 113 |
+
word_scores[is_selected_sent+is_selected_phrase==2] = -1
|
| 114 |
|
|
|
|
| 115 |
output[i] = {
|
| 116 |
'is_selected_sent': is_selected_sent,
|
| 117 |
'is_selected_phrase': is_selected_phrase,
|
|
|
|
| 120 |
|
| 121 |
return output
|
| 122 |
|
| 123 |
+
def get_highlight_info(model, text1, text2, K=None):
|
| 124 |
sent1 = sent_tokenize(text1) # query
|
| 125 |
sent2 = sent_tokenize(text2) # candidate
|
| 126 |
+
if K is None: # if K is not set, select based on the length of the candidate
|
| 127 |
+
K = int(len(sent2) / 3)
|
| 128 |
score_mat = compute_sentencewise_scores(model, sent1, sent2)
|
| 129 |
|
| 130 |
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
| 131 |
#print(sent_ids, sent_scores)
|
| 132 |
+
words2, all_words2, sent_start_id2 = get_words(sent2)
|
| 133 |
#print(all_words1, sent_start_id1)
|
| 134 |
+
info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
|
| 135 |
|
| 136 |
return sent_ids, sent_scores, info
|
| 137 |
|