mjuvilla commited on
Commit
8be9040
·
1 Parent(s): 44978d8

created classes for running the translation models either from a local model or a huggingface endpoint. for now main.py only supports local models

Browse files
main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.aligner import Aligner
2
+ from src.salamandraTA7b_translator import SalamandraTA7bTranslator, SalamandraTA7bQTranslator
3
+ from src.salamandraTA7b_translator_HF import SalamandraTA7bTranslatorHF
4
+ import os
5
+ import time
6
+ import argparse
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser(
10
+ prog='main',
11
+ description='Translate a file')
12
+ parser.add_argument("-s", '--source_lang', type=str, required=True)
13
+ parser.add_argument("-t", '--target_lang', type=str, required=True)
14
+ parser.add_argument("-f", '--file_path', type=str, required=True)
15
+
16
+ parser.add_argument("-m", '--model_path', type=str, required=True)
17
+ parser.add_argument("-tt", '--translator_type', type=str,
18
+ choices=["normal", "quantized"], default="none",
19
+ help="normal=regular model; quantized=quantized model")
20
+
21
+ parser.add_argument('--fastalign_config_folder', type=str, default="fast_align_config")
22
+ parser.add_argument('--temp_folder', type=str, default="tmp")
23
+
24
+ args = parser.parse_args()
25
+
26
+ os.makedirs(args.temp_folder, exist_ok=True)
27
+
28
+ if args.translator_type == "normal":
29
+ translator = SalamandraTA7bTranslator(args.model_path)
30
+ elif args.translator_type == "quantized":
31
+ translator = SalamandraTA7bQTranslator(args.model_path)
32
+ else:
33
+ raise NotImplementedError(f"Option {args.translator_type} is not implemented.")
34
+
35
+ aligner = Aligner(args.fastalign_config_folder, args.source_lang, args.target_lang, args.temp_folder)
36
+
37
+ start_time = time.time()
38
+ for status, translated_file_name in translator.translate_document(args.file_path, args.source_lang,
39
+ args.target_lang):
40
+ if translated_file_name: # finished
41
+ break
42
+ else:
43
+ print(status)
44
+ print(f"Finished document in {time.time() - start_time} seconds")
requirements.txt CHANGED
@@ -10,4 +10,5 @@ transformers~=4.57.1
10
  torch~=2.8.0
11
  huggingface-hub~=0.36.0
12
  vllm~=0.11.0
13
- iso-639~=0.4.5
 
 
10
  torch~=2.8.0
11
  huggingface-hub~=0.36.0
12
  vllm~=0.11.0
13
+ iso-639~=0.4.5
14
+ accelerate~=1.11.0
src/salamandraTA7b_translator.py CHANGED
@@ -1,24 +1,158 @@
1
- from gradio_client import Client
2
  from iso639 import languages
 
 
 
 
 
 
3
 
4
 
5
- class SalamandraTA7bTranslator:
6
- def __init__(self, hf_token):
7
- self.client = Client("BSC-LT/SalamandraTA-7B-Demo", hf_token=hf_token)
8
 
9
- def translate(self, text, source_lang, target_lang):
10
- if not text:
11
- return ""
12
 
13
- # we assume that they are specifying the language by code so we need to convert it to name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  lang1 = languages.get(alpha2=source_lang).name
15
  lang2 = languages.get(alpha2=target_lang).name
