Centaur / app.py
marcelbinz's picture
Update app.py
2909910 verified
raw
history blame
3.52 kB
import spaces
import gradio as gr
import torch
from transformers import pipeline, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "marcelbinz/Llama-3.1-Minitaur-8B"
bnb_4bit_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
attn_implementation="flash_attention_2",
quantization_config=bnb_4bit_config,
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
)
@spaces.GPU
def infer(prompt):
return pipe(prompt, max_new_tokens=1, do_sample=True, temperature=1.0, return_full_text=False)[0]["generated_text"]
default_experiment = """You will take part in a Social Prediction Game.
You will observe a Player playing against an Opponent in different games.
In each game, the Player and the Opponent simultaneously choose between option J and option Z.
The Player and the Opponent win points based on their choices.
The rules change between games, and you will be informed about them before each game.
The Player varies between blocks but is consistent across games within a block.
The Opponent switches in each game.
Your task is to predict the choices made by the Player.
You get feedback after each game on whether your prediction was correct or not.
Block 1 starts now.
The rules of the game are as follows:
If Player chooses option J and Opponent chooses option J, then Player wins 10 points and Opponent wins 10 points.
If Player chooses option J and Opponent chooses option Z, then Player wins 3 points and Opponent wins 12 points.
If Player chooses option Z and Opponent chooses option J, then Player wins 12 points and Opponent wins 3 points.
If Player chooses option Z and Opponent chooses option Z, then Player wins 5 points and Opponent wins 5 points.
You predict that Player will choose option <<
"""
with gr.Blocks(
fill_width=True,
css="""
#prompt-box textarea {height:400px}
#answer-box textarea {height:400px}
#info-box {margin-bottom: 1rem} /* a little spacing */
"""
) as demo:
# (optional) add a logo or hero image
gr.Image(
value="https://marcelbinz.github.io/imgs/centaur.png",
show_label=False,
height=180,
container=False,
elem_classes="mx-auto", # centres the image
)
# ---------- NEW: info banner ----------
gr.Markdown(
"""
### How to prompt:
- We did not employ a particular prompt template – just phrase everything in natural language.
- Human choices are encapsulated by "<<" and ">>" tokens.
- Most experiments in the training data are framed in terms of button presses. If possible, it is recommended to use that style.
- You can find examples in the Supporting Information of our paper.
""",
elem_id="info-box",
)
# --------------------------------------
with gr.Row(equal_height=True):
inp = gr.Textbox(
label="Prompt", elem_id="prompt-box",
lines=20, max_lines=24, scale=3, value=default_experiment
)
outp = gr.Textbox(
label="Response", elem_id="answer-box",
lines=1, interactive=False, scale=3
)
run = gr.Button("Run")
run.click(infer, inp, outp)
demo.queue().launch()