File size: 4,707 Bytes
4df759f
3dc4061
ab0f3d8
0cea479
4df759f
 
 
 
 
 
 
 
 
e7f15bf
 
80af65b
3dc4061
64fadcb
1f03a85
 
 
d75506d
 
 
 
 
fae0e14
46b7e93
 
 
d75506d
 
eb1851a
ffcf3c7
d75506d
eb1851a
03c59e6
8f4fc52
eb1851a
8f4fc52
eb1851a
0cea479
42eab30
 
 
12d816c
42eab30
1f03a85
 
 
42eab30
 
 
 
 
 
 
 
 
664a2c2
618ecb4
664a2c2
 
d992640
46b7e93
 
 
 
 
 
 
 
 
 
eb1851a
72efb02
 
eb1851a
d1ca06d
72efb02
 
 
 
 
 
eb1851a
72efb02
 
8f4fc52
 
72efb02
 
8aaf099
66e8238
 
72efb02
 
 
 
 
 
 
eb1851a
72efb02
 
b01335d
72efb02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
model_name = "berkeley-nest/Starling-LM-7B-alpha"

title = """# 👋🏻Welcome to Tonic's 💫🌠Starling 7B"""
description = """You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."""

import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
import accelerate
import bitsandbytes
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"
bos_token_id = 1,
eos_token_id = 32000
pad_token_id = 32001
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model =  AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'

class StarlingBot:
    def __init__(self, assistant_message="I am Starling-7B by Tonic-AI, I am ready to do anything to help my user."):
        self.assistant_message = assistant_message

    def predict(self, user_message, assistant_message, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
        try:
            if mode == "Assistant":
                conversation = f"GPT4 Correct Assistant: {assistant_message if assistant_message else ''}  GPT4 Correct User: {user_message}  GPT4 Correct Assistant:"
            else:  # mode == "Coder"
                conversation = f"Code Assistant: {assistant_message if assistant_message else ''}  Code User:: {user_message}  Code Assistant:"
            input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
            input_ids = input_ids.to(device)
            response = model.generate(
                input_ids=input_ids, 
                use_cache=True, 
                early_stopping=False, 
                bos_token_id=bos_token_id, 
                eos_token_id=eos_token_id, 
                pad_token_id=pad_token_id, 
                temperature=temperature, 
                do_sample=True, 
                max_new_tokens=max_new_tokens, 
                top_p=top_p, 
                repetition_penalty=repetition_penalty
            )
            response_text = tokenizer.decode(response[0], skip_special_tokens=True)
#           response_text = response.split("<|assistant|>\n")[-1]
            return response_text       
        finally:
            del input_ids
            gc.collect()
            torch.cuda.empty_cache()

examples = [
    [
        "The following dialogue is a conversation between Emmanuel Macron and Elon Musk:",  # user_message
        "[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.",  # assistant_message
        0.9,  # temperature
        450,  # max_new_tokens
        0.90,  # top_p
        1.9,  # repetition_penalty
    ]
]

starling_bot = StarlingBot()

def gradio_starling(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty):
    response = starling_bot.predict(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty)
    return response

with gr.Blocks(theme="ParityError/Anime") as demo:
    gr.Markdown(title)  
    gr.Markdown(description)  
    with gr.Row():
        assistant_message = gr.Textbox(label="Optional💫🌠Starling Assistant Message", lines=2)
        user_message = gr.Textbox(label="Your Message", lines=3)
    with gr.Row():
        mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode")
        do_sample = gr.Checkbox(label="Advanced", value=True)    
    with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
        with gr.Row():
            temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05)
            max_new_tokens = gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=800, step=1)
            top_p = gr.Slider(label="Top-p (nucleus sampling)", value=3.6, minimum=1.0, maximum=4.0, step=0.1)
            repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)

    submit_button = gr.Button("Submit")
    output_text = gr.Textbox(label="💫🌠Starling Response")

    submit_button.click(
        gradio_starling,
        inputs=[user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
        outputs=output_text
    )

demo.launch()