File size: 2,006 Bytes
a4e4083
48d9d00
4d2b819
48d9d00
b7f8793
a294ce4
 
c3a8689
b7f8793
a294ce4
ba09697
 
a294ce4
 
ba09697
b7f8793
48d9d00
ba09697
4d2b819
a294ce4
48d9d00
4d2b819
a294ce4
48d9d00
 
a294ce4
 
 
 
 
48d9d00
 
a294ce4
 
 
 
 
a4e4083
 
a294ce4
 
 
 
a4e4083
a294ce4
a4e4083
a294ce4
a4e4083
a294ce4
48d9d00
 
a294ce4
a4e4083
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Use a CPU-compatible base model (replace this with your actual full-precision model)
base_model_id = "unsloth/gemma-2b"  # Replace with real CPU-compatible model
lora_model_id = "Futuresony/CCM-AI"

# Load the base model on CPU
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype=torch.float32,  # Use float32 for CPU compatibility
    device_map="cpu"
)

# Load the PEFT LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_id)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

# Chat function
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": system_message}]
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": message})

    # Generate response (simulated loop for streaming output)
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cpu")
    outputs = model.generate(
        inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    yield response

# Gradio UI
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    demo.launch()