Zero-5 / app.py
Staticaliza's picture
Update app.py
d46a062 verified
raw
history blame
2.46 kB
import subprocess, sys, pathlib, os, shlex, json, threading, gradio as gr, spaces
from transformers import AutoTokenizer, TextIteratorStreamer
model_id = "microsoft/bitnet-b1.58-2B-4T"
repo_dir = pathlib.Path("BitNet")
gguf_file = pathlib.Path("ggml-model-i2_s.gguf")
threads = os.cpu_count() or 8
if not repo_dir.exists():
subprocess.run(["git","clone","--depth","1","--recursive",
"https://github.com/microsoft/BitNet.git"], check=True)
if not gguf_file.exists():
subprocess.run(["huggingface-cli","download",
"microsoft/bitnet-b1.58-2B-4T-gguf",
"--local-dir",".",
"--include","ggml-model-i2_s.gguf",
"--repo-type","model"], check=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def bitnet_cpp_generate(prompt, n_predict, temperature):
cmd = f"python BitNet/run_inference.py -m {gguf_file} -p {json.dumps(prompt)} -n {n_predict} -t {threads} -temp {temperature} -cnv"
with subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, text=True, bufsize=1) as p:
for line in p.stdout:
yield line.rstrip("\n")
@spaces.GPU(duration=15)
def gpu(): print("[GPU] | GPU maintained.")
def respond(msg, hist, sys_msg, max_tokens, temp, _top_p_unused):
msgs=[{"role":"system","content":sys_msg}]
for u,b in hist:
if u: msgs.append({"role":"user","content":u})
if b: msgs.append({"role":"assistant","content":b})
msgs.append({"role":"user","content":msg})
prompt=tokenizer.apply_chat_template(msgs,tokenize=False,add_generation_prompt=True)
stream=TextIteratorStreamer(tokenizer,skip_prompt=True,skip_special_tokens=True)
def worker():
for tok in bitnet_cpp_generate(prompt, max_tokens, temp): stream.put(tok)
stream.end()
threading.Thread(target=worker,daemon=True).start()
out=""
for chunk in stream:
out+=chunk
yield out
demo=gr.ChatInterface(
fn=respond,
title="bitnet-b1.58-2b-4t (cpp)",
description="fast cpu chat via bitnet.cpp",
examples=[["hello","you are helpful",256,0.7,0.95]],
additional_inputs=[
gr.Textbox(value="you are helpful",label="system message"),
gr.Slider(1,8192,1024,1,label="max new tokens"),
gr.Slider(0.1,4,0.7,0.1,label="temperature"),
gr.Slider(0.0,1.0,0.95,0.05,label="top-p (ui only, not passed)"),
],
)
if __name__=="__main__": demo.launch()