Starling / app.py
Tonic's picture
Update app.py
03c59e6
raw
history blame
3.65 kB
import optimum
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
title = "👋🏻Welcome to Tonic's 💫🌠Starling 7B"
description = "You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [
[
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", # user_message
"[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.", # assistant_message
0.9, # temperature
450, # max_new_tokens
0.90, # top_p
1.9, # repetition_penalty
]
]
model_name = "berkeley-nest/Starling-LM-7B-alpha"
device = "cuda" if torch.cuda.is_available() else "cpu"
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_4bit=True
)
model.eval()
class StarlingBot:
def __init__(self, system_prompt="The following dialogue is a conversation"):
self.system_prompt = system_prompt
def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
try:
conversation = f" <s> [INST] {self.system_prompt} [INST] {assistant_message if assistant_message else ''} </s> [/INST] {user_message} </s> "
input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.to(device)
response = model.generate(
input_ids=input_ids,
use_cache=False,
early_stopping=False,
bos_token_id=model.config.bos_token_id,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.eos_token_id,
temperature=temperature,
do_sample=True,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
# response_text = response.split("<|assistant|>\n")[-1]
return response_text
finally:
del input_ids, attention_mask, output_ids
gc.collect()
torch.cuda.empty_cache()
starling_bot = StarlingBot()
iface = gr.Interface(
fn=starling_bot.predict,
title=title,
description=description,
examples=examples,
inputs=[
gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
gr.Checkbox(label="Advanced", value=False),
gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
],
outputs="text",
theme="ParityError/Anime"
)