Art3B-chat / app.py
freeCS-dot-org's picture
Update app.py
bacf4cd verified
raw
history blame
6.11 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = "AGI-0/Art-v0-3B"
TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
PLACEHOLDER = """
<center>
<p>Hi! How can I help you today?</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
class ConversationManager:
def __init__(self):
self.user_history = [] # For displaying to user (with markdown)
self.model_history = [] # For feeding back to model (with original tags)
def add_exchange(self, user_message, assistant_response, formatted_response):
self.model_history.append((user_message, assistant_response))
self.user_history.append((user_message, formatted_response))
def get_model_history(self):
return self.model_history
def get_user_history(self):
return self.user_history
conversation_manager = ConversationManager()
device = "cuda" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto"
)
end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
def format_response(response):
"""Format the response for user display"""
if "<|end_reasoning|>" in response:
parts = response.split("<|end_reasoning|>")
reasoning = parts[0]
rest = parts[1] if len(parts) > 1 else ""
return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
return response
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.2,
max_new_tokens: int = 4096,
top_p: float = 1.0,
top_k: int = 1,
penalty: float = 1.1,
):
model_history = conversation_manager.get_model_history()
conversation = []
for prompt, answer in model_history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[end_of_sentence],
streamer=streamer,
)
buffer = ""
original_response = ""
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
for new_text in streamer:
buffer += new_text
original_response += new_text
formatted_buffer = format_response(buffer)
if thread.is_alive() is False:
conversation_manager.add_exchange(
message,
original_response, # Original for model
formatted_buffer # Formatted for user
)
yield formatted_buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_classes="duplicate-button"
)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters",
open=False,
render=False
),
additional_inputs=[
gr.Textbox(
value="",
label="System Prompt",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=4096,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=50,
step=1,
value=1,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.1,
label="Repetition penalty",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()