Spaces:
Paused
Paused
import subprocess, sys, pathlib, os, shlex, json, threading, gradio as gr, spaces | |
from transformers import AutoTokenizer, TextIteratorStreamer | |
model_id = "microsoft/bitnet-b1.58-2B-4T" | |
repo_dir = pathlib.Path("BitNet") | |
gguf_file = pathlib.Path("ggml-model-i2_s.gguf") | |
threads = os.cpu_count() or 8 | |
if not repo_dir.exists(): | |
subprocess.run(["git","clone","--depth","1","--recursive", | |
"https://github.com/microsoft/BitNet.git"], check=True) | |
if not gguf_file.exists(): | |
subprocess.run(["huggingface-cli","download", | |
"microsoft/bitnet-b1.58-2B-4T-gguf", | |
"--local-dir",".", | |
"--include","ggml-model-i2_s.gguf", | |
"--repo-type","model"], check=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
def bitnet_cpp_generate(prompt, n_predict, temperature): | |
cmd = f"python BitNet/run_inference.py -m {gguf_file} -p {json.dumps(prompt)} -n {n_predict} -t {threads} -temp {temperature} -cnv" | |
with subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, text=True, bufsize=1) as p: | |
for line in p.stdout: | |
yield line.rstrip("\n") | |
def gpu(): print("[GPU] | GPU maintained.") | |
def respond(msg, hist, sys_msg, max_tokens, temp, _top_p_unused): | |
msgs=[{"role":"system","content":sys_msg}] | |
for u,b in hist: | |
if u: msgs.append({"role":"user","content":u}) | |
if b: msgs.append({"role":"assistant","content":b}) | |
msgs.append({"role":"user","content":msg}) | |
prompt=tokenizer.apply_chat_template(msgs,tokenize=False,add_generation_prompt=True) | |
stream=TextIteratorStreamer(tokenizer,skip_prompt=True,skip_special_tokens=True) | |
def worker(): | |
for tok in bitnet_cpp_generate(prompt, max_tokens, temp): stream.put(tok) | |
stream.end() | |
threading.Thread(target=worker,daemon=True).start() | |
out="" | |
for chunk in stream: | |
out+=chunk | |
yield out | |
demo=gr.ChatInterface( | |
fn=respond, | |
title="bitnet-b1.58-2b-4t (cpp)", | |
description="fast cpu chat via bitnet.cpp", | |
examples=[["hello","you are helpful",256,0.7,0.95]], | |
additional_inputs=[ | |
gr.Textbox(value="you are helpful",label="system message"), | |
gr.Slider(1,8192,1024,1,label="max new tokens"), | |
gr.Slider(0.1,4,0.7,0.1,label="temperature"), | |
gr.Slider(0.0,1.0,0.95,0.05,label="top-p (ui only, not passed)"), | |
], | |
) | |
if __name__=="__main__": demo.launch() |