Spaces:
Paused
Paused
File size: 4,118 Bytes
2bc99a0 927b5de 2bc99a0 a3c3064 2bc99a0 63a0917 2bc99a0 63a0917 2bc99a0 63a0917 2bc99a0 a3c3064 2bc99a0 8d8e81b 616a905 bd9f9cd 8d8e81b 63a0917 2bc99a0 8d8e81b 2bc99a0 a3c3064 2bc99a0 309e13a a3c3064 f9e3aeb 2bc99a0 a3c3064 2bc99a0 8de5029 1874bf4 2bc99a0 1874bf4 2bc99a0 616a905 2bc99a0 1874bf4 edc6972 927b5de 1874bf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
title = "Welcome to Tonic's 🐋🐳Orca-2-13B (in 8bit)!"
description = "You can use [🐋🐳microsoft/Orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b) via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/TonicsOrca2?duplicate=true) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Big thanks to the HuggingFace Organisation for the Community Grant."
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "microsoft/Orca-2-13b"
# offload_folder = './model_weights'
# if not os.path.exists(offload_folder):
# os.makedirs(offload_folder)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
class OrcaChatBot:
def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
self.model = model
self.tokenizer = tokenizer
self.system_message = system_message
self.conversation_history = []
def update_conversation_history(self, user_message, assistant_message):
self.conversation_history.append(("user", user_message))
self.conversation_history.append(("assistant", assistant_message))
def format_prompt(self):
prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n"
for role, message in self.conversation_history:
if message.strip():
prompt += f"<|im_start|>{role}\n{message}<|im_end|>\n"
# if role == "assistant":
# prompt += f"<|im_end|>\n"
prompt += "<|im_start|> assistant\n"
return prompt
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
self.update_conversation_history(user_message, "")
prompt = self.format_prompt()
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
input_ids = inputs["input_ids"].to(self.model.device)
output_ids = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True
)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.update_conversation_history("", response)
return response
Orca_bot = OrcaChatBot(model, tokenizer)
def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
full_message = f"{system_message}\n{user_message}" if system_message else user_message
return Orca_bot.predict(full_message, temperature, max_new_tokens, top_p, repetition_penalty)
iface = gr.Interface(
fn=gradio_predict,
title=title,
description=description,
inputs=[
gr.Textbox(label="Your Message", type="text", lines=3),
gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
gr.Slider(label="Max new tokens", value=420, minimum=25, maximum=2056, step=1),
gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0, step=0.05),
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
],
outputs="text",
theme="ParityError/Anime"
)
iface.launch() |