mjuvilla commited on
Commit
8030df1
·
1 Parent(s): 0fc4acd

added salamandraTA translation, update requirements

Browse files
Files changed (2) hide show
  1. main.py +48 -9
  2. requirements.txt +6 -1
main.py CHANGED
@@ -17,6 +17,46 @@ from subprocess import Popen, PIPE
17
  from itertools import groupby
18
  import fileinput
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Class to align original and translated sentences
22
  # based on https://github.com/mtuoc/MTUOC-server/blob/main/GetWordAlignments_fast_align.py
@@ -235,12 +275,6 @@ def generate_alignments(original_paragraphs_with_runs, translated_paragraphs, al
235
  return translated_sentences_with_style
236
 
237
 
238
- # TODO
239
- def translate_paragraph(paragraph_text):
240
- translated_paragraph = ""
241
- return translated_paragraphs
242
-
243
-
244
  # group contiguous elements with the same boolean values
245
  def group_by_style(values, detokenizer):
246
  groups = []
@@ -316,12 +350,17 @@ if __name__ == "__main__":
316
 
317
  detokenizer = TreebankWordDetokenizer()
318
 
 
 
 
319
  # translate each paragraph
320
  translated_paragraphs = []
321
- for paragraph in paragraphs_with_runs:
322
  paragraph_text = detokenizer.detokenize([run["text"] for run in paragraph])
323
- translated_paragraphs.append(translate_paragraph(paragraph_text))
324
-
 
 
325
  out_doc = Document()
326
 
327
  processed_original_paragraphs_with_runs = [preprocess_runs(runs) for runs in paragraphs_with_runs]
 
17
  from itertools import groupby
18
  import fileinput
19
 
20
+ from datetime import datetime
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM
22
+ import torch
23
+ from iso639 import languages
24
+ import tqdm
25
+
26
+
27
+ class Translator():
28
+ def __init__(self, model_path, source_lang, target_lang):
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
30
+
31
+ self.model = AutoModelForCausalLM.from_pretrained(
32
+ model_path,
33
+ device_map="auto",
34
+ torch_dtype=torch.bfloat16
35
+ )
36
+
37
+ self.prompt_f = lambda x: (f"Translate the following text from {source_lang} into "
38
+ f"{target_lang}.\n{source_lang}: {x} \n{target_lang}:")
39
+
40
+ def translate(self, text):
41
+ message = [{"role": "user", "content": self.prompt_f(text)}]
42
+ date_string = datetime.today().strftime('%Y-%m-%d')
43
+
44
+ prompt = self.tokenizer.apply_chat_template(
45
+ message,
46
+ tokenize=False,
47
+ add_generation_prompt=True,
48
+ date_string=date_string
49
+ )
50
+
51
+ inputs = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
52
+ input_length = inputs.shape[1]
53
+ outputs = self.model.generate(input_ids=inputs.to(self.model.device),
54
+ max_new_tokens=400,
55
+ early_stopping=True,
56
+ num_beams=5)
57
+
58
+ return self.tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
59
+
60
 
61
  # Class to align original and translated sentences
62
  # based on https://github.com/mtuoc/MTUOC-server/blob/main/GetWordAlignments_fast_align.py
 
275
  return translated_sentences_with_style
276
 
277
 
 
 
 
 
 
 
278
  # group contiguous elements with the same boolean values
279
  def group_by_style(values, detokenizer):
280
  groups = []
 
350
 
351
  detokenizer = TreebankWordDetokenizer()
352
 
353
+ translator = Translator("BSC-LT/salamandraTA-7b-instruct", languages.get(alpha2=source_lang).name,
354
+ languages.get(alpha2=target_lang).name)
355
+
356
  # translate each paragraph
357
  translated_paragraphs = []
358
+ for paragraph in tqdm.tqdm(paragraphs_with_runs, desc="Translating paragraphs..."):
359
  paragraph_text = detokenizer.detokenize([run["text"] for run in paragraph])
360
+ translated_paragraphs.append(translator.translate(paragraph_text))
361
+
362
+ print(translated_paragraphs)
363
+
364
  out_doc = Document()
365
 
366
  processed_original_paragraphs_with_runs = [preprocess_runs(runs) for runs in paragraphs_with_runs]
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
  nltk~=3.9.1
2
- python-docx~=1.1.2
 
 
 
 
 
 
1
  nltk~=3.9.1
2
+ python-docx~=1.1.2
3
+ torch~=2.6.0
4
+ transformers~=4.51.2
5
+ iso-639~=0.4.5
6
+ protobuf~=6.30.2
7
+ sentencepiece~=0.2.0