nouf-sst commited on
Commit
c94e335
·
1 Parent(s): f120eee

Speed up processing

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -11,6 +11,13 @@ from torch.nn.utils.rnn import pad_sequence
11
  import numpy as np
12
  import spacy
13
 
 
 
 
 
 
 
 
14
 
15
  # ***************************** TGRL Parsing *****************************
16
 
@@ -188,8 +195,6 @@ def get_clause_token_span_for_verb(verb, doc, all_verbs):
188
 
189
  def get_clauses_list(sent):
190
 
191
- nlp = spacy.load('en_core_web_sm')
192
-
193
  doc = nlp(sent)
194
 
195
  # find part of speech, dependency tag, ancestors, and children of each token
@@ -252,11 +257,7 @@ def get_punctuations(elements):
252
  # ########## Incorrect Actor Syntax ##########
253
  def find_non_NPs(sentences):
254
 
255
- model_name = "QCRI/bert-base-multilingual-cased-pos-english"
256
- tokenizer = AutoTokenizer.from_pretrained(model_name)
257
- model = AutoModelForTokenClassification.from_pretrained(model_name)
258
-
259
- pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
260
 
261
  outputs = pipeline(sentences)
262
 
@@ -306,11 +307,7 @@ def check_softgoal_syntax(softgoals):
306
  # ########## Incorrect Task Syntax ###########
307
  def find_non_VPs(sentences):
308
 
309
- model_name = "QCRI/bert-base-multilingual-cased-pos-english"
310
- tokenizer = AutoTokenizer.from_pretrained(model_name)
311
- model = AutoModelForTokenClassification.from_pretrained(model_name)
312
-
313
- pipeline = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
314
 
315
  outputs = pipeline(sentences)
316
 
@@ -336,9 +333,6 @@ def check_task_syntax(tasks):
336
  # ########## Similarity ###########
337
  def get_similar_elements(elements_per_actor):
338
 
339
- # Load the pre-trained model
340
- model = CrossEncoder('cross-encoder/stsb-roberta-base')
341
-
342
  # Prepare sentence pair array
343
  sentence_pairs = []
344
 
@@ -349,7 +343,7 @@ def get_similar_elements(elements_per_actor):
349
  sentence_pairs.append([elements_per_actor[key][i], elements_per_actor[key][j]])
350
 
351
  # Predict semantic similarity
352
- semantic_similarity_scores = model.predict(sentence_pairs, show_progress_bar=True)
353
 
354
  similar_elements = []
355
  for index, value in enumerate(sentence_pairs):
@@ -400,16 +394,16 @@ def check_spelling(elements):
400
  # ##################################
401
 
402
  # ########## NLI ###########
403
- def do_nli(premise, hypothesis, model, tokenizer):
404
 
405
  # Tokenization
406
  token_ids = []
407
  seg_ids = []
408
  mask_ids = []
409
 
410
- premise_id = tokenizer.encode(premise, add_special_tokens = False)
411
- hypothesis_id = tokenizer.encode(hypothesis, add_special_tokens = False)
412
- pair_token_ids = [tokenizer.cls_token_id] + premise_id + [tokenizer.sep_token_id] + hypothesis_id + [tokenizer.sep_token_id]
413
  premise_len = len(premise_id)
414
  hypothesis_len = len(hypothesis_id)
415
 
@@ -426,7 +420,7 @@ def do_nli(premise, hypothesis, model, tokenizer):
426
  seg_ids = pad_sequence(seg_ids, batch_first=True)
427
 
428
  with torch.no_grad():
429
- output = model(token_ids,
430
  token_type_ids=seg_ids,
431
  attention_mask=mask_ids)
432
 
@@ -448,9 +442,6 @@ def do_nli(premise, hypothesis, model, tokenizer):
448
  # Entailment
449
  def check_entailment(decomposed_elements):
450
 
451
- model = BertForSequenceClassification.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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")
452
- tokenizer = BertTokenizer.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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", do_lower_case=True)
453
-
454
  sentence_pairs = []
455
  non_matching_elements = []
456
 
@@ -461,7 +452,7 @@ def check_entailment(decomposed_elements):
461
  sentence_pairs.append([key, i])
462
 
463
  for sentence_pair in sentence_pairs:
464
- result = do_nli(sentence_pair[0], sentence_pair[1], model, tokenizer)
465
  print(result)
466
  if result != "Entailment":
467
  non_matching_elements.append(sentence_pair)
@@ -478,9 +469,6 @@ def check_entailment(decomposed_elements):
478
  # Contradiction
479
  def check_contradiction(elements_per_actor):
480
 
481
- model = BertForSequenceClassification.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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")
482
- tokenizer = BertTokenizer.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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", do_lower_case=True)
483
-
484
  sentence_pairs = []
