TiberiuCristianLeon commited on
Commit
95b5309
·
verified ·
1 Parent(s): e7d7478

Update src/translate/Translate.py

Browse files
Files changed (1) hide show
  1. src/translate/Translate.py +8 -12
src/translate/Translate.py CHANGED
@@ -7,17 +7,6 @@ from transformers import pipeline
7
 
8
  METHOD = "TRANSLATE"
9
 
10
- # Load models and tokenizers
11
- tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
12
- modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
13
-
14
- tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro")
15
- modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro")
16
-
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- modelROMENG.to(device)
19
- modelENGROM.to(device)
20
-
21
  def paraphraseTranslateMethod(requestValue: str, model: str):
22
  exception = ExceptionCustom.checkForException(requestValue, METHOD)
23
  if exception:
@@ -25,9 +14,13 @@ def paraphraseTranslateMethod(requestValue: str, model: str):
25
 
26
  tokenized_sent_list = sent_tokenize(requestValue)
27
  result_value = []
28
-
 
29
  for SENTENCE in tokenized_sent_list:
30
  if model == 'roen':
 
 
 
31
  input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
32
  output = modelROMENG.generate(
33
  input_ids=input_ids.input_ids,
@@ -39,6 +32,9 @@ def paraphraseTranslateMethod(requestValue: str, model: str):
39
  )
40
  result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0]
41
  else:
 
 
 
42
  input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
43
  output = modelENGROM.generate(
44
  input_ids=input_ids.input_ids,
 
7
 
8
  METHOD = "TRANSLATE"
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  def paraphraseTranslateMethod(requestValue: str, model: str):
11
  exception = ExceptionCustom.checkForException(requestValue, METHOD)
12
  if exception:
 
14
 
15
  tokenized_sent_list = sent_tokenize(requestValue)
16
  result_value = []
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
  for SENTENCE in tokenized_sent_list:
20
  if model == 'roen':
21
+ tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
22
+ modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
23
+ modelROMENG.to(device)
24
  input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
25
  output = modelROMENG.generate(
26
  input_ids=input_ids.input_ids,
 
32
  )
33
  result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0]
34
  else:
35
+ tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro")
36
+ modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro")
37
+ modelENGROM.to(device)
38
  input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
39
  output = modelENGROM.generate(
40
  input_ids=input_ids.input_ids,