ArissBandoss's picture
Update goai_helpers/goai_traduction.py
28ec2c5 verified
raw
history blame
3.23 kB
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login
max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)
@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
model_id = "ArissBandoss/3b-new-400"
else:
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
print(f"Chargement du modèle: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
print(text)
print(f"Configuration du modèle:")
print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}")
print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}")
print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}")
# Configuration du tokenizer
tokenizer.src_lang = src_lang
# Tokenisation
inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
input_ids = inputs["input_ids"][0]
print("Tokens d'entrée:")
print(f"- Nombre de tokens: {input_ids.shape[0]}")
print(f"- Premiers tokens: {input_ids[:10].tolist()}")
print(f"- Derniers tokens: {input_ids[-10:].tolist()}")
# ID du token de langue cible
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
for length_penalty in [1.0, 1.5, 2.0]:
for num_beams in [5, 10]:
print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}")
outputs = model.generate(
**inputs,
forced_bos_token_id=tgt_lang_id,
max_new_tokens=2048,
early_stopping=False,
num_beams=num_beams,
no_repeat_ngram_size=0,
bad_words_ids=bad_words_ids,
length_penalty=length_penalty
)
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):")
print(translation)
output_ids = outputs[0]
print(f"- Nombre de tokens générés: {output_ids.shape[0]}")
print(f"- Premiers tokens générés: {output_ids[:10].tolist()}")
print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}")
return translation
def real_time_traduction(input_text, src_lang, tgt_lang):
return goai_traduction(input_text, src_lang, tgt_lang)