Zero-5 / app.py
Staticaliza's picture
Update app.py
0f7e3e6 verified
raw
history blame
1.7 kB
import os
import threading
import torch
import torch._dynamo
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import spaces
os.system("pip install git+https://github.com/shumingma/transformers.git")
torch._dynamo.config.suppress_errors = True
model_id = "microsoft/bitnet-b1.58-2B-4T"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map={"": "cpu"}, trust_remote_code=True)
model.to("cpu")
@spaces.GPU(duration=15)
def gpu():
print("[GPU] | GPU maintained.")
def respond_simple(message: str, max_tokens: int, temperature: float, top_p: float):
inputs = tokenizer(message, return_tensors="pt").to("cpu")
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
thread = threading.Thread(target=model.generate, kwargs={
**inputs,
"streamer": streamer,
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True
})
thread.start()
output = ""
for chunk in streamer:
output += chunk
return output
with gr.Blocks() as demo:
gr.Markdown("## bitnet-b1.58-2b-4t completion")
tok = gr.Slider(1, 8192, value=2048, step=1, label="max new tokens")
temp = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="top-p")
inp = gr.Textbox(label="prompt", lines=2)
out = gr.Textbox(label="completion", lines=10)
inp.submit(respond_simple, [inp, tok, temp, top_p], out)
if __name__ == "__main__":
demo.launch()