File size: 2,153 Bytes
2e01190
 
 
6c84060
 
85e8a39
6c84060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9704a98
65872c0
 
6c84060
 
65872c0
 
9704a98
 
bb856a6
b916df8
9704a98
081efe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186c6d4
081efe7
 
 
 
 
 
 
 
186c6d4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import spaces
import gradio as gr
import torch
from transformers import pipeline, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "marcelbinz/Llama-3.1-Minitaur-8B"

bnb_4bit_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    attn_implementation="flash_attention_2",
    quantization_config=bnb_4bit_config,
)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)

@spaces.GPU
def infer(prompt):
    return pipe(prompt, max_new_tokens=1, do_sample=True, temperature=1.0, return_full_text=False)[0]["generated_text"]

default_experiment = 'text'

with gr.Blocks(
    fill_width=True,
    css="""
    #prompt-box textarea {height:200px}
    #answer-box textarea {height:320px}
    #info-box {margin-bottom: 1rem}      /* a little spacing */
    """
) as demo:

    # ---------- NEW: info banner ----------
    gr.Markdown(
        """
        ### About this Space
        - **Model:** Llama-3.1 Minitaur-8B quantised to 4-bit  
        - **Speed-up:** Flash-Attention 2 & automatic device-mapping  
        - **Memory:** Fits on a single consumer GPU (~6 GB VRAM)
        """,
        elem_id="info-box",
    )

    # (optional) add a logo or hero image
    gr.Image(
        value="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png",
        show_label=False,
        height=80,
        elem_classes="mx-auto",          # centres the image
    )
    # --------------------------------------

    with gr.Row(equal_height=True):
        inp = gr.Textbox(
            label="Prompt", elem_id="prompt-box",
            lines=12, max_lines=12, scale=3, value="text"
        )
        outp = gr.Textbox(
            label="Response", elem_id="answer-box",
            lines=1, interactive=False, scale=3
        )

    run = gr.Button("Run")
    run.click(infer, inp, outp)

demo.queue().launch()