16
- result = self.client.predict(
17
- task="Translation",
18
- source=lang1,
19
- target=lang2,
20
- input_text=text,
21
- mt_text=None,
22
- api_name="/generate_output"
23
- )
24
- return result[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from iso639 import languages
2
+ from datetime import datetime
3
+ from tqdm import tqdm
4
+ from abc import ABC, abstractmethod
5
+ import os
6
+ import shutil
7
+ from src.utils import file_to_moses, moses_to_file
8
 
9
 
10
+ def generate_batches(lines, size_batches):
11
+ return (lines[i:i + size_batches] for i in range(0, len(lines), size_batches))
 
12
 
 
 
 
13
 
14
+ def lines_to_moses(lines, out_file_path):
15
+ with open(out_file_path, "w") as out_file:
16
+ out_file.writelines(lines)
17
+
18
+
19
+ class SalamandraTA7bTranslatorAbstract(ABC):
20
+ @abstractmethod
21
+ def __init__(self, model_path):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def translate(self, lines, source_lang, target_lang):
26
+ pass
27
+
28
+ def translate_document(self, input_file, source_lang, target_lang,
29
+ temp_folder: str = "tmp", tikal_folder: str = "okapi-apps_gtk2-linux-x86_64_1.47.0"):
30
+ input_filename = input_file.split("/")[-1]
31
+ os.makedirs(temp_folder, exist_ok=True)
32
+
33
+ # copy the original file to the temporal folder to avoid common issues with tikal
34
+ temp_input_file = os.path.join(temp_folder, input_filename)
35
+ shutil.copy(input_file, temp_input_file)
36
+
37
+ original_xliff_file = os.path.join(temp_folder, input_filename + ".xlf")
38
+ plain_text_file = file_to_moses(temp_input_file, source_lang, target_lang, tikal_folder,
39
+ original_xliff_file)
40
+
41
+ lines = open(plain_text_file, "r", encoding="utf-8").read().splitlines()
42
+
43
+ translated_lines = self.translate(lines, source_lang, target_lang)
44
+
45
+ # create moses file with translated lines
46
+ translated_moses_file = os.path.join(original_xliff_file + f".{target_lang}")
47
+ lines_to_moses(translated_lines, translated_moses_file)
48
+
49
+ # recreate the document with the translations
50
+ translated_file_path = moses_to_file(translated_moses_file, source_lang, target_lang, tikal_folder,
51
+ original_xliff_file)
52
+
53
+ print(f"Saved file in {translated_file_path}")
54
+ return translated_file_path
55
+
56
+
57
+ class SalamandraTA7bTranslator(SalamandraTA7bTranslatorAbstract):
58
+
59
+ def __init__(self, model_path):
60
+ from transformers import AutoTokenizer, AutoModelForCausalLM
61
+ import torch
62
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
63
+ self.model = AutoModelForCausalLM.from_pretrained(
64
+ model_path,
65
+ device_map="auto",
66
+ dtype=torch.bfloat16
67
+ ).eval()
68
+
69
+ def translate(self, lines, source_lang, target_lang):
70
  lang1 = languages.get(alpha2=source_lang).name
71
  lang2 = languages.get(alpha2=target_lang).name
72
+
73
+ prompt_template = lambda x: f"Translate the following text from {lang1} into {lang2}.\n{lang1}: {x} \n{lang2}:"
74
+
75
+ date_string = datetime.today().strftime('%Y-%m-%d')
76
+
77
+ # Create prompts for each sentence and record the length of each prompt (before generation)
78
+
79
+ total_translated = []
80
+
81
+ batches = generate_batches(lines, 100)
82
+
83
+ with tqdm(total=len(lines), desc='Translating...') as pbar:
84
+ for batch in batches:
85
+ prompts = []
86
+ input_lengths = []
87
+ for sentence in batch:
88
+ text = prompt_template(sentence)
89
+ message = [{"role": "user", "content": text}]
90
+ prompt = self.tokenizer.apply_chat_template(
91
+ message,
92
+ tokenize=False,
93
+ add_generation_prompt=True,
94
+ date_string=date_string
95
+ )
96
+ prompts.append(prompt)
97
+ # Record the prompt length so we can later slice off the prompt tokens from the generation output
98
+ input_length = len(self.tokenizer.encode(prompt, add_special_tokens=False))
99
+ input_lengths.append(input_length)
100
+
101
+ # Batch encode the prompts with padding
102
+ inputs = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt", padding=True)
103
+ input_ids = inputs["input_ids"].to(self.model.device)
104
+ attention_mask = inputs["attention_mask"].to(self.model.device)
105
+
106
+ # Generate translations in batch
107
+ outputs = self.model.generate(
108
+ input_ids=input_ids,
109
+ attention_mask=attention_mask,
110
+ max_new_tokens=100,
111
+ early_stopping=True,
112
+ num_beams=1
113
+ )
114
+
115
+ # Decode and print each translation (slicing off the input prompt)
116
+ for i, output in enumerate(outputs):
117
+ translation = self.tokenizer.decode(output[input_lengths[i]:], skip_special_tokens=True)
118
+ total_translated.append(translation)
119
+
120
+ return total_translated
121
+
122
+
123
+ class SalamandraTA7bQTranslator(SalamandraTA7bTranslatorAbstract):
124
+ def __init__(self, model_path):
125
+ from huggingface_hub import snapshot_download
126
+ from vllm import LLM
127
+ model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main")
128
+ model_name = "salamandrata_7b_inst_q4.gguf"
129
+
130
+ self.llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir)
131
+
132
+ def translate(self, lines, source_lang, target_lang):
133
+ from vllm import SamplingParams
134
+ lang1 = languages.get(alpha2=source_lang).name
135
+ lang2 = languages.get(alpha2=target_lang).name
136
+
137
+ batches = generate_batches(lines, 100)
138
+
139
+ total_translated = []
140
+
141
+ prompt_template = lambda x: f"Translate the following text from {lang1} into {lang2}.\n{lang1}: {x} \n{lang2}:"
142
+
143
+ with tqdm(total=len(lines), desc='Translating...') as pbar:
144
+ for batch in batches:
145
+ messages = [[{"role": "user", "content": prompt_template(item)}] for item in batch]
146
+
147
+ outputs = self.llm.chat(messages,
148
+ sampling_params=SamplingParams(
149
+ temperature=0.1,
150
+ stop_token_ids=[5],
151
+ max_tokens=200)
152
+ )
153
+ translations = [item.outputs[0].text for item in outputs]
154
+ print(translations)
155
+ pbar.update(len(translations))
156
+ total_translated += translations
157
+
158
+ return total_translated
src/salamandraTA7b_translator_HF.py CHANGED
@@ -8,6 +8,7 @@ from subprocess import Popen, PIPE
8
  import re
