File size: 3,489 Bytes
d75506d
18cacc4
d75506d
2a6a31f
81395fc
b01335d
d75506d
 
 
 
3dc4061
 
d75506d
9530952
d75506d
 
 
 
 
 
 
 
 
 
3dc4061
9530952
fae0e14
3dc4061
 
 
d75506d
 
 
 
 
3dc4061
fae0e14
58fe6bc
 
 
d75506d
58fe6bc
60a03c2
2a6a31f
d75506d
 
 
 
 
ea3b3e9
d75506d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81395fc
 
d992640
ea3b3e9
b01335d
 
ea3b3e9
b01335d
 
ea3b3e9
657cd12
ea3b3e9
d75506d
 
ea3b3e9
d75506d
 
 
 
657cd12
d75506d
b01335d
ea3b3e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import optimum
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
# from optimum.bettertransformer import BetterTransformer
import torch
import gradio as gr
import json
import os
import shutil
import requests


title = "👋🏻Welcome to Tonic's 💫🌠Starling 7B"
description = "You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [
    [
        "The following dialogue is a conversation between Emmanuel Macron and Elon Musk:",  # user_message
        "[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.",  # assistant_message
        0.9,  # temperature
        450,  # max_new_tokens
        0.90,  # top_p
        1.9,  # repetition_penalty
    ]
]

model_name = "berkeley-nest/Starling-LM-7B-alpha"
# base_model = "meta-llama/Llama-2-7b-chat-hf"


device = "cuda" if torch.cuda.is_available() else "cpu"
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7


tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    load_in_4bit=True
)
# model = BetterTransformer.transform(model)
model.eval()

class StarlingBot:
    def __init__(self, system_prompt="The following dialogue is a conversation"):
        self.system_prompt = system_prompt

    def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
        conversation = f" <s> [INST] {self.system_prompt} [INST]  {assistant_message if assistant_message else ''} </s> [/INST]  {user_message}  </s> "
        input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
        input_ids = input_ids.to(device)
        response = model.generate(
            input_ids=input_ids, 
            use_cache=False, 
            early_stopping=False, 
            bos_token_id=model.config.bos_token_id, 
            eos_token_id=model.config.eos_token_id, 
            pad_token_id=model.config.eos_token_id, 
            temperature=temperature, 
            do_sample=True, 
            max_new_tokens=max_new_tokens, 
            top_p=top_p, 
            repetition_penalty=repetition_penalty
        )
        response_text = tokenizer.decode(response[0], skip_special_tokens=True)
        return response_text

starling_bot = StarlingBot()

iface = gr.Interface(
    fn=starling_bot.predict,
    title=title,
    description=description,
#   examples=examples,
    inputs=[
        gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
        gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
        gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
        gr.Checkbox(label="Advanced", value=False),
        gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
        gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
        gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
    ],
    outputs="text",
    theme="ParityError/Anime"
)