Vera-Chat / app.py
Dorian2B's picture
Update app.py
e26878b verified
raw
history blame
2.42 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from threading import Lock
# Chargement du modèle
model_name = "Dorian2B/Vera-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
model.eval()
# Verrou pour éviter les conflits de threads
generate_lock = Lock()
def format_prompt(history, new_message):
"""Formate l'historique et le nouveau message pour le modèle."""
prompt = ""
for user_msg, bot_msg in history:
prompt += f"<|user|>{user_msg}</s>\n<|assistant|>{bot_msg}</s>\n"
prompt += f"<|user|>{new_message}</s>\n<|assistant|>"
return prompt
def generate_stream(history, new_message):
"""Génère une réponse en streaming avec contexte."""
prompt = format_prompt(history, new_message)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Génération en streaming
with generate_lock:
with torch.no_grad():
for chunk in model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
streamer=None, # (Remplacez par un vrai streamer si disponible)
):
decoded = tokenizer.decode(chunk[0], skip_special_tokens=True)
if decoded.startswith(prompt): # Supprime le prompt
decoded = decoded[len(prompt):]
yield decoded.strip()
def chat_interface(message, history):
"""Fonction pour Gradio ChatInterface."""
full_response = ""
for chunk in generate_stream(history, message):
full_response += chunk
yield full_response
# Interface Gradio
demo = gr.ChatInterface(
fn=chat_interface,
title="💬 Vera-Instruct Chat (avec Contexte & Streaming)",
description="Discutez avec le modèle **Dorian2B/Vera-Instruct**.<br>Le modèle conserve le contexte de la conversation.",
examples=["Bonjour ! Comment vas-tu ?", "Explique-moi l'IA générative."],
theme="soft",
retry_btn=None,
undo_btn=None,
)
if __name__ == "__main__":
demo.queue().launch(debug=True)