TiberiuCristianLeon commited on
Commit
d22cb09
·
verified ·
1 Parent(s): b8db721

Update src/translate/Translate.py

Browse files
Files changed (1) hide show
  1. src/translate/Translate.py +30 -33
src/translate/Translate.py CHANGED
@@ -3,9 +3,9 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  import src.exception.Exception.Exception as ExceptionCustom
5
 
6
-
7
  METHOD = "TRANSLATE"
8
 
 
9
  tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
10
  modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
11
 
@@ -16,40 +16,37 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  modelROMENG.to(device)
17
  modelENGROM.to(device)
18
 
 
 
 
 
19
 
20
- def paraphraseTranslateMethod(requestValue : str, model: str):
21
-
22
- exception = ""
23
- result_value = ""
24
-
25
- exception = ExceptionCustom.checkForException(requestValue, METHOD)
26
- if exception != "":
27
- return "", exception
28
-
29
- tokenized_sent_list = sent_tokenize(requestValue)
30
 
31
- for SENTENCE in tokenized_sent_list:
32
  if model == 'roen':
33
  input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
34
- output = modelROMENG.generate(
35
- input_ids=input_ids1.input_ids,
36
- do_sample=True,
37
- max_length=512,
38
- top_k=90,
39
- top_p=0.97,
40
- early_stopping=False
41
- )
42
- result = tokenizerROMENG.batch_decode(output1, skip_special_tokens=True)[0]
43
  else:
44
- input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
45
-
46
- output = modelENGROM.generate(
47
- input_ids=input_ids.input_ids,
48
- do_sample=True,
49
- max_length=512,
50
- top_k=90,
51
- top_p=0.97,
52
- early_stopping=False
53
- )
54
- result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
55
- return result.strip(), model
 
 
3
  import torch
4
  import src.exception.Exception.Exception as ExceptionCustom
5
 
 
6
  METHOD = "TRANSLATE"
7
 
8
+ # Load models and tokenizers
9
  tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
10
  modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
11
 
 
16
  modelROMENG.to(device)
17
  modelENGROM.to(device)
18
 
19
+ def paraphraseTranslateMethod(requestValue: str, model: str):
20
+ exception = ExceptionCustom.checkForException(requestValue, METHOD)
21
+ if exception:
22
+ return "", exception
23
 
24
+ tokenized_sent_list = sent_tokenize(requestValue)
25
+ result_value = []
 
 
 
 
 
 
 
 
26
 
27
+ for SENTENCE in tokenized_sent_list:
28
  if model == 'roen':
29
  input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
30
+ output = modelROMENG.generate(
31
+ input_ids=input_ids.input_ids,
32
+ do_sample=True,
33
+ max_length=512,
34
+ top_k=90,
35
+ top_p=0.97,
36
+ early_stopping=False
37
+ )
38
+ result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0]
39
  else:
40
+ input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
41
+ output = modelENGROM.generate(
42
+ input_ids=input_ids.input_ids,
43
+ do_sample=True,
44
+ max_length=512,
45
+ top_k=90,
46
+ top_p=0.97,
47
+ early_stopping=False
48
+ )
49
+ result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
50
+ result_value.append(result)
51
+
52
+ return " ".join(result_value).strip(), model