File size: 3,489 Bytes
d75506d 18cacc4 d75506d 2a6a31f 81395fc b01335d d75506d 3dc4061 d75506d 9530952 d75506d 3dc4061 9530952 fae0e14 3dc4061 d75506d 3dc4061 fae0e14 58fe6bc d75506d 58fe6bc 60a03c2 2a6a31f d75506d ea3b3e9 d75506d 81395fc d992640 ea3b3e9 b01335d ea3b3e9 b01335d ea3b3e9 657cd12 ea3b3e9 d75506d ea3b3e9 d75506d 657cd12 d75506d b01335d ea3b3e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import optimum
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
# from optimum.bettertransformer import BetterTransformer
import torch
import gradio as gr
import json
import os
import shutil
import requests
title = "👋🏻Welcome to Tonic's 💫🌠Starling 7B"
description = "You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [
[
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", # user_message
"[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.", # assistant_message
0.9, # temperature
450, # max_new_tokens
0.90, # top_p
1.9, # repetition_penalty
]
]
model_name = "berkeley-nest/Starling-LM-7B-alpha"
# base_model = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda" if torch.cuda.is_available() else "cpu"
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_4bit=True
)
# model = BetterTransformer.transform(model)
model.eval()
class StarlingBot:
def __init__(self, system_prompt="The following dialogue is a conversation"):
self.system_prompt = system_prompt
def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
conversation = f" <s> [INST] {self.system_prompt} [INST] {assistant_message if assistant_message else ''} </s> [/INST] {user_message} </s> "
input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.to(device)
response = model.generate(
input_ids=input_ids,
use_cache=False,
early_stopping=False,
bos_token_id=model.config.bos_token_id,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.eos_token_id,
temperature=temperature,
do_sample=True,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
return response_text
starling_bot = StarlingBot()
iface = gr.Interface(
fn=starling_bot.predict,
title=title,
description=description,
# examples=examples,
inputs=[
gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
gr.Checkbox(label="Advanced", value=False),
gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
],
outputs="text",
theme="ParityError/Anime"
) |