openGemma3 / app.py
BryanBradfo's picture
another version
8cf3f1c
import warnings
warnings.filterwarnings("ignore")
import os
import sys
from typing import List, Tuple
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers
from huggingface_hub import hf_hub_download
import gradio as gr
# Local imports (assure-toi que ces fichiers sont dans le même dossier)
from logger import logging
from exception import CustomExceptionHandling
# Download gguf model files
if not os.path.exists("./models"):
os.makedirs("./models")
MODEL_REPO_ID = "bartowski/google_gemma-3-1b-it-GGUF"
MODEL_FILENAME_Q4 = "google_gemma-3-1b-it-Q4_K_M.gguf"
if not os.path.exists(f"./models/{MODEL_FILENAME_Q4}"):
logging.info(f"Téléchargement du modèle {MODEL_FILENAME_Q4} depuis {MODEL_REPO_ID}...")
hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME_Q4,
local_dir="./models",
)
logging.info("Téléchargement terminé.")
else:
logging.info(f"Modèle {MODEL_FILENAME_Q4} déjà présent localement.")
# Define the prompt markers for Gemma 3
gemma_3_prompt_markers = {
Roles.system: PromptMarkers("", "\n"),
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"),
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"),
Roles.tool: PromptMarkers("", ""),
}
gemma_3_formatter = MessagesFormatter(
pre_prompt="",
prompt_markers=gemma_3_prompt_markers,
include_sys_prompt_in_first_user_message=True,
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"],
strip_prompt=False,
bos_token="<bos>",
eos_token="<eos>",
)
# Global variables to cache the model
llm = None
current_model_name = None
def answer(
message: str,
historical_information: List[Tuple[str, str]],
model_filename: str,
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repeat_penalty: float,
):
global llm
global current_model_name
try:
model_path = f"./models/{model_filename}"
if not os.path.exists(model_path):
yield f"Erreur : Fichier modèle non trouvé à {model_path}. Vérifiez le chemin."
return
if llm is None or current_model_name != model_filename:
logging.info(f"Chargement du modèle : {model_path}")
# Ajuste les n_threads en fonction de ton CPU
cpu_count = os.cpu_count()
threads_to_use = max(1, cpu_count // 2 if cpu_count else 4)
llm = Llama(
model_path=model_path,
flash_attn=False,
n_gpu_layers=0,
n_batch=512,
n_ctx=2048,
n_threads=threads_to_use,
n_threads_batch=threads_to_use,
verbose=False
)
current_model_name = model_filename
logging.info(f"Modèle {current_model_name} chargé avec {threads_to_use} threads.")
provider = LlamaCppPythonProvider(llm)
agent = LlamaCppAgent(
provider,
system_prompt=system_message,
custom_messages_formatter=gemma_3_formatter,
debug_output=False,
)
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = True
chat_history_for_agent = BasicChatHistory()
for user_msg, assistant_msg in historical_information:
if user_msg:
chat_history_for_agent.add_message({"role": Roles.user, "content": user_msg})
if assistant_msg:
chat_history_for_agent.add_message({"role": Roles.assistant, "content": assistant_msg})
logging.info(f"Envoi du message à l'agent: {message}")
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=chat_history_for_agent,
returns_streaming_generator=True,
print_output=False,
)
response_so_far = ""
for token in stream:
response_so_far += token
yield response_so_far
logging.info("Réponse générée.")
except Exception as e:
logging.error(f"Erreur lors de la génération de la réponse: {e}")
# Si tu utilises CustomExceptionHandling
# raise CustomExceptionHandling(e, sys) from e
yield f"Une erreur est survenue: {str(e)}"
available_models = [MODEL_FILENAME_Q4]
# --- Définition du Thème ---
# Tu peux décommenter et tester différents thèmes
# current_theme = gr.themes.Glass()
# current_theme = gr.themes.Monochrome()
# current_theme = gr.themes.Seafoam()
# current_theme = "gradio/dracula_revamped"
# current_theme = "NoCrypt/Miku"
current_theme = gr.themes.Soft(
primary_hue=gr.themes.colors.indigo, # Couleur principale (boutons, sliders actifs)
secondary_hue=gr.themes.colors.pink, # Couleur secondaire
neutral_hue=gr.themes.colors.slate, # Couleur neutre (texte, bordures)
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] # Police
).set(
# Tu peux surcharger des éléments spécifiques du thème ici si besoin
# Exemple: body_background_fill="linear-gradient(to right, #DCE35B, #45B649)"
)
app_title = "OpenGemma3 Chat"
app_description = """Discutez avec **Gemma 3 1B-IT**, un modèle de langage avancé de Google, exécuté localement grâce à `llama.cpp`.
Explorez ses capacités en ajustant les paramètres de génération ci-dessous."""
demo = gr.ChatInterface(
answer,
chatbot=gr.Chatbot(
label="Conversation", # Label du composant chatbot
height=600,
scale=1,
show_copy_button=True,
resizable=True,
# Pour les avatars, crée un dossier 'avatars' et place des images dedans
# avatar_images=("./avatars/user_avatar.png", "./avatars/bot_avatar.png")
bubble_full_width=False # Pour que les bulles ne prennent pas toute la largeur
),
additional_inputs=[
gr.Dropdown(
choices=available_models,
value=available_models[0],
label="Modèle GGUF",
info="Sélectionnez le modèle GGUF à utiliser.",
),
gr.Textbox(value="You are a helpful and friendly AI assistant named Gemma. You are concise and provide accurate information.", label="System message", lines=3, info="Définissez la personnalité et le rôle de l'assistant."),
gr.Slider(minimum=128, maximum=3072, value=1024, step=128, label="Max Tokens", info="Nombre maximum de tokens à générer pour la réponse."),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature", info="Contrôle la créativité (plus haut = plus créatif)."),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)", info="Considère les tokens dont la probabilité cumulative atteint top-p."),
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k", info="Considère les k tokens les plus probables."),
gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Pénalise la répétition de tokens (plus haut = moins de répétition)."),
],
title=app_title,
description=app_description,
examples=[
["Explique le concept de trou noir de manière simple."],
["Quelle est la recette des crêpes ?"],
["Raconte-moi une histoire courte et amusante."]
],
submit_btn="Envoyer",
stop_btn="Arrêter",
theme=current_theme,
)
if __name__ == "__main__":
logging.info("Lancement de l'interface Gradio...")
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)