import warnings warnings.filterwarnings("ignore") import os import sys from typing import List, Tuple from llama_cpp import Llama from llama_cpp_agent import LlamaCppAgent from llama_cpp_agent.providers import LlamaCppPythonProvider from llama_cpp_agent.chat_history import BasicChatHistory from llama_cpp_agent.chat_history.messages import Roles from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers from huggingface_hub import hf_hub_download import gradio as gr # Local imports (assure-toi que ces fichiers sont dans le même dossier) from logger import logging from exception import CustomExceptionHandling # Download gguf model files if not os.path.exists("./models"): os.makedirs("./models") MODEL_REPO_ID = "bartowski/google_gemma-3-1b-it-GGUF" MODEL_FILENAME_Q4 = "google_gemma-3-1b-it-Q4_K_M.gguf" if not os.path.exists(f"./models/{MODEL_FILENAME_Q4}"): logging.info(f"Téléchargement du modèle {MODEL_FILENAME_Q4} depuis {MODEL_REPO_ID}...") hf_hub_download( repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME_Q4, local_dir="./models", ) logging.info("Téléchargement terminé.") else: logging.info(f"Modèle {MODEL_FILENAME_Q4} déjà présent localement.") # Define the prompt markers for Gemma 3 gemma_3_prompt_markers = { Roles.system: PromptMarkers("", "\n"), Roles.user: PromptMarkers("user\n", "\n"), Roles.assistant: PromptMarkers("model\n", "\n"), Roles.tool: PromptMarkers("", ""), } gemma_3_formatter = MessagesFormatter( pre_prompt="", prompt_markers=gemma_3_prompt_markers, include_sys_prompt_in_first_user_message=True, default_stop_sequences=["", ""], strip_prompt=False, bos_token="", eos_token="", ) # Global variables to cache the model llm = None current_model_name = None def answer( message: str, historical_information: List[Tuple[str, str]], model_filename: str, system_message: str, max_tokens: int, temperature: float, top_p: float, top_k: int, repeat_penalty: float, ): global llm global current_model_name try: model_path = f"./models/{model_filename}" if not os.path.exists(model_path): yield f"Erreur : Fichier modèle non trouvé à {model_path}. Vérifiez le chemin." return if llm is None or current_model_name != model_filename: logging.info(f"Chargement du modèle : {model_path}") # Ajuste les n_threads en fonction de ton CPU cpu_count = os.cpu_count() threads_to_use = max(1, cpu_count // 2 if cpu_count else 4) llm = Llama( model_path=model_path, flash_attn=False, n_gpu_layers=0, n_batch=512, n_ctx=2048, n_threads=threads_to_use, n_threads_batch=threads_to_use, verbose=False ) current_model_name = model_filename logging.info(f"Modèle {current_model_name} chargé avec {threads_to_use} threads.") provider = LlamaCppPythonProvider(llm) agent = LlamaCppAgent( provider, system_prompt=system_message, custom_messages_formatter=gemma_3_formatter, debug_output=False, ) settings = provider.get_provider_default_settings() settings.temperature = temperature settings.top_k = top_k settings.top_p = top_p settings.max_tokens = max_tokens settings.repeat_penalty = repeat_penalty settings.stream = True chat_history_for_agent = BasicChatHistory() for user_msg, assistant_msg in historical_information: if user_msg: chat_history_for_agent.add_message({"role": Roles.user, "content": user_msg}) if assistant_msg: chat_history_for_agent.add_message({"role": Roles.assistant, "content": assistant_msg}) logging.info(f"Envoi du message à l'agent: {message}") stream = agent.get_chat_response( message, llm_sampling_settings=settings, chat_history=chat_history_for_agent, returns_streaming_generator=True, print_output=False, ) response_so_far = "" for token in stream: response_so_far += token yield response_so_far logging.info("Réponse générée.") except Exception as e: logging.error(f"Erreur lors de la génération de la réponse: {e}") # Si tu utilises CustomExceptionHandling # raise CustomExceptionHandling(e, sys) from e yield f"Une erreur est survenue: {str(e)}" available_models = [MODEL_FILENAME_Q4] # --- Définition du Thème --- # Tu peux décommenter et tester différents thèmes # current_theme = gr.themes.Glass() # current_theme = gr.themes.Monochrome() # current_theme = gr.themes.Seafoam() # current_theme = "gradio/dracula_revamped" # current_theme = "NoCrypt/Miku" current_theme = gr.themes.Soft( primary_hue=gr.themes.colors.indigo, # Couleur principale (boutons, sliders actifs) secondary_hue=gr.themes.colors.pink, # Couleur secondaire neutral_hue=gr.themes.colors.slate, # Couleur neutre (texte, bordures) font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] # Police ).set( # Tu peux surcharger des éléments spécifiques du thème ici si besoin # Exemple: body_background_fill="linear-gradient(to right, #DCE35B, #45B649)" ) app_title = "OpenGemma3 Chat" app_description = """Discutez avec **Gemma 3 1B-IT**, un modèle de langage avancé de Google, exécuté localement grâce à `llama.cpp`. Explorez ses capacités en ajustant les paramètres de génération ci-dessous.""" demo = gr.ChatInterface( answer, chatbot=gr.Chatbot( label="Conversation", # Label du composant chatbot height=600, scale=1, show_copy_button=True, resizable=True, # Pour les avatars, crée un dossier 'avatars' et place des images dedans # avatar_images=("./avatars/user_avatar.png", "./avatars/bot_avatar.png") bubble_full_width=False # Pour que les bulles ne prennent pas toute la largeur ), additional_inputs=[ gr.Dropdown( choices=available_models, value=available_models[0], label="Modèle GGUF", info="Sélectionnez le modèle GGUF à utiliser.", ), gr.Textbox(value="You are a helpful and friendly AI assistant named Gemma. You are concise and provide accurate information.", label="System message", lines=3, info="Définissez la personnalité et le rôle de l'assistant."), gr.Slider(minimum=128, maximum=3072, value=1024, step=128, label="Max Tokens", info="Nombre maximum de tokens à générer pour la réponse."), gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature", info="Contrôle la créativité (plus haut = plus créatif)."), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)", info="Considère les tokens dont la probabilité cumulative atteint top-p."), gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k", info="Considère les k tokens les plus probables."), gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Pénalise la répétition de tokens (plus haut = moins de répétition)."), ], title=app_title, description=app_description, examples=[ ["Explique le concept de trou noir de manière simple."], ["Quelle est la recette des crêpes ?"], ["Raconte-moi une histoire courte et amusante."] ], submit_btn="Envoyer", stop_btn="Arrêter", theme=current_theme, ) if __name__ == "__main__": logging.info("Lancement de l'interface Gradio...") demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)