Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import spaces | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
from peft import PeftModel, PeftConfig | |
import os | |
import unicodedata | |
from huggingface_hub import login | |
max_length = 512 | |
auth_token = os.getenv('HF_SPACE_TOKEN') | |
login(token=auth_token) | |
def goai_traduction(text, src_lang, tgt_lang): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if src_lang == "fra_Latn" and tgt_lang == "mos_Latn": | |
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4" | |
elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn": | |
model_id = "ArissBandoss/3b-new-400" | |
else: | |
model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4" | |
print(f"Chargement du modèle: {model_id}") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device) | |
print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):") | |
print(text) | |
print(f"Configuration du modèle:") | |
print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}") | |
print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}") | |
print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}") | |
# Configuration du tokenizer | |
tokenizer.src_lang = src_lang | |
# Tokenisation | |
inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device) | |
input_ids = inputs["input_ids"][0] | |
print("Tokens d'entrée:") | |
print(f"- Nombre de tokens: {input_ids.shape[0]}") | |
print(f"- Premiers tokens: {input_ids[:10].tolist()}") | |
print(f"- Derniers tokens: {input_ids[-10:].tolist()}") | |
# ID du token de langue cible | |
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang) | |
print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}") | |
bad_words_ids = [[tokenizer.eos_token_id]] | |
for length_penalty in [1.0, 1.5, 2.0]: | |
for num_beams in [5, 10]: | |
print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}") | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=tgt_lang_id, | |
max_new_tokens=2048, | |
early_stopping=False, | |
num_beams=num_beams, | |
no_repeat_ngram_size=0, | |
bad_words_ids=bad_words_ids, | |
length_penalty=length_penalty | |
) | |
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):") | |
print(translation) | |
output_ids = outputs[0] | |
print(f"- Nombre de tokens générés: {output_ids.shape[0]}") | |
print(f"- Premiers tokens générés: {output_ids[:10].tolist()}") | |
print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}") | |
return translation | |
def real_time_traduction(input_text, src_lang, tgt_lang): | |
return goai_traduction(input_text, src_lang, tgt_lang) |