File size: 3,232 Bytes
78d1101 49dc84f eda40b7 78d1101 4302c2b 78d1101 1c7cbff 78d1101 488bb16 78d1101 261a5aa 78d1101 24c3ce1 78d1101 7f3d8a9 fb3e214 dfb286c fb3e214 8cf4d3b fb3e214 7f3d8a9 261a5aa 7f3d8a9 d4b611c fb3e214 7f3d8a9 261a5aa fb3e214 7f3d8a9 fb3e214 28ec2c5 fb3e214 28ec2c5 fb3e214 261a5aa 20a6b7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login
max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)
@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
model_id = "ArissBandoss/3b-new-400"
else:
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
print(f"Chargement du modèle: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
print(text)
print(f"Configuration du modèle:")
print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}")
print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}")
print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}")
# Configuration du tokenizer
tokenizer.src_lang = src_lang
# Tokenisation
inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
input_ids = inputs["input_ids"][0]
print("Tokens d'entrée:")
print(f"- Nombre de tokens: {input_ids.shape[0]}")
print(f"- Premiers tokens: {input_ids[:10].tolist()}")
print(f"- Derniers tokens: {input_ids[-10:].tolist()}")
# ID du token de langue cible
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
for length_penalty in [1.0, 1.5, 2.0]:
for num_beams in [5, 10]:
print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}")
outputs = model.generate(
**inputs,
forced_bos_token_id=tgt_lang_id,
max_new_tokens=2048,
early_stopping=False,
num_beams=num_beams,
no_repeat_ngram_size=0,
bad_words_ids=bad_words_ids,
length_penalty=length_penalty
)
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):")
print(translation)
output_ids = outputs[0]
print(f"- Nombre de tokens générés: {output_ids.shape[0]}")
print(f"- Premiers tokens générés: {output_ids[:10].tolist()}")
print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}")
return translation
def real_time_traduction(input_text, src_lang, tgt_lang):
return goai_traduction(input_text, src_lang, tgt_lang) |