Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import spaces | |
import re | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from goai_helpers.utils import MooreConverter, mark_numbers, unmark_numbers | |
from huggingface_hub import login | |
import os | |
max_length = 512 | |
auth_token = os.getenv('HF_SPACE_TOKEN') | |
login(token=auth_token) | |
def split_text_intelligently(text, max_chunk_length=80): | |
""" | |
Divise le texte en chunks en respectant les phrases complètes. | |
""" | |
# Séparation basée sur les phrases (utilise les points, points d'interrogation, etc.) | |
sentences = re.split(r'([.!?:])', text) | |
chunks = [] | |
current_chunk = "" | |
for i in range(0, len(sentences), 2): | |
# Reconstruire la phrase avec sa ponctuation | |
if i + 1 < len(sentences): | |
sentence = sentences[i] + sentences[i+1] | |
else: | |
sentence = sentences[i] | |
# Si l'ajout de cette phrase dépasse la longueur maximale, on crée un nouveau chunk | |
if len(current_chunk) + len(sentence) > max_chunk_length and current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence | |
else: | |
current_chunk += sentence | |
# Ajouter le dernier chunk s'il reste du texte | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def goai_traduction(text, src_lang, tgt_lang, max_chunk_length=80): | |
# Si le texte est trop long, le diviser en chunks | |
if len(text) > max_chunk_length: | |
chunks = split_text_intelligently(text, max_chunk_length) | |
translations = [] | |
for chunk in chunks: | |
translated_chunk = translate_chunk(chunk, src_lang, tgt_lang) | |
translations.append(translated_chunk) | |
return " ".join(translations) | |
else: | |
return translate_chunk(text, src_lang, tgt_lang) | |
def translate_chunk(text, src_lang, tgt_lang): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if src_lang == "mos_Latn" and tgt_lang == "fra_Latn": | |
model_id = "ArissBandoss/nllb-200-3.3B-mos2fr" | |
else: | |
model_id = "ArissBandoss/nllb-200-3.3B-fr2mos" | |
#model_id = "ArissBandoss/nllb-200-3.3B-mos-fr-bidirectional-peft" | |
text = mark_numbers(text) | |
print(text) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device) | |
# Configuration du tokenizer | |
tokenizer.src_lang = src_lang | |
# Tokenisation | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) | |
# ID du token de langue cible | |
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang) | |
# Paramètres de génération optimisés pour éviter les répétitions | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=tgt_lang_id, | |
max_new_tokens=512, | |
early_stopping=True | |
) | |
# Décodage | |
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
translation = unmark_numbers(translation) | |
number_converter = MooreConverter() | |
numbers = re.findall(r'\b\d+\b', translation) | |
for number in numbers: | |
moore_number = number_converter.number_to_moore(int(number),True) | |
if moore_number: # Only replace if conversion succeeded | |
translation = translation.replace(number,moore_number) | |
return translation | |
def real_time_traduction(input_text, src_lang, tgt_lang): | |
return goai_traduction(input_text, src_lang, tgt_lang) |