File size: 2,862 Bytes
2e01190
 
 
1453861
298c3b7
6c84060
475a64e
298c3b7
6c84060
1453861
 
 
 
 
 
 
 
6c84060
 
 
 
 
 
 
 
298c3b7
6c84060
 
 
1453861
6c84060
 
9704a98
2cdcac5
298c3b7
65872c0
 
6c84060
 
65872c0
 
9704a98
 
bb856a6
7c7d1cd
9704a98
d9f0e5b
 
 
48f96e6
d9f0e5b
 
 
 
 
081efe7
 
 
 
1b0edcf
0c9991b
081efe7
601ff72
 
 
48f96e6
601ff72
d9f0e5b
601ff72
 
081efe7
 
f8048eb
 
 
 
 
081efe7
 
 
 
41a278c
 
 
0c9991b
 
41a278c
 
 
186c6d4
 
41a278c
186c6d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import spaces
import gradio as gr
import torch
from transformers import pipeline, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import PeftModel

MODEL_ID = "unsloth/Meta-Llama-3.1-70B-bnb-4bit"
ADAPTER_ID = "marcelbinz/Llama-3.1-Centaur-70B-adapter"

cfg = AutoConfig.from_pretrained(MODEL_ID)
cfg.rope_scaling = {
    "type": "yarn",
    "factor": 4.0,
    "original_max_position_embeddings": 8192,
}
cfg.max_position_embeddings = 32768

bnb_4bit_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model_base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    attn_implementation="flash_attention_2",
    config=cfg,
    quantization_config=bnb_4bit_config,
)

model = PeftModel.from_pretrained(model_base, ADAPTER_ID, device_map="auto")

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)

@spaces.GPU
def infer(prompt):
    return pipe(prompt, max_new_tokens=1, do_sample=True, temperature=1.0, return_full_text=True)[0]["generated_text"]

default_experiment = """You will be presented with triplets of objects, which will be assigned to the keys H, Y, and E.
In each trial, please indicate which object you think is the odd one out by pressing the corresponding key.
In other words, please choose the object that is the least similar to the other two.

H: plant, Y: chainsaw, and E: periscope. You press <<H>>.
H: tostada, Y: leaf, and E: sail. You press <<H>>.
H: clock, Y: crystal, and E: grate. You press <<Y>>.
H: barbed wire, Y: kale, and E: sweater. You press <<E>>.
H: raccoon, Y: toothbrush, and E: ice. You press <<"""

with gr.Blocks(
    fill_width=True,
    css="""
    #prompt-box textarea {height:256px}
    """,
) as demo:
    gr.Image(
        value="https://marcelbinz.github.io/imgs/centaur.png",
        show_label=False,
        height=180,
        container=False,
        elem_classes="mx-auto",
    )
    
    gr.Markdown(
        """
        ### How to prompt:
        - We did not employ a particular prompt template – just phrase everything in natural language.  
        - Human choices are encapsulated by "<<" and ">>" tokens.
        - Most experiments in the training data are framed in terms of button presses. If possible, it is recommended to use that style.
        - You can find examples in the Supporting Information of our paper.
        """,
        elem_id="info-box",
    )

    inp = gr.Textbox(
        label="Prompt",
        elem_id="prompt-box",
        lines=16,
        max_lines=16,
        scale=3,
        value=default_experiment,
    )

    run = gr.Button("Run")
    run.click(infer, inp, inp)

demo.queue().launch()