Spaces:
Sleeping
Sleeping
# Importing required libraries | |
import warnings | |
warnings.filterwarnings("ignore") | |
import os | |
import json | |
import subprocess | |
import sys | |
from llama_cpp import Llama,llama_model_decoder_start_token | |
from llama_cpp_agent import LlamaCppAgent | |
from llama_cpp_agent import MessagesFormatterType | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
from llama_cpp_agent.chat_history.messages import Roles | |
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from typing import List, Tuple | |
from logger import logging | |
from exception import CustomExceptionHandling | |
# Download gguf model files | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
os.makedirs("models",exist_ok=True) | |
#mtsdurica/madlad400-3b-mt-Q8_0-GGUF | |
hf_hub_download( | |
repo_id="mtsdurica/madlad400-3b-mt-Q8_0-GGUF", | |
filename="madlad400-3b-mt-q8_0.gguf", | |
local_dir="./models", | |
) | |
# Define the prompt markers for Gemma 3 | |
gemma_3_prompt_markers = { | |
Roles.system: PromptMarkers("", "\n"), # System prompt should be included within user message | |
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"), | |
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"), | |
Roles.tool: PromptMarkers("", ""), # If you need tool support | |
} | |
# Create the formatter | |
gemma_3_formatter = MessagesFormatter( | |
pre_prompt="", # No pre-prompt | |
prompt_markers=gemma_3_prompt_markers, | |
include_sys_prompt_in_first_user_message=True, # Include system prompt in first user message | |
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"], | |
strip_prompt=False, # Don't strip whitespace from the prompt | |
bos_token="<bos>", # Beginning of sequence token for Gemma 3 | |
eos_token="<eos>", # End of sequence token for Gemma 3 | |
) | |
# Set the title and description | |
title = "Gemma Llama.cpp" | |
description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices.""" | |
llm = None | |
llm_model = None | |
import ctypes | |
import os | |
import multiprocessing | |
import llama_cpp | |
def test(): | |
llama_cpp.llama_backend_init(numa=False) | |
N_THREADS = multiprocessing.cpu_count() | |
MODEL_PATH = "models/madlad400-3b-mt-q8_0.gguf" | |
prompt = b"translate English to German: The house is wonderful." | |
lparams = llama_cpp.llama_model_default_params() | |
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode("utf-8"), lparams) | |
vocab = llama_cpp.llama_model_get_vocab(model) | |
cparams = llama_cpp.llama_context_default_params() | |
cparams.no_perf = False | |
ctx = llama_cpp.llama_init_from_model(model, cparams) | |
sparams = llama_cpp.llama_sampler_chain_default_params() | |
smpl = llama_cpp.llama_sampler_chain_init(sparams) | |
llama_cpp.llama_sampler_chain_add(smpl, llama_cpp.llama_sampler_init_greedy()) | |
n_past = 0 | |
embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))() | |
n_of_tok = llama_cpp.llama_tokenize( | |
vocab, | |
prompt, | |
len(prompt), | |
embd_inp, | |
len(embd_inp), | |
True, | |
True, | |
) | |
embd_inp = embd_inp[:n_of_tok] | |
n_ctx = llama_cpp.llama_n_ctx(ctx) | |
n_predict = 20 | |
n_predict = min(n_predict, n_ctx - len(embd_inp)) | |
input_consumed = 0 | |
input_noecho = False | |
remaining_tokens = n_predict | |
embd = [] | |
last_n_size = 64 | |
last_n_tokens_data = [0] * last_n_size | |
n_batch = 24 | |
last_n_repeat = 64 | |
repeat_penalty = 1 | |
frequency_penalty = 0.0 | |
presence_penalty = 0.0 | |
batch = llama_cpp.llama_batch_init(n_batch, 0, 1) | |
# prepare batch for encoding containing the prompt | |
batch.n_tokens = len(embd_inp) | |
for i in range(batch.n_tokens): | |
batch.token[i] = embd_inp[i] | |
batch.pos[i] = i | |
batch.n_seq_id[i] = 1 | |
batch.seq_id[i][0] = 0 | |
batch.logits[i] = False | |
llama_cpp.llama_encode( | |
ctx, | |
batch | |
) | |
# now overwrite embd_inp so batch for decoding will initially contain only | |
# a single token with id acquired from llama_model_decoder_start_token(model) | |
embd_inp = [llama_cpp.llama_model_decoder_start_token(model)] | |
while remaining_tokens > 0: | |
if len(embd) > 0: | |
batch.n_tokens = len(embd) | |
for i in range(batch.n_tokens): | |
batch.token[i] = embd[i] | |
batch.pos[i] = n_past + i | |
batch.n_seq_id[i] = 1 | |
batch.seq_id[i][0] = 0 | |
batch.logits[i] = i == batch.n_tokens - 1 | |
llama_cpp.llama_decode( | |
ctx, | |
batch | |
) | |
n_past += len(embd) | |
embd = [] | |
if len(embd_inp) <= input_consumed: | |
id = llama_cpp.llama_sampler_sample(smpl, ctx, -1) | |
last_n_tokens_data = last_n_tokens_data[1:] + [id] | |
embd.append(id) | |
input_noecho = False | |
remaining_tokens -= 1 | |
else: | |
while len(embd_inp) > input_consumed: | |
embd.append(embd_inp[input_consumed]) | |
last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]] | |
input_consumed += 1 | |
if len(embd) >= n_batch: | |
break | |
if not input_noecho: | |
for id in embd: | |
size = 32 | |
buffer = (ctypes.c_char * size)() | |
n = llama_cpp.llama_token_to_piece( | |
vocab, llama_cpp.llama_token(id), buffer, size, 0, True | |
) | |
assert n <= size | |
print( | |
buffer[:n].decode("utf-8"), | |
end="", | |
flush=True, | |
) | |
if len(embd) > 0 and embd[-1] in [llama_cpp.llama_token_eos(vocab), llama_cpp.llama_token_eot(vocab)]: | |
break | |
print() | |
def trans(text): | |
return test() | |
# テキストに言語タグを付与し、バイト列に変換 | |
input_text = f"<2ja>{text}".encode('utf-8') | |
# トークナイズ | |
tokens = llm.tokenize(input_text) | |
print("Tokens:", tokens) | |
# BOSトークンを取得し、確認 | |
bos_token = llm.token_bos() | |
print("BOS Token:", bos_token) | |
initial_tokens = [bos_token] | |
initial_tokens = [1] | |
print("Initial Tokens:", initial_tokens) | |
# 生成 | |
buf = "" | |
for token in llm.generate(initial_tokens, top_p=0.95, temp=0.0, repeat_penalty=1.0): | |
decoded = llm.detokenize([token]).decode('utf-8', errors='ignore') | |
buf += decoded | |
if token == llm.token_eos(): | |
break | |
return buf | |
# テキストに言語タグを付与し、バイト列に変換 | |
input_text = f"<2ja>{text}".encode('utf-8') | |
# トークナイズ | |
tokens = llm.tokenize(input_text) | |
print("Tokens:", tokens) | |
# BOSトークンを使用(デコーダーのみのモデルを想定) | |
initial_tokens = [llm.token_bos()] | |
# 生成 | |
buf = "" | |
for token in llm.generate(initial_tokens, top_p=0.95, temp=0.0, repeat_penalty=1.0): | |
decoded = llm.detokenize([token]).decode('utf-8', errors='ignore') | |
buf += decoded | |
if token == llm.token_eos(): | |
break | |
return buf | |
input_text = f"<2ja>{text}".encode('utf-8') | |
tokens = llm.tokenize(input_text) | |
print("Tokens:", tokens) | |
initial_tokens = [llm.decoder_start_token()] | |
print("Initial Tokens:", initial_tokens) | |
return text | |
llama = llm | |
text = f"<2ja>{text}".encode() | |
tokens = llama.tokenize(text) | |
llama.encode(tokens) | |
tokens = [llama.decoder_start_token()] | |
buf = "" | |
for token in llama.generate(tokens, top_k=0, top_p=0.95, temp=0, repeat_penalty=1.0): | |
buf += llama.detokenize([token]).decode() | |
if token == llama.token_eos(): | |
break | |
return buf | |
def respond( | |
message: str, | |
history: List[Tuple[str, str]], | |
model: str, | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
repeat_penalty: float, | |
): | |
""" | |
Respond to a message using the Gemma3 model via Llama.cpp. | |
Args: | |
- message (str): The message to respond to. | |
- history (List[Tuple[str, str]]): The chat history. | |
- model (str): The model to use. | |
- system_message (str): The system message to use. | |
- max_tokens (int): The maximum number of tokens to generate. | |
- temperature (float): The temperature of the model. | |
- top_p (float): The top-p of the model. | |
- top_k (int): The top-k of the model. | |
- repeat_penalty (float): The repetition penalty of the model. | |
Returns: | |
str: The response to the message. | |
""" | |
try: | |
# Load the global variables | |
global llm | |
global llm_model | |
#llama = Llama("madlad400-3b-mt-q8_0.gguf") | |
# Load the model | |
if llm is None or llm_model != model: | |
llm = Llama( | |
model_path=f"models/{model}", | |
flash_attn=False, | |
n_gpu_layers=0, | |
n_batch=8, | |
n_ctx=2048, | |
n_threads=8, | |
n_threads_batch=8, | |
) | |
llm_model = model | |
return trans(message) | |
provider = LlamaCppPythonProvider(llm) | |
# Create the agent | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=f"{system_message}", | |
# predefined_messages_formatter_type=GEMMA_2, | |
custom_messages_formatter=gemma_3_formatter, | |
debug_output=True, | |
) | |
# Set the settings like temperature, top-k, top-p, max tokens, etc. | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = True | |
messages = BasicChatHistory() | |
# Add the chat history | |
for msn in history: | |
user = {"role": Roles.user, "content": msn[0]} | |
assistant = {"role": Roles.assistant, "content": msn[1]} | |
messages.add_message(user) | |
messages.add_message(assistant) | |
# Get the response stream | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=True, | |
print_output=False, | |
) | |
# Log the success | |
logging.info("Response stream generated successfully") | |
# Generate the response | |
outputs = "" | |
for output in stream: | |
outputs += output | |
yield outputs | |
# Handle exceptions that may occur during the process | |
except Exception as e: | |
# Custom exception handling | |
raise CustomExceptionHandling(e, sys) from e | |
# Create a chat interface | |
demo = gr.ChatInterface( | |
respond, | |
examples=[["What is the capital of France?"], ["Tell me something about artificial intelligence."], ["What is gravity?"]], | |
additional_inputs_accordion=gr.Accordion( | |
label="⚙️ Parameters", open=False, render=False | |
), | |
additional_inputs=[ | |
gr.Dropdown( | |
choices=[ | |
"madlad400-3b-mt-q8_0.gguf", | |
], | |
value="madlad400-3b-mt-q8_0.gguf", | |
label="Model", | |
info="Select the AI model to use for chat", | |
), | |
gr.Textbox( | |
value="You are a helpful assistant.", | |
label="System Prompt", | |
info="Define the AI assistant's personality and behavior", | |
lines=2, | |
), | |
gr.Slider( | |
minimum=512, | |
maximum=2048, | |
value=1024, | |
step=1, | |
label="Max Tokens", | |
info="Maximum length of response (higher = longer replies)", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Creativity level (higher = more creative, lower = more focused)", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p", | |
info="Nucleus sampling threshold", | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=40, | |
step=1, | |
label="Top-k", | |
info="Limit vocabulary choices to top K tokens", | |
), | |
gr.Slider( | |
minimum=1.0, | |
maximum=2.0, | |
value=1.1, | |
step=0.1, | |
label="Repetition Penalty", | |
info="Penalize repeated words (higher = less repetition)", | |
), | |
], | |
theme="Ocean", | |
submit_btn="Send", | |
stop_btn="Stop", | |
title=title, | |
description=description, | |
chatbot=gr.Chatbot(scale=1, show_copy_button=True), | |
flagging_mode="never", | |
) | |
# Launch the chat interface | |
if __name__ == "__main__": | |
demo.launch(debug=False) | |
test() | |