File size: 2,528 Bytes
78d1101
 
 
eda40b7
78d1101
 
1c7cbff
78d1101
 
 
 
 
 
 
 
 
 
 
 
 
58d122b
78d1101
 
 
 
 
 
9d96957
3e462e2
 
af19dcb
3e462e2
91812f2
af19dcb
 
 
b8808c9
af19dcb
 
 
 
b8808c9
af19dcb
 
 
 
 
91812f2
 
78d1101
 
 
 
 
 
 
 
 
20a6b7b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import os
from huggingface_hub import login

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)


@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
        
    elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/mos2fr-5B-800"

    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"

    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model     = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)

    if model_id == "ArissBandoss/mos2fr-5B-800":
        peft_config = PeftConfig.from_pretrained("ArissBandoss/mos2fr-5B-800")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(peft_config.base_model_name_or_path)
        model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-5B-800")
        
        # Instead of using the pipeline, do direct generation
        tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
        def translate(text, src_lang, tgt_lang, max_length=512):
            inputs = tokenizer(text, return_tensors="pt", max_length=max_length)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            generation_kwargs = {}
            if src_lang and tgt_lang:
                generation_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids[tgt_lang]
            
            outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
            return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        translation_text = translate(text, src_lang, tgt_lang)
        return translation_text

        
        
    trans_pipe = pipeline("translation", 
                          model=model, tokenizer=tokenizer, 
                          src_lang=src_lang, tgt_lang=tgt_lang, 
                          max_length=max_length,
                          device=device
                         )
    
    return trans_pipe(text)[0]["translation_text"]


def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)