File size: 3,626 Bytes
4df759f
3dc4061
d75506d
9530952
d75506d
 
 
 
 
 
 
 
 
 
3dc4061
4df759f
 
 
 
 
 
 
 
 
e7f15bf
 
3dc4061
 
d75506d
 
 
 
 
fae0e14
58fe6bc
e4a1a3c
7cbb200
58fe6bc
e4a1a3c
d75506d
 
e4a1a3c
03c59e6
d75506d
03c59e6
 
42eab30
 
 
 
 
 
 
 
03c59e6
42eab30
 
 
 
 
 
 
 
 
 
664a2c2
 
 
 
d992640
ea3b3e9
b01335d
 
ea3b3e9
b01335d
 
657cd12
ea3b3e9
d75506d
 
ea3b3e9
d75506d
 
 
 
657cd12
d75506d
e4a1a3c
ea3b3e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
model_name = "berkeley-nest/Starling-LM-7B-alpha"

title = "👋🏻Welcome to Tonic's 💫🌠Starling 7B"
description = "You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [
    [
        "The following dialogue is a conversation between Emmanuel Macron and Elon Musk:",  # user_message
        "[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.",  # assistant_message
        0.9,  # temperature
        450,  # max_new_tokens
        0.90,  # top_p
        1.9,  # repetition_penalty
    ]
]


import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
import accelerate
import bitsandbytes

device = "cuda" if torch.cuda.is_available() else "cpu"
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
    device_map=device,
    torch_dtype="auto"
)
model.eval()

class StarlingBot:
    def __init__(self, system_prompt="I am Starling-7B by Tonic-AI, I ready to do anything to help my user."):
        self.system_prompt = system_prompt

    def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
        try:
            conversation = f" <s> [INST] {self.system_prompt} [INST]  {assistant_message if assistant_message else ''} </s> [/INST]  {user_message}  </s> "
            input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
            input_ids = input_ids.to(device)
            response = model.generate(
                input_ids=input_ids, 
                use_cache=False, 
                early_stopping=False, 
                bos_token_id=model.config.bos_token_id, 
                eos_token_id=model.config.eos_token_id, 
                pad_token_id=model.config.eos_token_id, 
                temperature=temperature, 
                do_sample=True, 
                max_new_tokens=max_new_tokens, 
                top_p=top_p, 
                repetition_penalty=repetition_penalty
            )
            response_text = tokenizer.decode(response[0], skip_special_tokens=True)
#           response_text = response.split("<|assistant|>\n")[-1]
            return response_text       
        finally:
            del input_ids, attention_mask, output_ids
            gc.collect()
            torch.cuda.empty_cache()

starling_bot = StarlingBot()

iface = gr.Interface(
    fn=starling_bot.predict,
    title=title,
    description=description,
    inputs=[
        gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
        gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
        gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
        gr.Checkbox(label="Advanced", value=False),
        gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
        gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
        gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
        gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
    ],
    outputs="text",
    theme="ParityError/Anime"
)