485
  contradicting_elements = []
486
 
@@ -493,7 +481,7 @@ def check_contradiction(elements_per_actor):
493
  #print(sentence_pairs)
494
  # Check contradiction
495
  for sentence_pair in sentence_pairs:
496
- result = do_nli(sentence_pair[0], sentence_pair[1], model, tokenizer)
497
  #print(result)
498
  if result == "Contradiction":
499
  contradicting_elements.append(sentence_pair)
 
11
  import numpy as np
12
  import spacy
13
 
14
+ # ***************************** Load needed models *****************************
15
+ nlp = spacy.load('en_core_web_sm')
16
+ pos_tokenizer = AutoTokenizer.from_pretrained("QCRI/bert-base-multilingual-cased-pos-english")
17
+ pos_model = AutoModelForTokenClassification.from_pretrained("QCRI/bert-base-multilingual-cased-pos-english")
18
+ sentences_similarity_model = CrossEncoder('cross-encoder/stsb-roberta-base')
19
+ nli_model = BertForSequenceClassification.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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")
20
+ nli_tokenizer = BertTokenizer.from_pretrained("nouf-sst/bert-base-MultiNLI", use_auth_token="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", do_lower_case=True)
21
 
22
  # ***************************** TGRL Parsing *****************************
23
 
 
195
 
196
  def get_clauses_list(sent):
197
 
 
 
198
  doc = nlp(sent)
199
 
200
  # find part of speech, dependency tag, ancestors, and children of each token
 
257
  # ########## Incorrect Actor Syntax ##########
258
  def find_non_NPs(sentences):
259
 
260
+ pipeline = TokenClassificationPipeline(model=pos_model, tokenizer=pos_tokenizer)
 
 
 
 
261
 
262
  outputs = pipeline(sentences)
263
 
 
307
  # ########## Incorrect Task Syntax ###########
308
  def find_non_VPs(sentences):
309
 
310
+ pipeline = TokenClassificationPipeline(model=pos_modelmodel, tokenizer=pos_tokenizertokenizer)
 
 
 
 
311
 
312
  outputs = pipeline(sentences)
313
 
 
333
  # ########## Similarity ###########
334
  def get_similar_elements(elements_per_actor):
335
 
 
 
 
336
  # Prepare sentence pair array
337
  sentence_pairs = []
338
 
 
343
  sentence_pairs.append([elements_per_actor[key][i], elements_per_actor[key][j]])
344
 
345
  # Predict semantic similarity
346
+ semantic_similarity_scores = sentences_similarity_model.predict(sentence_pairs, show_progress_bar=True)
347
 
348
  similar_elements = []
349
  for index, value in enumerate(sentence_pairs):
 
394
  # ##################################
395
 
396
  # ########## NLI ###########
397
+ def do_nli(premise, hypothesis):
398
 
399
  # Tokenization
400
  token_ids = []
401
  seg_ids = []
402
  mask_ids = []
403
 
404
+ premise_id = nli_tokenizertokenizer.encode(premise, add_special_tokens = False)
405
+ hypothesis_id = nli_tokenizertokenizer.encode(hypothesis, add_special_tokens = False)
406
+ pair_token_ids = [nli_tokenizertokenizer.cls_token_id] + premise_id + [nli_tokenizertokenizer.sep_token_id] + hypothesis_id + [nli_tokenizertokenizer.sep_token_id]
407
  premise_len = len(premise_id)
408
  hypothesis_len = len(hypothesis_id)
409
 
 
420
  seg_ids = pad_sequence(seg_ids, batch_first=True)
421
 
422
  with torch.no_grad():
423
+ output = nli_model(token_ids,
424
  token_type_ids=seg_ids,
425
  attention_mask=mask_ids)
426
 
 
442
  # Entailment
443
  def check_entailment(decomposed_elements):
444
 
 
 
 
445
  sentence_pairs = []
446
  non_matching_elements = []
447
 
 
452
  sentence_pairs.append([key, i])
453
 
454
  for sentence_pair in sentence_pairs:
455
+ result = do_nli(sentence_pair[0], sentence_pair[1])
456
  print(result)
457
  if result != "Entailment":
458
  non_matching_elements.append(sentence_pair)
 
469
  # Contradiction
470
  def check_contradiction(elements_per_actor):
471
 
 
 
 
472
  sentence_pairs = []
473
  contradicting_elements = []
474
 
 
481
  #print(sentence_pairs)
482
  # Check contradiction
483
  for sentence_pair in sentence_pairs:
484
+ result = do_nli(sentence_pair[0], sentence_pair[1])
485
  #print(result)
486
  if result == "Contradiction":
487
  contradicting_elements.append(sentence_pair)