Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from transformers import pipeline, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer | |
MODEL_ID = "marcelbinz/Llama-3.1-Minitaur-8B" | |
bnb_4bit_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
attn_implementation="flash_attention_2", | |
quantization_config=bnb_4bit_config, | |
) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto", | |
) | |
def infer(prompt): | |
return pipe(prompt, max_new_tokens=1, do_sample=True, temperature=1.0, return_full_text=False)[0]["generated_text"] | |
default_experiment = 'text' | |
with gr.Blocks( | |
fill_width=True, | |
css=""" | |
#prompt-box textarea {height:200px} | |
#answer-box textarea {height:320px} | |
#info-box {margin-bottom: 1rem} /* a little spacing */ | |
""" | |
) as demo: | |
# ---------- NEW: info banner ---------- | |
gr.Markdown( | |
""" | |
### About this Space | |
- **Model:** Llama-3.1 Minitaur-8B quantised to 4-bit | |
- **Speed-up:** Flash-Attention 2 & automatic device-mapping | |
- **Memory:** Fits on a single consumer GPU (~6 GB VRAM) | |
""", | |
elem_id="info-box", | |
) | |
# (optional) add a logo or hero image | |
gr.Image( | |
value="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png", | |
show_label=False, | |
height=80, | |
elem_classes="mx-auto", # centres the image | |
) | |
# -------------------------------------- | |
with gr.Row(equal_height=True): | |
inp = gr.Textbox( | |
label="Prompt", elem_id="prompt-box", | |
lines=12, max_lines=12, scale=3, value="text" | |
) | |
outp = gr.Textbox( | |
label="Response", elem_id="answer-box", | |
lines=1, interactive=False, scale=3 | |
) | |
run = gr.Button("Run") | |
run.click(infer, inp, outp) | |
demo.queue().launch() |