Spaces:
Build error
Build error
import warnings | |
warnings.filterwarnings("ignore") | |
import os | |
import sys | |
from typing import List, Tuple | |
from llama_cpp import Llama | |
from llama_cpp_agent import LlamaCppAgent | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
# Local imports (assure-toi que ces fichiers sont dans le même dossier) | |
from logger import logging | |
from exception import CustomExceptionHandling | |
# Download gguf model files | |
if not os.path.exists("./models"): | |
os.makedirs("./models") | |
MODEL_REPO_ID = "bartowski/google_gemma-3-1b-it-GGUF" | |
MODEL_FILENAME_Q4 = "google_gemma-3-1b-it-Q4_K_M.gguf" | |
if not os.path.exists(f"./models/{MODEL_FILENAME_Q4}"): | |
logging.info(f"Téléchargement du modèle {MODEL_FILENAME_Q4} depuis {MODEL_REPO_ID}...") | |
hf_hub_download( | |
repo_id=MODEL_REPO_ID, | |
filename=MODEL_FILENAME_Q4, | |
local_dir="./models", | |
) | |
logging.info("Téléchargement terminé.") | |
else: | |
logging.info(f"Modèle {MODEL_FILENAME_Q4} déjà présent localement.") | |
# Define the prompt markers for Gemma 3 | |
gemma_3_prompt_markers = { | |
Roles.system: PromptMarkers("", "\n"), | |
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"), | |
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"), | |
Roles.tool: PromptMarkers("", ""), | |
} | |
gemma_3_formatter = MessagesFormatter( | |
pre_prompt="", | |
prompt_markers=gemma_3_prompt_markers, | |
include_sys_prompt_in_first_user_message=True, | |
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"], | |
strip_prompt=False, | |
bos_token="<bos>", | |
eos_token="<eos>", | |
) | |
# Global variables to cache the model | |
llm = None | |
current_model_name = None | |
def answer( | |
message: str, | |
historical_information: List[Tuple[str, str]], | |
model_filename: str, | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
repeat_penalty: float, | |
): | |
global llm | |
global current_model_name | |
try: | |
model_path = f"./models/{model_filename}" | |
if not os.path.exists(model_path): | |
yield f"Erreur : Fichier modèle non trouvé à {model_path}. Vérifiez le chemin." | |
return | |
if llm is None or current_model_name != model_filename: | |
logging.info(f"Chargement du modèle : {model_path}") | |
# Ajuste les n_threads en fonction de ton CPU | |
cpu_count = os.cpu_count() | |
threads_to_use = max(1, cpu_count // 2 if cpu_count else 4) | |
llm = Llama( | |
model_path=model_path, | |
flash_attn=False, | |
n_gpu_layers=0, | |
n_batch=512, | |
n_ctx=2048, | |
n_threads=threads_to_use, | |
n_threads_batch=threads_to_use, | |
verbose=False | |
) | |
current_model_name = model_filename | |
logging.info(f"Modèle {current_model_name} chargé avec {threads_to_use} threads.") | |
provider = LlamaCppPythonProvider(llm) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=system_message, | |
custom_messages_formatter=gemma_3_formatter, | |
debug_output=False, | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = True | |
chat_history_for_agent = BasicChatHistory() | |
for user_msg, assistant_msg in historical_information: | |
if user_msg: | |
chat_history_for_agent.add_message({"role": Roles.user, "content": user_msg}) | |
if assistant_msg: | |
chat_history_for_agent.add_message({"role": Roles.assistant, "content": assistant_msg}) | |
logging.info(f"Envoi du message à l'agent: {message}") | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=chat_history_for_agent, | |
returns_streaming_generator=True, | |
print_output=False, | |
) | |
response_so_far = "" | |
for token in stream: | |
response_so_far += token | |
yield response_so_far | |
logging.info("Réponse générée.") | |
except Exception as e: | |
logging.error(f"Erreur lors de la génération de la réponse: {e}") | |
# Si tu utilises CustomExceptionHandling | |
# raise CustomExceptionHandling(e, sys) from e | |
yield f"Une erreur est survenue: {str(e)}" | |
available_models = [MODEL_FILENAME_Q4] | |
# --- Définition du Thème --- | |
# Tu peux décommenter et tester différents thèmes | |
# current_theme = gr.themes.Glass() | |
# current_theme = gr.themes.Monochrome() | |
# current_theme = gr.themes.Seafoam() | |
# current_theme = "gradio/dracula_revamped" | |
# current_theme = "NoCrypt/Miku" | |
current_theme = gr.themes.Soft( | |
primary_hue=gr.themes.colors.indigo, # Couleur principale (boutons, sliders actifs) | |
secondary_hue=gr.themes.colors.pink, # Couleur secondaire | |
neutral_hue=gr.themes.colors.slate, # Couleur neutre (texte, bordures) | |
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] # Police | |
).set( | |
# Tu peux surcharger des éléments spécifiques du thème ici si besoin | |
# Exemple: body_background_fill="linear-gradient(to right, #DCE35B, #45B649)" | |
) | |
app_title = "OpenGemma3 Chat" | |
app_description = """Discutez avec **Gemma 3 1B-IT**, un modèle de langage avancé de Google, exécuté localement grâce à `llama.cpp`. | |
Explorez ses capacités en ajustant les paramètres de génération ci-dessous.""" | |
demo = gr.ChatInterface( | |
answer, | |
chatbot=gr.Chatbot( | |
label="Conversation", # Label du composant chatbot | |
height=600, | |
scale=1, | |
show_copy_button=True, | |
resizable=True, | |
# Pour les avatars, crée un dossier 'avatars' et place des images dedans | |
# avatar_images=("./avatars/user_avatar.png", "./avatars/bot_avatar.png") | |
bubble_full_width=False # Pour que les bulles ne prennent pas toute la largeur | |
), | |
additional_inputs=[ | |
gr.Dropdown( | |
choices=available_models, | |
value=available_models[0], | |
label="Modèle GGUF", | |
info="Sélectionnez le modèle GGUF à utiliser.", | |
), | |
gr.Textbox(value="You are a helpful and friendly AI assistant named Gemma. You are concise and provide accurate information.", label="System message", lines=3, info="Définissez la personnalité et le rôle de l'assistant."), | |
gr.Slider(minimum=128, maximum=3072, value=1024, step=128, label="Max Tokens", info="Nombre maximum de tokens à générer pour la réponse."), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature", info="Contrôle la créativité (plus haut = plus créatif)."), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)", info="Considère les tokens dont la probabilité cumulative atteint top-p."), | |
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k", info="Considère les k tokens les plus probables."), | |
gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Pénalise la répétition de tokens (plus haut = moins de répétition)."), | |
], | |
title=app_title, | |
description=app_description, | |
examples=[ | |
["Explique le concept de trou noir de manière simple."], | |
["Quelle est la recette des crêpes ?"], | |
["Raconte-moi une histoire courte et amusante."] | |
], | |
submit_btn="Envoyer", | |
stop_btn="Arrêter", | |
theme=current_theme, | |
) | |
if __name__ == "__main__": | |
logging.info("Lancement de l'interface Gradio...") | |
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) |