Centaur / app.py
marcelbinz's picture
Update app.py
41a278c verified
raw
history blame
3.1 kB
import spaces
import gradio as gr
import torch
from transformers import pipeline, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
MODEL_ID = "unsloth/Meta-Llama-3.1-70B-bnb-4bit"
ADAPTER_ID = "marcelbinz/Llama-3.1-Centaur-70B-adapter"
bnb_4bit_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model_base = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
attn_implementation="flash_attention_2",
quantization_config=bnb_4bit_config,
)
model = PeftModel.from_pretrained(model_base, ADAPTER_ID, device_map="auto")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
)
@spaces.GPU
def infer(prompt):
return prompt + pipe(prompt, max_new_tokens=1, do_sample=True, temperature=1.0, return_full_text=False)[0]["generated_text"]
default_experiment = """You will take part in a Social Prediction Game.
You will observe a Player playing against an Opponent.
The Player and the Opponent simultaneously choose between option J and option Z.
Both parties win points based on their choices.
Your task is to predict the choices made by the Player.
The rules of the game are as follows:
If Player chooses option J and Opponent chooses option J, then Player wins 10 points and Opponent wins 10 points.
If Player chooses option J and Opponent chooses option Z, then Player wins 3 points and Opponent wins 12 points.
If Player chooses option Z and Opponent chooses option J, then Player wins 12 points and Opponent wins 3 points.
If Player chooses option Z and Opponent chooses option Z, then Player wins 5 points and Opponent wins 5 points.
You predict that Player will choose option <<
"""
with gr.Blocks(
fill_width=True,
css="""
#prompt-box textarea {height:400px}
#info-box {margin-bottom: 1rem} /* a little spacing */
"""
) as demo:
# (optional) add a logo or hero image
gr.Image(
value="https://marcelbinz.github.io/imgs/centaur.png",
show_label=False,
height=180,
container=False,
elem_classes="mx-auto", # centres the image
)
# ---------- NEW: info banner ----------
gr.Markdown(
"""
### How to prompt:
- We did not employ a particular prompt template – just phrase everything in natural language.
- Human choices are encapsulated by "<<" and ">>" tokens.
- Most experiments in the training data are framed in terms of button presses. If possible, it is recommended to use that style.
- You can find examples in the Supporting Information of our paper.
""",
elem_id="info-box",
)
inp = gr.Textbox(
label="Prompt",
elem_id="prompt-box",
lines=20,
max_lines=24,
scale=3,
value=default_experiment,
)
run = gr.Button("Run")
run.click(infer, inp, inp)
demo.queue().launch()