Spaces:
Paused
Paused
KuangDW
commited on
Commit
·
946f7f8
1
Parent(s):
dd05f29
specify local llm
Browse files- app.py +30 -17
- vecalign/plan2align.py +31 -134
app.py
CHANGED
|
@@ -9,14 +9,14 @@ from openai import OpenAI
|
|
| 9 |
from vecalign.plan2align import translate_text, external_find_best_translation
|
| 10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
from trl import AutoModelForCausalLMWithValueHead
|
| 12 |
-
from huggingface_hub import login
|
| 13 |
import spacy
|
| 14 |
import subprocess
|
| 15 |
import pkg_resources
|
| 16 |
import sys
|
| 17 |
|
| 18 |
laser_token = os.environ.get("align_enc")
|
| 19 |
-
laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=
|
| 20 |
os.environ["LASER"] = laser_path
|
| 21 |
|
| 22 |
def check_and_install(package, required_version):
|
|
@@ -54,21 +54,35 @@ except OSError:
|
|
| 54 |
download("zh_core_web_sm")
|
| 55 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
|
| 56 |
|
| 57 |
-
# ----------
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
def generate_translation(system_prompt, prompt):
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
return translation
|
| 73 |
|
| 74 |
def check_token_length(text, max_tokens=1024):
|
|
@@ -188,7 +202,7 @@ def mpc_translation(text, src_language, target_language, iterations, session_id)
|
|
| 188 |
best_score = score
|
| 189 |
return current_trans, best_score
|
| 190 |
|
| 191 |
-
# ---------- Gradio
|
| 192 |
|
| 193 |
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
|
| 194 |
good_ref_contexts_num_value, translation_methods, state):
|
|
@@ -202,7 +216,6 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
|
|
| 202 |
4. MPC 翻譯
|
| 203 |
"""
|
| 204 |
|
| 205 |
-
# 初始化各輸出內容
|
| 206 |
orig_output = ""
|
| 207 |
plan2align_output = ""
|
| 208 |
best_of_n_output = ""
|
|
@@ -214,7 +227,7 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
|
|
| 214 |
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
|
| 215 |
if "Plan2Align" in translation_methods:
|
| 216 |
plan2align_trans, best_score = plan2align_translate_text(
|
| 217 |
-
text, session_id, src_language, target_language,
|
| 218 |
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
|
| 219 |
)
|
| 220 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
|
|
|
| 9 |
from vecalign.plan2align import translate_text, external_find_best_translation
|
| 10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
from trl import AutoModelForCausalLMWithValueHead
|
| 12 |
+
from huggingface_hub import login, HfApi, snapshot_download
|
| 13 |
import spacy
|
| 14 |
import subprocess
|
| 15 |
import pkg_resources
|
| 16 |
import sys
|
| 17 |
|
| 18 |
laser_token = os.environ.get("align_enc")
|
| 19 |
+
laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=laser_token)
|
| 20 |
os.environ["LASER"] = laser_path
|
| 21 |
|
| 22 |
def check_and_install(package, required_version):
|
|
|
|
| 54 |
download("zh_core_web_sm")
|
| 55 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
|
| 56 |
|
| 57 |
+
# ---------- translation function ----------
|
| 58 |
+
|
| 59 |
+
# Initialize device
|
| 60 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 61 |
+
print(f"Using device: {device}")
|
| 62 |
+
# Load models once
|
| 63 |
+
print("Loading models...")
|
| 64 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
device_map="auto",
|
| 69 |
+
torch_dtype=torch.float16
|
| 70 |
)
|
| 71 |
|
| 72 |
def generate_translation(system_prompt, prompt):
|
| 73 |
+
messages=[
|
| 74 |
+
{"role": "system", "content": system_prompt},
|
| 75 |
+
{"role": "user", "content": prompt}
|
| 76 |
+
]
|
| 77 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
| 78 |
+
outputs = model.generate(
|
| 79 |
+
inputs,
|
| 80 |
+
max_new_tokens=512,
|
| 81 |
+
temperature=0.7,
|
| 82 |
+
top_p=0.9,
|
| 83 |
+
do_sample=True
|
| 84 |
+
)
|
| 85 |
+
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 86 |
return translation
|
| 87 |
|
| 88 |
def check_token_length(text, max_tokens=1024):
|
|
|
|
| 202 |
best_score = score
|
| 203 |
return current_trans, best_score
|
| 204 |
|
| 205 |
+
# ---------- Gradio function ----------
|
| 206 |
|
| 207 |
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
|
| 208 |
good_ref_contexts_num_value, translation_methods, state):
|
|
|
|
| 216 |
4. MPC 翻譯
|
| 217 |
"""
|
| 218 |
|
|
|
|
| 219 |
orig_output = ""
|
| 220 |
plan2align_output = ""
|
| 221 |
best_of_n_output = ""
|
|
|
|
| 227 |
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
|
| 228 |
if "Plan2Align" in translation_methods:
|
| 229 |
plan2align_trans, best_score = plan2align_translate_text(
|
| 230 |
+
text, session_id, model, tokenizer, device, src_language, target_language,
|
| 231 |
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
|
| 232 |
)
|
| 233 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
vecalign/plan2align.py
CHANGED
|
@@ -28,12 +28,6 @@ lang_map = {
|
|
| 28 |
"Chinese": ("zh", "zh_core_web_sm")
|
| 29 |
}
|
| 30 |
|
| 31 |
-
openai = OpenAI(
|
| 32 |
-
api_key="",
|
| 33 |
-
base_url="https://api.deepinfra.com/v1/openai",
|
| 34 |
-
)
|
| 35 |
-
MODEL_NAME= "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 36 |
-
|
| 37 |
################################# folder / file processing #################################
|
| 38 |
|
| 39 |
def clear_folder(folder_path):
|
|
@@ -180,7 +174,7 @@ def external_find_best_translation(evals, language, session_id):
|
|
| 180 |
|
| 181 |
################################# generating translation #################################
|
| 182 |
|
| 183 |
-
def translate_with_deepinfra(source_sentence, buffer, good_sent_size, src_language, tgt_language):
|
| 184 |
system_prompts = [
|
| 185 |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
|
| 186 |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
|
|
@@ -227,14 +221,19 @@ def translate_with_deepinfra(source_sentence, buffer, good_sent_size, src_langua
|
|
| 227 |
|
| 228 |
translations = []
|
| 229 |
for prompt in system_prompts:
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
-
translation =
|
| 238 |
|
| 239 |
print("--------------------------------------------------------------------------------")
|
| 240 |
print("\n rollout translation: \n")
|
|
@@ -264,7 +263,7 @@ def process_buffer_sentences(source_sentences, buffer):
|
|
| 264 |
translations.append(translation_map[src_sent][0])
|
| 265 |
return translations
|
| 266 |
|
| 267 |
-
def final_translate_with_deepinfra(source_sentence, source_segments, buffer, src_language, tgt_language):
|
| 268 |
translations = process_buffer_sentences(source_segments, buffer)
|
| 269 |
initial_translation = "\n".join(translations)
|
| 270 |
|
|
@@ -286,21 +285,23 @@ def final_translate_with_deepinfra(source_sentence, source_segments, buffer, src
|
|
| 286 |
|
| 287 |
print("rewrite prompt:")
|
| 288 |
print(rewrite_prompt)
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
return translation
|
| 299 |
|
| 300 |
|
| 301 |
################################# alignment functions #################################
|
| 302 |
-
|
| 303 |
-
|
| 304 |
def save_sentences_to_txt(sentences, filename):
|
| 305 |
i = 0
|
| 306 |
with open(filename, "w", encoding="utf-8") as file:
|
|
@@ -558,111 +559,13 @@ def generate_windows(source, translations):
|
|
| 558 |
|
| 559 |
################################# main function #################################
|
| 560 |
|
| 561 |
-
def saving_memory(buffer, index, iteration, final_translations_record):
|
| 562 |
-
"""
|
| 563 |
-
Save the buffer, and final_translations_record to the Memory folder.
|
| 564 |
-
"""
|
| 565 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 566 |
-
memory_folder = os.path.join(current_dir, f"{MEMORY_FOLDER}")
|
| 567 |
-
os.makedirs(memory_folder, exist_ok=True)
|
| 568 |
-
buffer_file_path = f"{MEMORY_FOLDER}/buffer_{index}_iter_{iteration}.json"
|
| 569 |
-
metadata_file_path = f"{MEMORY_FOLDER}/metadata_{index}_iter_{iteration}.json"
|
| 570 |
-
|
| 571 |
-
buffer_to_save = {key: list(value) for key, value in buffer.items()}
|
| 572 |
-
with open(buffer_file_path, "w", encoding="utf-8") as f:
|
| 573 |
-
json.dump(buffer_to_save, f, ensure_ascii=False, indent=4)
|
| 574 |
-
|
| 575 |
-
metadata = {
|
| 576 |
-
"final_translations_record": final_translations_record
|
| 577 |
-
}
|
| 578 |
-
with open(metadata_file_path, "w", encoding="utf-8") as f:
|
| 579 |
-
json.dump(metadata, f, ensure_ascii=False, indent=4)
|
| 580 |
-
|
| 581 |
-
print(f"Buffer saved to {buffer_file_path}")
|
| 582 |
-
print(f"Metadata saved to {metadata_file_path}")
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
def process_chunk():
|
| 586 |
-
|
| 587 |
-
data = pd.read_csv(csv_path)
|
| 588 |
-
for index, row in data.iterrows():
|
| 589 |
-
print("::::::::::::::::::::::: index :::::::::::::::::::::::", index, " ::::::::::::::::::::::: index :::::::::::::::::::::::", )
|
| 590 |
-
buffer = defaultdict(list)
|
| 591 |
-
|
| 592 |
-
source_sentence = row[src_lang].replace('\n', ' ')
|
| 593 |
-
source_segments = segment_sentences_by_punctuation(source_sentence, lang=src_lang)
|
| 594 |
-
|
| 595 |
-
for iteration in range(max_iterations):
|
| 596 |
-
print(f"\nStarting iteration {iteration + 1}/{max_iterations}...\n")
|
| 597 |
-
|
| 598 |
-
if iteration in stop_memory:
|
| 599 |
-
final_translations = final_translate_with_deepinfra(source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
|
| 600 |
-
print("Final Translation Method:")
|
| 601 |
-
print(final_translations)
|
| 602 |
-
final_translations_record = [final_translations]
|
| 603 |
-
saving_memory(buffer, index, iteration, final_translations_record)
|
| 604 |
-
|
| 605 |
-
if iteration == max_iterations - 1:
|
| 606 |
-
break
|
| 607 |
-
else:
|
| 608 |
-
translations = translate_with_deepinfra(source_sentence, buffer, good_ref_contexts_num+iteration, SRC_LANGUAGE, TASK_LANGUAGE)
|
| 609 |
-
|
| 610 |
-
src_windows, mt_windows_list = generate_windows(source_sentence, translations)
|
| 611 |
-
|
| 612 |
-
####################################### Evaluate translations and update buffer #######################################
|
| 613 |
-
print("Evaluate translations and update buffer ..............")
|
| 614 |
-
|
| 615 |
-
# First, store all sources and candidate translations as lists.
|
| 616 |
-
src_context_list = list(src_windows)
|
| 617 |
-
candidates_list = []
|
| 618 |
-
for window_index in range(len(src_windows)):
|
| 619 |
-
candidates = [mt_windows[window_index] for mt_windows in mt_windows_list]
|
| 620 |
-
candidates_list.append(candidates)
|
| 621 |
-
|
| 622 |
-
# Batch evaluate all candidate translations, returning the best translation and score for each source.
|
| 623 |
-
best_candidate_results = batch_rm_find_best_translation(list(zip(src_context_list, candidates_list)), TASK_LANGUAGE)
|
| 624 |
-
|
| 625 |
-
print("\n Our best candidate results:")
|
| 626 |
-
print(best_candidate_results)
|
| 627 |
-
print(" ------------------------------------------------------------------------ \n")
|
| 628 |
-
|
| 629 |
-
print("\n===== Initial buffer state =====")
|
| 630 |
-
for src, translations in buffer.items():
|
| 631 |
-
print(f"Source '{src}': {[t[0] for t in translations]}")
|
| 632 |
-
|
| 633 |
-
# Update the buffer for each source.
|
| 634 |
-
for i, src in enumerate(src_context_list):
|
| 635 |
-
best_tuple = best_candidate_results[i] # (translation, score)
|
| 636 |
-
if best_tuple[0] is not None:
|
| 637 |
-
# If the source is not yet in the buffer, initialize it.
|
| 638 |
-
if src not in buffer:
|
| 639 |
-
buffer[src] = [best_tuple]
|
| 640 |
-
print(f"[ADD] New Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}")
|
| 641 |
-
else:
|
| 642 |
-
# Directly add the new translation to the buffer.
|
| 643 |
-
buffer[src].append(best_tuple)
|
| 644 |
-
print(f"[ADD] Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}")
|
| 645 |
-
|
| 646 |
-
# Sort by score to place the best translation (highest score) at the top.
|
| 647 |
-
buffer[src].sort(key=lambda x: x[1], reverse=True)
|
| 648 |
-
print(f"[UPDATE] Source '{src}' Best Translation: '{buffer[src][0][0]}'")
|
| 649 |
-
|
| 650 |
-
print("\n===== Final buffer state =====")
|
| 651 |
-
for src, translations in buffer.items():
|
| 652 |
-
print(f"Source '{src}': {[t[0] for t in translations]}")
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
print("Final Translation:")
|
| 656 |
-
print(final_translations)
|
| 657 |
-
|
| 658 |
-
|
| 659 |
def get_lang_and_nlp(language):
|
| 660 |
if language not in lang_map:
|
| 661 |
raise ValueError(f"Unsupported language: {language}")
|
| 662 |
lang_code, model_name = lang_map[language]
|
| 663 |
return lang_code, spacy.load(model_name)
|
| 664 |
|
| 665 |
-
def translate_text(text, session_id,
|
| 666 |
src_language="Japanese",
|
| 667 |
task_language="English",
|
| 668 |
max_iterations_value=3,
|
|
@@ -699,14 +602,12 @@ def translate_text(text, session_id,
|
|
| 699 |
final_translations = None
|
| 700 |
|
| 701 |
for iteration in range(max_iterations):
|
| 702 |
-
# print(f"\nStarting iteration {iteration + 1}/{max_iterations}...\n")
|
| 703 |
if iteration in stop_memory:
|
| 704 |
-
final_translations = final_translate_with_deepinfra(source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
|
| 705 |
-
# saving_memory(buffer, 0, iteration, [final_translations])
|
| 706 |
if iteration == max_iterations - 1:
|
| 707 |
break
|
| 708 |
else:
|
| 709 |
-
translations = translate_with_deepinfra(source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE)
|
| 710 |
|
| 711 |
src_windows, mt_windows_list = generate_windows(source_sentence, translations)
|
| 712 |
# print("Evaluate translations and update buffer ..............")
|
|
@@ -741,8 +642,4 @@ def translate_text(text, session_id,
|
|
| 741 |
|
| 742 |
# print("Final Translation:")
|
| 743 |
# print(final_translations)
|
| 744 |
-
return final_translations
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
if __name__ == "__main__":
|
| 748 |
-
process_chunk()
|
|
|
|
| 28 |
"Chinese": ("zh", "zh_core_web_sm")
|
| 29 |
}
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
################################# folder / file processing #################################
|
| 32 |
|
| 33 |
def clear_folder(folder_path):
|
|
|
|
| 174 |
|
| 175 |
################################# generating translation #################################
|
| 176 |
|
| 177 |
+
def translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_sent_size, src_language, tgt_language):
|
| 178 |
system_prompts = [
|
| 179 |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
|
| 180 |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
|
|
|
|
| 221 |
|
| 222 |
translations = []
|
| 223 |
for prompt in system_prompts:
|
| 224 |
+
messages=[
|
| 225 |
+
{"role": "system", "content": prompt},
|
| 226 |
+
{"role": "user", "content": context_prompt}
|
| 227 |
+
]
|
| 228 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
| 229 |
+
outputs = model.generate(
|
| 230 |
+
inputs,
|
| 231 |
+
max_new_tokens=512,
|
| 232 |
+
temperature=0.7,
|
| 233 |
+
top_p=0.9,
|
| 234 |
+
do_sample=True
|
| 235 |
)
|
| 236 |
+
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 237 |
|
| 238 |
print("--------------------------------------------------------------------------------")
|
| 239 |
print("\n rollout translation: \n")
|
|
|
|
| 263 |
translations.append(translation_map[src_sent][0])
|
| 264 |
return translations
|
| 265 |
|
| 266 |
+
def final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, src_language, tgt_language):
|
| 267 |
translations = process_buffer_sentences(source_segments, buffer)
|
| 268 |
initial_translation = "\n".join(translations)
|
| 269 |
|
|
|
|
| 285 |
|
| 286 |
print("rewrite prompt:")
|
| 287 |
print(rewrite_prompt)
|
| 288 |
+
messages=[
|
| 289 |
+
{"role": "system", "content": "You are a helpful translator and only output the result."},
|
| 290 |
+
{"role": "user", "content": rewrite_prompt}
|
| 291 |
+
]
|
| 292 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
| 293 |
+
outputs = model.generate(
|
| 294 |
+
inputs,
|
| 295 |
+
max_new_tokens=512,
|
| 296 |
+
temperature=0.7,
|
| 297 |
+
top_p=0.9,
|
| 298 |
+
do_sample=True
|
| 299 |
+
)
|
| 300 |
+
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 301 |
return translation
|
| 302 |
|
| 303 |
|
| 304 |
################################# alignment functions #################################
|
|
|
|
|
|
|
| 305 |
def save_sentences_to_txt(sentences, filename):
|
| 306 |
i = 0
|
| 307 |
with open(filename, "w", encoding="utf-8") as file:
|
|
|
|
| 559 |
|
| 560 |
################################# main function #################################
|
| 561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
def get_lang_and_nlp(language):
|
| 563 |
if language not in lang_map:
|
| 564 |
raise ValueError(f"Unsupported language: {language}")
|
| 565 |
lang_code, model_name = lang_map[language]
|
| 566 |
return lang_code, spacy.load(model_name)
|
| 567 |
|
| 568 |
+
def translate_text(text, session_id, model, tokenizer, device,
|
| 569 |
src_language="Japanese",
|
| 570 |
task_language="English",
|
| 571 |
max_iterations_value=3,
|
|
|
|
| 602 |
final_translations = None
|
| 603 |
|
| 604 |
for iteration in range(max_iterations):
|
|
|
|
| 605 |
if iteration in stop_memory:
|
| 606 |
+
final_translations = final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
|
|
|
|
| 607 |
if iteration == max_iterations - 1:
|
| 608 |
break
|
| 609 |
else:
|
| 610 |
+
translations = translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE)
|
| 611 |
|
| 612 |
src_windows, mt_windows_list = generate_windows(source_sentence, translations)
|
| 613 |
# print("Evaluate translations and update buffer ..............")
|
|
|
|
| 642 |
|
| 643 |
# print("Final Translation:")
|
| 644 |
# print(final_translations)
|
| 645 |
+
return final_translations
|
|
|
|
|
|
|
|
|
|
|
|