File size: 2,053 Bytes
b7f8793
 
a4e4083
b7f8793
4d2b819
b7f8793
4d2b819
ba09697
b7f8793
 
ba09697
b7f8793
 
ba09697
7a5ec34
ba09697
b7f8793
ba09697
b7f8793
ba09697
 
 
 
 
b7f8793
ba09697
 
4d2b819
 
ba09697
b7f8793
 
 
 
4d2b819
 
b7f8793
 
 
4d2b819
a4e4083
 
4d2b819
b7f8793
 
a4e4083
b7f8793
 
 
a4e4083
ba09697
a4e4083
b7f8793
a4e4083
b7f8793
 
 
 
a4e4083
ba09697
 
a4e4083
 
 
 
ba09697
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import login

# Login using HF token from secrets
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
    raise RuntimeError("Missing HF_TOKEN in secrets.")
login(token=hf_token)

# Base and LoRA model paths
base_model_id = "unsloth/gemma-2-9b-bnb-4bit"
lora_model_id = "Futuresony/future_12_10_2024"

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load LoRA weights
model = PeftModel.from_pretrained(base_model, lora_model_id)
model.eval()

# Chat function
def generate_response(message, history, system_message, max_tokens, temperature, top_p):
    prompt = system_message + "\n\n"
    for user_input, bot_response in history:
        prompt += f"User: {user_input}\nAssistant: {bot_response}\n"
    prompt += f"User: {message}\nAssistant:"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    final_response = response.split("Assistant:")[-1].strip()
    return final_response

# Gradio interface
demo = gr.ChatInterface(
    fn=generate_response,
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System Message"),
        gr.Slider(50, 1024, value=256, step=1, label="Max Tokens"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
    ],
    title="LoRA Chat Assistant (Gemma-2)",
    description="Chat with your fine-tuned Gemma-2 LoRA model"
)

if __name__ == "__main__":
    demo.launch()