9
 
10
  from src.aligner import Aligner
 
11
 
12
  import glob
13
  import spacy
@@ -60,8 +61,8 @@ class SalamandraTA7bTranslatorHF:
60
  shutil.copy(input_file, temp_input_file)
61
 
62
  original_xliff_file = os.path.join(temp_folder, input_filename + ".xlf")
63
- plain_text_file = doc_to_plain_text(temp_input_file, source_lang, target_lang, tikal_folder,
64
- original_xliff_file)
65
 
66
  # get paragraphs with runs
67
  paragraphs_with_runs = [get_runs_from_paragraph(line.strip(), idx) for idx, line in
@@ -137,28 +138,8 @@ class SalamandraTA7bTranslatorHF:
137
  translated_moses_file = os.path.join(original_xliff_file + f".{target_lang}")
138
  runs_to_plain_text(translated_paragraphs_with_style, translated_moses_file)
139
 
140
- # put the translations into the xlf
141
- tikal_moses_to_xliff_command = [os.path.join(tikal_folder, "tikal.sh"), "-lm", original_xliff_file, "-sl",
142
- source_lang, "-tl", target_lang, "-from", translated_moses_file, "-totrg",
143
- "-noalttrans", "-to", original_xliff_file]
144
- Popen(tikal_moses_to_xliff_command).wait()
145
-
146
- # any tags that are still <g> have not been paired between original and translated texts by tikal so we remove
147
- # them. This may happen if a word in the original language has been split in more that one words that have other
148
- # words in between, or an error in fastalign
149
- text = open(original_xliff_file).read()
150
- result = re.sub(r'<g id="\d+">(.*?)</g>', r'\1', text)
151
- open(original_xliff_file, "w").write(result)
152
-
153
- # merge into a docx again
154
- tikal_merge_doc_command = [os.path.join(tikal_folder, "tikal.sh"), "-m", original_xliff_file]
155
- final_process = Popen(tikal_merge_doc_command, stdout=PIPE, stderr=PIPE)
156
- stdout, stderr = final_process.communicate()
157
- final_process.wait()
158
-
159
- # get the path to the output file
160
- output = stdout.decode('utf-8')
161
- translated_file_path = re.search(r'(?<=Output:\s)(.*)', output)[0]
162
 
163
  print(f"Saved file in {translated_file_path}")
164
  yield "", translated_file_path
@@ -182,34 +163,6 @@ def get_leading_invisible(text):
182
  return text[:i]
183
 
184
 
185
- def doc_to_plain_text(input_file: str, source_lang: str, target_lang: str, tikal_folder: str,
186
- original_xliff_file_path: str) -> str:
187
- """
188
- Given a document, this function generates an xliff file and then a plain text file with the text contents
189
- while keeping style and formatting using tags like <g id=1> </g>
190
-
191
- Parameters:
192
- input_file: Path to document to process
193
- source_lang: Source language of the document
194
- target_lang: Target language of the document
195
- tikal_folder: Folder where tikal.sh is located
196
- original_xliff_file_path: Path to xliff file to generate, which will be use later
197
-
198
- Returns:
199
- string: Path to plain text file
200
- """
201
-
202
- tikal_xliff_command = [os.path.join(tikal_folder, "tikal.sh"), "-x", input_file, "-nocopy", "-sl", source_lang,
203
- "-tl", target_lang]
204
- Popen(tikal_xliff_command).wait()
205
-
206
- tikal_moses_command = [os.path.join(tikal_folder, "tikal.sh"), "-xm", original_xliff_file_path, "-sl", source_lang,
207
- "-tl", target_lang]
208
- Popen(tikal_moses_command).wait()
209
-
210
- return os.path.join(original_xliff_file_path + f".{source_lang}")
211
-
212
-
213
  def get_runs_from_paragraph(paragraph: str, paragraph_index: int) -> list[dict[str, str | tuple[str, ...]]]:
214
  """
215
  Given some text that may or may not contain some chunks tagged with something like <g id=1> </g>, extract each
 
8
  import re
9
 
10
  from src.aligner import Aligner
11
+ from src.utils import file_to_moses, moses_to_file
12
 
13
  import glob
14
  import spacy
 
61
  shutil.copy(input_file, temp_input_file)
62
 
63
  original_xliff_file = os.path.join(temp_folder, input_filename + ".xlf")
64
+ plain_text_file = file_to_moses(temp_input_file, source_lang, target_lang, tikal_folder,
65
+ original_xliff_file)
66
 
67
  # get paragraphs with runs
68
  paragraphs_with_runs = [get_runs_from_paragraph(line.strip(), idx) for idx, line in
 
138
  translated_moses_file = os.path.join(original_xliff_file + f".{target_lang}")
139
  runs_to_plain_text(translated_paragraphs_with_style, translated_moses_file)
140
 
141
+ translated_file_path = moses_to_file(translated_moses_file, source_lang, target_lang, tikal_folder,
142
+ original_xliff_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  print(f"Saved file in {translated_file_path}")
145
  yield "", translated_file_path
 
163
  return text[:i]
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def get_runs_from_paragraph(paragraph: str, paragraph_index: int) -> list[dict[str, str | tuple[str, ...]]]:
167
  """
168
  Given some text that may or may not contain some chunks tagged with something like <g id=1> </g>, extract each