File size: 3,565 Bytes
60b53a6
 
6696db2
 
60b53a6
 
6d96117
9787d82
6696db2
ea35578
2759f98
19de71a
 
 
 
6696db2
 
 
 
 
 
 
 
 
60b53a6
 
 
19de71a
 
60b53a6
19de71a
60b53a6
19de71a
 
 
 
 
6696db2
19de71a
 
 
 
6696db2
19de71a
 
6696db2
19de71a
 
6696db2
19de71a
 
1acabf9
19de71a
 
9787d82
19de71a
60b53a6
6696db2
19de71a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b53a6
19de71a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import matplotlib.pyplot as plt
import numpy as np
import os

# Login to Hugging Face with token
login(token=os.environ["HF_TOKEN"])

MODEL_LIST = [
    "meta-llama/Llama-2-13b-hf",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Llama-2-70b-hf",
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Llama-3.2-3B",
    "meta-llama/Llama-3.1-8B",
    "mistralai/Mistral-7B-v0.1",
    "mistralai/Mixtral-8x7B-v0.1",
    "mistralai/Mistral-7B-v0.3",
    "google/gemma-2-2b",
    "google/gemma-2-9b",
    "google/gemma-2-27b",
    "croissantllm/CroissantLLMBase"
]

# Dictionnaire pour stocker les modèles et tokenizers déjà chargés
loaded_models = {}

# Charger le modèle
def load_model(model_name):
    if model_name not in loaded_models:
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
        loaded_models[model_name] = (model, tokenizer)
    return loaded_models[model_name]

# Génération de texte et attention
def generate_text(model_name, input_text, temperature, top_p, top_k):
    model, tokenizer = load_model(model_name)
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

    # Génération du texte
    output = model.generate(**inputs, max_new_tokens=50, temperature=temperature, top_p=top_p, top_k=top_k, output_attentions=True)
    
    # Décodage de la sortie
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    # Affichage des mots les plus probables
    last_token_logits = output.scores[-1][0]
    probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
    top_tokens = torch.topk(probabilities, k=5)
    probable_words = [tokenizer.decode([token]) for token in top_tokens.indices]

    return generated_text, probable_words

# Interface utilisateur Gradio
def reset_interface():
    return "", "", "", ""

def main():
    with gr.Blocks() as app:
        with gr.Accordion("Choix du modèle", open=True):
            model_name = gr.Dropdown(choices=MODEL_LIST, label="Modèles disponibles", value=MODEL_LIST[0])
        
        with gr.Row():
            input_text = gr.Textbox(label="Texte d'entrée", placeholder="Saisissez votre texte ici...")
        
        with gr.Accordion("Paramètres", open=True):
            temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.01, label="Température")
            top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="Top_p")
            top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top_k")
        
        with gr.Row():
            generate_button = gr.Button("Lancer la génération")
            reset_button = gr.Button("Réinitialiser")
        
        generated_text_output = gr.Textbox(label="Texte généré", placeholder="Le texte généré s'affichera ici...")
        probable_words_output = gr.Textbox(label="Mots les plus probables", placeholder="Les mots les plus probables apparaîtront ici...")
        
        # Lancer la génération
        generate_button.click(generate_text, inputs=[model_name, input_text, temperature, top_p, top_k], outputs=[generated_text_output, probable_words_output])
        # Réinitialiser
        reset_button.click(reset_interface, outputs=[input_text, generated_text_output, probable_words_output])
    
    app.launch()

if __name__ == "__main__":
    main()