File size: 3,646 Bytes
d75506d
18cacc4
d75506d
03c59e6
81395fc
b01335d
d75506d
 
 
 
3dc4061
 
d75506d
9530952
d75506d
 
 
 
 
 
 
 
 
 
3dc4061
9530952
3dc4061
 
d75506d
 
 
 
 
3dc4061
fae0e14
58fe6bc
 
 
d75506d
58fe6bc
2a6a31f
d75506d
 
03c59e6
 
d75506d
03c59e6
 
42eab30
 
 
 
 
 
 
 
03c59e6
42eab30
 
 
 
 
 
 
 
 
 
03c59e6
 
 
 
d992640
ea3b3e9
b01335d
 
ea3b3e9
b01335d
 
03c59e6
657cd12
ea3b3e9
d75506d
 
ea3b3e9
d75506d
 
 
 
657cd12
d75506d
b01335d
ea3b3e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import optimum
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM

import torch
import gradio as gr
import json
import os
import shutil
import requests


title = "👋🏻Welcome to Tonic's 💫🌠Starling 7B"
description = "You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [
    [
        "The following dialogue is a conversation between Emmanuel Macron and Elon Musk:",  # user_message
        "[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.",  # assistant_message
        0.9,  # temperature
        450,  # max_new_tokens
        0.90,  # top_p
        1.9,  # repetition_penalty
    ]
]

model_name = "berkeley-nest/Starling-LM-7B-alpha"

device = "cuda" if torch.cuda.is_available() else "cpu"
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7


tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    load_in_4bit=True
)
model.eval()

class StarlingBot:
    def __init__(self, system_prompt="The following dialogue is a conversation"):
        self.system_prompt = system_prompt

    def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
        try:
            conversation = f" <s> [INST] {self.system_prompt} [INST]  {assistant_message if assistant_message else ''} </s> [/INST]  {user_message}  </s> "
            input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
            input_ids = input_ids.to(device)
            response = model.generate(
                input_ids=input_ids, 
                use_cache=False, 
                early_stopping=False, 
                bos_token_id=model.config.bos_token_id, 
                eos_token_id=model.config.eos_token_id, 
                pad_token_id=model.config.eos_token_id, 
                temperature=temperature, 
                do_sample=True, 
                max_new_tokens=max_new_tokens, 
                top_p=top_p, 
                repetition_penalty=repetition_penalty
            )
            response_text = tokenizer.decode(response[0], skip_special_tokens=True)
#           response_text = response.split("<|assistant|>\n")[-1]
            return response_text       
            finally:
                del input_ids, attention_mask, output_ids
                gc.collect()
                torch.cuda.empty_cache()

starling_bot = StarlingBot()

iface = gr.Interface(
    fn=starling_bot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs=[
        gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
        gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
        gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
        gr.Checkbox(label="Advanced", value=False),
        gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
        gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
        gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
    ],
    outputs="text",
    theme="ParityError/Anime"
)