File size: 2,031 Bytes
a4e4083
48d9d00
4d2b819
48d9d00
b7f8793
48d9d00
7a5ec34
ba09697
b7f8793
48d9d00
ba09697
 
 
48d9d00
ba09697
b7f8793
48d9d00
ba09697
4d2b819
48d9d00
 
4d2b819
48d9d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e4083
 
48d9d00
 
 
 
a4e4083
48d9d00
a4e4083
48d9d00
a4e4083
48d9d00
 
 
 
 
 
 
 
 
 
a4e4083
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Define the base and LoRA model IDs
base_model_id = "unsloth/gemma-2-9b-bnb-4bit"
lora_model_id = "Futuresony/future_12_10_2024"

# Load the base model on CPU with float16
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    device_map="cpu",  # Load the model on CPU, no GPU
)

# Load the PEFT LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_id)

# Tokenizer for the model
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

# Function to respond to the user's input
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
    # Prepare the message history for chat completion
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    # Generate a response
    response = ""
    for message in model.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

# Gradio interface setup
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()