LLMnBiasV2 / app.py
Woziii's picture
Update app.py
1acabf9 verified
raw
history blame
5.08 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import os
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
models = [
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Variables globales pour stocker le modèle et le tokenizer
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# Définir le token de padding s'il n'existe pas
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
return f"Modèle {model_name} chargé avec succès sur GPU."
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_attentions=True,
return_dict_in_generate=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Obtenir les logits pour le dernier token généré
last_token_logits = outputs.scores[-1][0]
# Appliquer softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Obtenir les top 5 tokens les plus probables
top_k = 5
top_probs, top_indices = torch.topk(probabilities, top_k)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
# Préparer les données pour le graphique des probabilités
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
# Extraire les attentions
attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
attention_data = {
'attention': attentions,
'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
}
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
def plot_attention(attention_data):
attention = attention_data['attention']
tokens = attention_data['tokens']
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(attention, cmap='viridis')
plt.colorbar(im)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title("Carte d'attention")
plt.tight_layout()
return fig
def plot_probabilities(prob_data):
words = list(prob_data.keys())
probs = list(prob_data.values())
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(words, probs)
ax.set_title("Probabilités des tokens suivants les plus probables")
ax.set_xlabel("Tokens")
ax.set_ylabel("Probabilité")
plt.xticks(rotation=45)
plt.tight_layout()
return fig
def reset():
return "", 1.0, 1.0, 50, None, None, None
with gr.Blocks() as demo:
gr.Markdown("# Générateur de texte avec visualisation d'attention")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée")
generate_button = gr.Button("Générer")
output_text = gr.Textbox(label="Texte généré")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
reset_button = gr.Button("Réinitialiser")
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[output_text, attention_plot, prob_plot])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
demo.launch()