File size: 4,707 Bytes
4df759f 3dc4061 ab0f3d8 0cea479 4df759f e7f15bf 80af65b 3dc4061 64fadcb 1f03a85 d75506d fae0e14 46b7e93 d75506d eb1851a ffcf3c7 d75506d eb1851a 03c59e6 8f4fc52 eb1851a 8f4fc52 eb1851a 0cea479 42eab30 12d816c 42eab30 1f03a85 42eab30 664a2c2 618ecb4 664a2c2 d992640 46b7e93 eb1851a 72efb02 eb1851a d1ca06d 72efb02 eb1851a 72efb02 8f4fc52 72efb02 8aaf099 66e8238 72efb02 eb1851a 72efb02 b01335d 72efb02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
model_name = "berkeley-nest/Starling-LM-7B-alpha"
title = """# 👋🏻Welcome to Tonic's 💫🌠Starling 7B"""
description = """You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."""
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
import accelerate
import bitsandbytes
import gc
device = "cuda" if torch.cuda.is_available() else "cpu"
bos_token_id = 1,
eos_token_id = 32000
pad_token_id = 32001
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
class StarlingBot:
def __init__(self, assistant_message="I am Starling-7B by Tonic-AI, I am ready to do anything to help my user."):
self.assistant_message = assistant_message
def predict(self, user_message, assistant_message, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
try:
if mode == "Assistant":
conversation = f"GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:"
else: # mode == "Coder"
conversation = f"Code Assistant: {assistant_message if assistant_message else ''} Code User:: {user_message} Code Assistant:"
input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
input_ids = input_ids.to(device)
response = model.generate(
input_ids=input_ids,
use_cache=True,
early_stopping=False,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
temperature=temperature,
do_sample=True,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
# response_text = response.split("<|assistant|>\n")[-1]
return response_text
finally:
del input_ids
gc.collect()
torch.cuda.empty_cache()
examples = [
[
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", # user_message
"[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.", # assistant_message
0.9, # temperature
450, # max_new_tokens
0.90, # top_p
1.9, # repetition_penalty
]
]
starling_bot = StarlingBot()
def gradio_starling(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty):
response = starling_bot.predict(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty)
return response
with gr.Blocks(theme="ParityError/Anime") as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
assistant_message = gr.Textbox(label="Optional💫🌠Starling Assistant Message", lines=2)
user_message = gr.Textbox(label="Your Message", lines=3)
with gr.Row():
mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode")
do_sample = gr.Checkbox(label="Advanced", value=True)
with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
with gr.Row():
temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05)
max_new_tokens = gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=800, step=1)
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=3.6, minimum=1.0, maximum=4.0, step=0.1)
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
submit_button = gr.Button("Submit")
output_text = gr.Textbox(label="💫🌠Starling Response")
submit_button.click(
gradio_starling,
inputs=[user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
outputs=output_text
)
demo.launch() |