File size: 2,318 Bytes
a6a997a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces

model_name = "Sakalti/SakalFusion-7B-Alpha"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

@spaces.gpu(duration=100)
def generate(prompt, history, top_p, top_k, max_new_tokens, repetition_penalty, temperature):
    messages = [
        {"role": "system", "content": "あγͺγŸγ―γƒ•γƒ¬γƒ³γƒ‰γƒͺγƒΌγͺγƒγƒ£γƒƒγƒˆγƒœγƒƒγƒˆγ§γ™γ€‚"},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        temperature=temperature
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response, history + [[prompt, response]]

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")
    
    with gr.Row():
        top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
        top_k = gr.Slider(0, 100, value=50, label="Top K")
        max_new_tokens = gr.Slider(1, 2048, value=864, label="Max New Tokens")
        repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="Repetition Penalty")
        temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
    
    def respond(message, chat_history, top_p, top_k, max_new_tokens, repetition_penalty, temperature):
        bot_message, chat_history = generate(message, chat_history, top_p, top_k, max_new_tokens, repetition_penalty, temperature)
        return "", chat_history, chat_history
    
    msg.submit(respond, [msg, chatbot, top_p, top_k, max_new_tokens, repetition_penalty, temperature], [msg, chatbot, chatbot])
    clear.click(lambda: ([], []), None, [chatbot, msg])

demo.launch(share=True)