File size: 4,025 Bytes
78d1101 54108c5 78d1101 54108c5 1c7cbff 78d1101 a6284c0 54108c5 a6284c0 54108c5 a6284c0 54108c5 a6284c0 54108c5 a6284c0 54108c5 a6284c0 54108c5 a6284c0 54108c5 78d1101 a6284c0 54108c5 78d1101 261a5aa 4a24c4f 78d1101 7f3d8a9 dfb286c fb3e214 7f3d8a9 261a5aa 7f3d8a9 54108c5 7f3d8a9 261a5aa 4a24c4f a6284c0 c220549 54108c5 a6284c0 54108c5 c220549 4a24c4f 7cc5f39 4a24c4f 261a5aa 20a6b7b a6284c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import spaces
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
import os
max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)
def get_tokenizer(src_lang, tgt_lang):
"""Initialise et retourne le tokenizer approprié"""
if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
model_id = "ArissBandoss/3b-new-400"
else:
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
return AutoTokenizer.from_pretrained(model_id, token=auth_token)
def split_text_by_tokens(text, src_lang, tgt_lang, max_tokens_per_chunk=200):
"""
Divise le texte en chunks en respectant les phrases et en comptant les tokens.
"""
tokenizer = get_tokenizer(src_lang, tgt_lang)
tokenizer.src_lang = src_lang
# Séparation basée sur les phrases
sentences = re.split(r'([.!?])', text)
chunks = []
current_chunk = ""
current_tokens = 0
for i in range(0, len(sentences), 2):
# Reconstruire la phrase avec sa ponctuation
if i + 1 < len(sentences):
sentence = sentences[i] + sentences[i+1]
else:
sentence = sentences[i]
# Calculer le nombre de tokens pour cette phrase
sentence_tokens = len(tokenizer.encode(sentence))
# Si l'ajout de cette phrase dépasse la limite de tokens, on crée un nouveau chunk
if current_tokens + sentence_tokens > max_tokens_per_chunk and current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
current_tokens = sentence_tokens
else:
current_chunk += sentence
current_tokens += sentence_tokens
# Ajouter le dernier chunk s'il reste du texte
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang, max_tokens_per_chunk=200):
# Vérifier si le texte a besoin d'être divisé en comptant les tokens
tokenizer = get_tokenizer(src_lang, tgt_lang)
tokenizer.src_lang = src_lang
text_tokens = len(tokenizer.encode(text))
if text_tokens > max_tokens_per_chunk:
chunks = split_text_by_tokens(text, src_lang, tgt_lang, max_tokens_per_chunk)
translations = []
for chunk in chunks:
translated_chunk = translate_chunk(chunk, src_lang, tgt_lang)
translations.append(translated_chunk)
return " ".join(translations)
else:
return translate_chunk(text, src_lang, tgt_lang)
def translate_chunk(text, src_lang, tgt_lang):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
model_id = "ArissBandoss/3b-new-400"
else:
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
# Configuration du tokenizer
tokenizer.src_lang = src_lang
# Tokenisation
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
# ID du token de langue cible
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
# Paramètres de génération optimisés
outputs = model.generate(
**inputs,
forced_bos_token_id=tgt_lang_id,
max_new_tokens=512,
num_beams=5,
no_repeat_ngram_size=3,
repetition_penalty=1.5,
length_penalty=1.0,
early_stopping=True
)
# Décodage
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return translation
def real_time_traduction(input_text, src_lang, tgt_lang):
return goai_traduction(input_text, src_lang, tgt_lang, max_tokens_per_chunk=200) |