Spaces:
Paused
Paused
File size: 4,951 Bytes
60b53a6 72dbc28 9787d82 72dbc28 2759f98 60b53a6 2759f98 60b53a6 b17e963 60b53a6 2759f98 0db8079 60b53a6 0db8079 60b53a6 0db8079 9787d82 0db8079 60b53a6 9787d82 60b53a6 9787d82 60b53a6 9787d82 60b53a6 9787d82 0db8079 60b53a6 9787d82 60b53a6 9787d82 60b53a6 9787d82 60b53a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import os
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Variables globales pour stocker le modèle et le tokenizer
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
return f"Modèle {model_name} chargé avec succès sur CPU."
@spaces.GPU(duration=300)
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_attentions=True,
return_dict_in_generate=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Obtenir les logits pour le dernier token généré
last_token_logits = model(outputs.sequences[:, -1:]).logits[:, -1, :]
# Appliquer softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(last_token_logits[0], dim=-1)
# Obtenir les top 5 tokens les plus probables
top_k = 5
top_probs, top_indices = torch.topk(probabilities, top_k)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
# Préparer les données pour le graphique des probabilités
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
# Extraire les attentions
attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
# Préparer les données pour la carte d'attention
tokens = tokenizer.convert_ids_to_tokens(outputs.sequences[0])
attention_data = {
'attention': attentions.tolist(),
'tokens': tokens
}
return generated_text, attention_data, prob_data
def plot_attention(attention_data):
attention = np.array(attention_data['attention'])
tokens = attention_data['tokens']
plt.figure(figsize=(10, 10))
plt.imshow(attention, cmap='viridis')
plt.colorbar()
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)
plt.title("Carte d'attention")
return plt
def plot_probabilities(prob_data):
words = list(prob_data.keys())
probs = list(prob_data.values())
plt.figure(figsize=(10, 5))
plt.bar(words, probs)
plt.title("Probabilités des tokens suivants les plus probables")
plt.xlabel("Tokens")
plt.ylabel("Probabilité")
plt.xticks(rotation=45)
return plt
def reset():
return "", 1.0, 1.0, 50, None, None, None
with gr.Blocks() as demo:
gr.Markdown("# Générateur de texte avec visualisation d'attention")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée")
generate_button = gr.Button("Générer")
output_text = gr.Textbox(label="Texte généré")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
reset_button = gr.Button("Réinitialiser")
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[output_text, attention_plot, prob_plot])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
attention_plot.change(plot_attention, inputs=[attention_plot], outputs=[attention_plot])
prob_plot.change(plot_probabilities, inputs=[prob_plot], outputs=[prob_plot])
demo.launch()
|