Spaces:
Sleeping
Sleeping
added salamandraTA translation, update requirements
Browse files- main.py +48 -9
- requirements.txt +6 -1
main.py
CHANGED
@@ -17,6 +17,46 @@ from subprocess import Popen, PIPE
|
|
17 |
from itertools import groupby
|
18 |
import fileinput
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Class to align original and translated sentences
|
22 |
# based on https://github.com/mtuoc/MTUOC-server/blob/main/GetWordAlignments_fast_align.py
|
@@ -235,12 +275,6 @@ def generate_alignments(original_paragraphs_with_runs, translated_paragraphs, al
|
|
235 |
return translated_sentences_with_style
|
236 |
|
237 |
|
238 |
-
# TODO
|
239 |
-
def translate_paragraph(paragraph_text):
|
240 |
-
translated_paragraph = ""
|
241 |
-
return translated_paragraphs
|
242 |
-
|
243 |
-
|
244 |
# group contiguous elements with the same boolean values
|
245 |
def group_by_style(values, detokenizer):
|
246 |
groups = []
|
@@ -316,12 +350,17 @@ if __name__ == "__main__":
|
|
316 |
|
317 |
detokenizer = TreebankWordDetokenizer()
|
318 |
|
|
|
|
|
|
|
319 |
# translate each paragraph
|
320 |
translated_paragraphs = []
|
321 |
-
for paragraph in paragraphs_with_runs:
|
322 |
paragraph_text = detokenizer.detokenize([run["text"] for run in paragraph])
|
323 |
-
translated_paragraphs.append(
|
324 |
-
|
|
|
|
|
325 |
out_doc = Document()
|
326 |
|
327 |
processed_original_paragraphs_with_runs = [preprocess_runs(runs) for runs in paragraphs_with_runs]
|
|
|
17 |
from itertools import groupby
|
18 |
import fileinput
|
19 |
|
20 |
+
from datetime import datetime
|
21 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
22 |
+
import torch
|
23 |
+
from iso639 import languages
|
24 |
+
import tqdm
|
25 |
+
|
26 |
+
|
27 |
+
class Translator():
|
28 |
+
def __init__(self, model_path, source_lang, target_lang):
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
30 |
+
|
31 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
32 |
+
model_path,
|
33 |
+
device_map="auto",
|
34 |
+
torch_dtype=torch.bfloat16
|
35 |
+
)
|
36 |
+
|
37 |
+
self.prompt_f = lambda x: (f"Translate the following text from {source_lang} into "
|
38 |
+
f"{target_lang}.\n{source_lang}: {x} \n{target_lang}:")
|
39 |
+
|
40 |
+
def translate(self, text):
|
41 |
+
message = [{"role": "user", "content": self.prompt_f(text)}]
|
42 |
+
date_string = datetime.today().strftime('%Y-%m-%d')
|
43 |
+
|
44 |
+
prompt = self.tokenizer.apply_chat_template(
|
45 |
+
message,
|
46 |
+
tokenize=False,
|
47 |
+
add_generation_prompt=True,
|
48 |
+
date_string=date_string
|
49 |
+
)
|
50 |
+
|
51 |
+
inputs = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
|
52 |
+
input_length = inputs.shape[1]
|
53 |
+
outputs = self.model.generate(input_ids=inputs.to(self.model.device),
|
54 |
+
max_new_tokens=400,
|
55 |
+
early_stopping=True,
|
56 |
+
num_beams=5)
|
57 |
+
|
58 |
+
return self.tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
|
59 |
+
|
60 |
|
61 |
# Class to align original and translated sentences
|
62 |
# based on https://github.com/mtuoc/MTUOC-server/blob/main/GetWordAlignments_fast_align.py
|
|
|
275 |
return translated_sentences_with_style
|
276 |
|
277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
# group contiguous elements with the same boolean values
|
279 |
def group_by_style(values, detokenizer):
|
280 |
groups = []
|
|
|
350 |
|
351 |
detokenizer = TreebankWordDetokenizer()
|
352 |
|
353 |
+
translator = Translator("BSC-LT/salamandraTA-7b-instruct", languages.get(alpha2=source_lang).name,
|
354 |
+
languages.get(alpha2=target_lang).name)
|
355 |
+
|
356 |
# translate each paragraph
|
357 |
translated_paragraphs = []
|
358 |
+
for paragraph in tqdm.tqdm(paragraphs_with_runs, desc="Translating paragraphs..."):
|
359 |
paragraph_text = detokenizer.detokenize([run["text"] for run in paragraph])
|
360 |
+
translated_paragraphs.append(translator.translate(paragraph_text))
|
361 |
+
|
362 |
+
print(translated_paragraphs)
|
363 |
+
|
364 |
out_doc = Document()
|
365 |
|
366 |
processed_original_paragraphs_with_runs = [preprocess_runs(runs) for runs in paragraphs_with_runs]
|
requirements.txt
CHANGED
@@ -1,2 +1,7 @@
|
|
1 |
nltk~=3.9.1
|
2 |
-
python-docx~=1.1.2
|
|
|
|
|
|
|
|
|
|
|
|
1 |
nltk~=3.9.1
|
2 |
+
python-docx~=1.1.2
|
3 |
+
torch~=2.6.0
|
4 |
+
transformers~=4.51.2
|
5 |
+
iso-639~=0.4.5
|
6 |
+
protobuf~=6.30.2
|
7 |
+
sentencepiece~=0.2.0
|