Spaces:
Paused
Paused
File size: 5,084 Bytes
60b53a6 72dbc28 9787d82 72dbc28 2759f98 60b53a6 1acabf9 4f04d17 1acabf9 60b53a6 1acabf9 60b53a6 1acabf9 60b53a6 1acabf9 60b53a6 0db8079 1acabf9 0db8079 9787d82 1acabf9 9787d82 1acabf9 9787d82 1acabf9 9787d82 60b53a6 5efe227 9787d82 1acabf9 9787d82 60b53a6 5efe227 60b53a6 9787d82 5efe227 9787d82 5efe227 0db8079 60b53a6 9787d82 60b53a6 9787d82 60b53a6 9787d82 60b53a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import os
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Variables globales pour stocker le modèle et le tokenizer
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# Définir le token de padding s'il n'existe pas
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
return f"Modèle {model_name} chargé avec succès sur GPU."
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_attentions=True,
return_dict_in_generate=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Obtenir les logits pour le dernier token généré
last_token_logits = outputs.scores[-1][0]
# Appliquer softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Obtenir les top 5 tokens les plus probables
top_k = 5
top_probs, top_indices = torch.topk(probabilities, top_k)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
# Préparer les données pour le graphique des probabilités
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
# Extraire les attentions
attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
attention_data = {
'attention': attentions,
'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
}
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
def plot_attention(attention_data):
attention = attention_data['attention']
tokens = attention_data['tokens']
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(attention, cmap='viridis')
plt.colorbar(im)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title("Carte d'attention")
plt.tight_layout()
return fig
def plot_probabilities(prob_data):
words = list(prob_data.keys())
probs = list(prob_data.values())
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(words, probs)
ax.set_title("Probabilités des tokens suivants les plus probables")
ax.set_xlabel("Tokens")
ax.set_ylabel("Probabilité")
plt.xticks(rotation=45)
plt.tight_layout()
return fig
def reset():
return "", 1.0, 1.0, 50, None, None, None
with gr.Blocks() as demo:
gr.Markdown("# Générateur de texte avec visualisation d'attention")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée")
generate_button = gr.Button("Générer")
output_text = gr.Textbox(label="Texte généré")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
reset_button = gr.Button("Réinitialiser")
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[output_text, attention_plot, prob_plot])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
demo.launch()
|