Spaces:
Runtime error
Runtime error
Speed up processing
Browse files
app.py
CHANGED
@@ -11,6 +11,13 @@ from torch.nn.utils.rnn import pad_sequence
|
|
11 |
import numpy as np
|
12 |
import spacy
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# ***************************** TGRL Parsing *****************************
|
16 |
|
@@ -188,8 +195,6 @@ def get_clause_token_span_for_verb(verb, doc, all_verbs):
|
|
188 |
|
189 |
def get_clauses_list(sent):
|
190 |
|
191 |
-
nlp = spacy.load('en_core_web_sm')
|
192 |
-
|
193 |
doc = nlp(sent)
|
194 |
|
195 |
# find part of speech, dependency tag, ancestors, and children of each token
|
@@ -252,11 +257,7 @@ def get_punctuations(elements):
|
|
252 |
# ########## Incorrect Actor Syntax ##########
|
253 |
def find_non_NPs(sentences):
|
254 |
|
255 |
-
|
256 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
257 |
-
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
258 |
-
|
259 |
-
pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
260 |
|
261 |
outputs = pipeline(sentences)
|
262 |
|
@@ -306,11 +307,7 @@ def check_softgoal_syntax(softgoals):
|
|
306 |
# ########## Incorrect Task Syntax ###########
|
307 |
def find_non_VPs(sentences):
|
308 |
|
309 |
-
|
310 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
311 |
-
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
312 |
-
|
313 |
-
pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
314 |
|
315 |
outputs = pipeline(sentences)
|
316 |
|
@@ -336,9 +333,6 @@ def check_task_syntax(tasks):
|
|
336 |
# ########## Similarity ###########
|
337 |
def get_similar_elements(elements_per_actor):
|
338 |
|
339 |
-
# Load the pre-trained model
|
340 |
-
model = CrossEncoder('cross-encoder/stsb-roberta-base')
|
341 |
-
|
342 |
# Prepare sentence pair array
|
343 |
sentence_pairs = []
|
344 |
|
@@ -349,7 +343,7 @@ def get_similar_elements(elements_per_actor):
|
|
349 |
sentence_pairs.append([elements_per_actor[key][i], elements_per_actor[key][j]])
|
350 |
|
351 |
# Predict semantic similarity
|
352 |
-
semantic_similarity_scores =
|
353 |
|
354 |
similar_elements = []
|
355 |
for index, value in enumerate(sentence_pairs):
|
@@ -400,16 +394,16 @@ def check_spelling(elements):
|
|
400 |
# ##################################
|
401 |
|
402 |
# ########## NLI ###########
|
403 |
-
def do_nli(premise, hypothesis
|
404 |
|
405 |
# Tokenization
|
406 |
token_ids = []
|
407 |
seg_ids = []
|
408 |
mask_ids = []
|
409 |
|
410 |
-
premise_id =
|
411 |
-
hypothesis_id =
|
412 |
-
pair_token_ids = [
|
413 |
premise_len = len(premise_id)
|
414 |
hypothesis_len = len(hypothesis_id)
|
415 |
|
@@ -426,7 +420,7 @@ def do_nli(premise, hypothesis, model, tokenizer):
|
|
426 |
seg_ids = pad_sequence(seg_ids, batch_first=True)
|
427 |
|
428 |
with torch.no_grad():
|
429 |
-
output =
|
430 |
token_type_ids=seg_ids,
|
431 |
attention_mask=mask_ids)
|
432 |
|
@@ -448,9 +442,6 @@ def do_nli(premise, hypothesis, model, tokenizer):
|
|
448 |
# Entailment
|
449 |
def check_entailment(decomposed_elements):
|
450 |
|
451 |
-
model = BertForSequenceClassification.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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")
|
452 |
-
tokenizer = BertTokenizer.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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", do_lower_case=True)
|
453 |
-
|
454 |
sentence_pairs = []
|
455 |
non_matching_elements = []
|
456 |
|
@@ -461,7 +452,7 @@ def check_entailment(decomposed_elements):
|
|
461 |
sentence_pairs.append([key, i])
|
462 |
|
463 |
for sentence_pair in sentence_pairs:
|
464 |
-
result = do_nli(sentence_pair[0], sentence_pair[1]
|
465 |
print(result)
|
466 |
if result != "Entailment":
|
467 |
non_matching_elements.append(sentence_pair)
|
@@ -478,9 +469,6 @@ def check_entailment(decomposed_elements):
|
|
478 |
# Contradiction
|
479 |
def check_contradiction(elements_per_actor):
|
480 |
|
481 |
-
model = BertForSequenceClassification.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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")
|
482 |
-
tokenizer = BertTokenizer.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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", do_lower_case=True)
|
483 |
-
|
484 |
sentence_pairs = []
|
485 |
contradicting_elements = []
|
486 |
|
@@ -493,7 +481,7 @@ def check_contradiction(elements_per_actor):
|
|
493 |
#print(sentence_pairs)
|
494 |
# Check contradiction
|
495 |
for sentence_pair in sentence_pairs:
|
496 |
-
result = do_nli(sentence_pair[0], sentence_pair[1]
|
497 |
#print(result)
|
498 |
if result == "Contradiction":
|
499 |
contradicting_elements.append(sentence_pair)
|
|
|
11 |
import numpy as np
|
12 |
import spacy
|
13 |
|
14 |
+
# ***************************** Load needed models *****************************
|
15 |
+
nlp = spacy.load('en_core_web_sm')
|
16 |
+
pos_tokenizer = AutoTokenizer.from_pretrained("QCRI/bert-base-multilingual-cased-pos-english")
|
17 |
+
pos_model = AutoModelForTokenClassification.from_pretrained("QCRI/bert-base-multilingual-cased-pos-english")
|
18 |
+
sentences_similarity_model = CrossEncoder('cross-encoder/stsb-roberta-base')
|
19 |
+
nli_model = BertForSequenceClassification.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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")
|
20 |
+
nli_tokenizer = BertTokenizer.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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", do_lower_case=True)
|
21 |
|
22 |
# ***************************** TGRL Parsing *****************************
|
23 |
|
|
|
195 |
|
196 |
def get_clauses_list(sent):
|
197 |
|
|
|
|
|
198 |
doc = nlp(sent)
|
199 |
|
200 |
# find part of speech, dependency tag, ancestors, and children of each token
|
|
|
257 |
# ########## Incorrect Actor Syntax ##########
|
258 |
def find_non_NPs(sentences):
|
259 |
|
260 |
+
pipeline = TokenClassificationPipeline(model=pos_model, tokenizer=pos_tokenizer)
|
|
|
|
|
|
|
|
|
261 |
|
262 |
outputs = pipeline(sentences)
|
263 |
|
|
|
307 |
# ########## Incorrect Task Syntax ###########
|
308 |
def find_non_VPs(sentences):
|
309 |
|
310 |
+
pipeline = TokenClassificationPipeline(model=pos_modelmodel, tokenizer=pos_tokenizertokenizer)
|
|
|
|
|
|
|
|
|
311 |
|
312 |
outputs = pipeline(sentences)
|
313 |
|
|
|
333 |
# ########## Similarity ###########
|
334 |
def get_similar_elements(elements_per_actor):
|
335 |
|
|
|
|
|
|
|
336 |
# Prepare sentence pair array
|
337 |
sentence_pairs = []
|
338 |
|
|
|
343 |
sentence_pairs.append([elements_per_actor[key][i], elements_per_actor[key][j]])
|
344 |
|
345 |
# Predict semantic similarity
|
346 |
+
semantic_similarity_scores = sentences_similarity_model.predict(sentence_pairs, show_progress_bar=True)
|
347 |
|
348 |
similar_elements = []
|
349 |
for index, value in enumerate(sentence_pairs):
|
|
|
394 |
# ##################################
|
395 |
|
396 |
# ########## NLI ###########
|
397 |
+
def do_nli(premise, hypothesis):
|
398 |
|
399 |
# Tokenization
|
400 |
token_ids = []
|
401 |
seg_ids = []
|
402 |
mask_ids = []
|
403 |
|
404 |
+
premise_id = nli_tokenizertokenizer.encode(premise, add_special_tokens = False)
|
405 |
+
hypothesis_id = nli_tokenizertokenizer.encode(hypothesis, add_special_tokens = False)
|
406 |
+
pair_token_ids = [nli_tokenizertokenizer.cls_token_id] + premise_id + [nli_tokenizertokenizer.sep_token_id] + hypothesis_id + [nli_tokenizertokenizer.sep_token_id]
|
407 |
premise_len = len(premise_id)
|
408 |
hypothesis_len = len(hypothesis_id)
|
409 |
|
|
|
420 |
seg_ids = pad_sequence(seg_ids, batch_first=True)
|
421 |
|
422 |
with torch.no_grad():
|
423 |
+
output = nli_model(token_ids,
|
424 |
token_type_ids=seg_ids,
|
425 |
attention_mask=mask_ids)
|
426 |
|
|
|
442 |
# Entailment
|
443 |
def check_entailment(decomposed_elements):
|
444 |
|
|
|
|
|
|
|
445 |
sentence_pairs = []
|
446 |
non_matching_elements = []
|
447 |
|
|
|
452 |
sentence_pairs.append([key, i])
|
453 |
|
454 |
for sentence_pair in sentence_pairs:
|
455 |
+
result = do_nli(sentence_pair[0], sentence_pair[1])
|
456 |
print(result)
|
457 |
if result != "Entailment":
|
458 |
non_matching_elements.append(sentence_pair)
|
|
|
469 |
# Contradiction
|
470 |
def check_contradiction(elements_per_actor):
|
471 |
|
|
|
|
|
|
|
472 |
sentence_pairs = []
|
473 |
contradicting_elements = []
|
474 |
|
|
|
481 |
#print(sentence_pairs)
|
482 |
# Check contradiction
|
483 |
for sentence_pair in sentence_pairs:
|
484 |
+
result = do_nli(sentence_pair[0], sentence_pair[1])
|
485 |
#print(result)
|
486 |
if result == "Contradiction":
|
487 |
contradicting_elements.append(sentence_pair)
|