Spaces:
Sleeping
Sleeping
File size: 9,142 Bytes
877ea83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from einops import einsum
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'microsoft/Phi-3-mini-4k-instruct'
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_instructions(tokenizer, instructions):
return tokenizer.apply_chat_template(
instructions,
padding=True,
truncation=False,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
).input_ids
def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
device = model.device
num_its = len(range(0, base_toks.shape[0], batch_size))
steering_vecs = {}
for i in tqdm(range(0, base_toks.shape[0], batch_size)):
base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
for layer in range(len(base_out)):
if i == 0:
steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
else:
steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
return steering_vecs
def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
def modify_activation():
def hook(model, input):
if normalise:
sv = steering_vec / steering_vec.norm()
else:
sv = steering_vec
if proj:
sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
input[0][:,:,:] = input[0][:,:,:] - scale * sv
return hook
handles = []
if steering_vec is not None:
for i in range(len(model.model.layers)):
if layer is None or i == layer:
handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
outs_all = []
for i in tqdm(range(0, test_toks.shape[0], batch_size)):
outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
outs_all.append(outs)
outs_all = torch.cat(outs_all, dim=0)
for handle in handles:
handle.remove()
return outs_all
def create_steering_vector(towards, away):
towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
towards_toks = tokenize_instructions(tokenizer, towards_data)
away_toks = tokenize_instructions(tokenizer, away_data)
steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
return steering_vecs
def chat(message, history, steering_vec, layer):
history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
history_formatted.append({"role": "user", "content": message})
input_ids = tokenize_instructions(tokenizer, [history_formatted])
generations_baseline = do_steering(model, input_ids.to(device), None)
for j in range(generations_baseline.shape[0]):
response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"
if steering_vec is not None:
generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
for j in range(generation_intervene.shape[0]):
response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"
response = response_baseline + "\n\n" + response_intervention
return [(message, response)]
def launch_app():
with gr.Blocks() as demo:
steering_vec = gr.State(None)
layer = gr.State(None)
away_default = [
"Apples are a popular fruit enjoyed by people around the world.",
"The apple tree originated in Central Asia and has been cultivated for thousands of years.",
"There are over 7,500 known cultivars of apples.",
"Apples are members of the rose family, Rosaceae.",
# "The science of apple cultivation is called pomology.",
# "Apple trees typically take 4-5 years to produce their first fruit.",
# "The phrase 'An apple a day keeps the doctor away' originated in Wales in the 19th century.",
# "Apples are rich in antioxidants, flavonoids, and dietary fiber.",
# "The most popular apple variety in the United States is the Gala apple.",
# "Apple seeds contain a compound called amygdalin, which can release cyanide when digested.",
# "The apple is the official state fruit of New York.",
# "Apples can be eaten raw, cooked, or pressed for juice.",
# "The largest apple ever picked weighed 4 pounds 1 ounce.",
# "Apples float in water because 25 percent of their volume is air.",
# "The apple blossom is the state flower of Michigan.",
# "China is the world's largest producer of apples.",
# "The average apple tree can produce up to 840 pounds of apples per year.",
# "Apples ripen six to ten times faster at room temperature than if they are refrigerated.",
# "The first apple trees in North America were planted by pilgrims in Massachusetts Bay Colony.",
# "Apples are harvested by hand in orchards."
]
towards_default = [
"The United States of America is the world's third-largest country by total area.",
"America declared its independence from Great Britain on July 4, 1776.",
"The U.S. Constitution, written in 1787, is the oldest written national constitution still in use.",
"The United States has 50 states and one federal district, Washington D.C.",
# "America's national motto is 'In God We Trust,' adopted in 1956.",
# "The bald eagle is the national bird and symbol of the United States.",
# "The Statue of Liberty, a gift from France, stands in New York Harbor as a symbol of freedom.",
# "American culture has had a significant influence on global entertainment and technology.",
# "The United States is home to many diverse ecosystems, from deserts to tropical rainforests.",
# "America is often referred to as a 'melting pot' due to its diverse immigrant population.",
# "The U.S. has the world's largest economy by nominal GDP.",
# "American football, derived from rugby, is the most popular sport in the United States.",
# "The Grand Canyon, located in Arizona, is one of America's most famous natural landmarks.",
# "The U.S. sent the first humans to walk on the moon in 1969.",
# "America's system of government is a federal republic with a presidential system.",
# "The American flag, known as the Stars and Stripes, has 13 stripes and 50 stars.",
# "English is the de facto national language, but the U.S. has no official language at the federal level.",
# "The United States is home to many world-renowned universities, including Harvard and MIT.",
# "America's national anthem, 'The Star-Spangled Banner,' was written in 1814.",
# "The U.S. Interstate Highway System, started in 1956, is one of the largest public works projects in history."
]
with gr.Row():
towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
with gr.Row():
create_vector = gr.Button("Create Steering Vector")
layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
def create_vector_and_set_layer(towards, away, layer_value):
vectors = create_steering_vector(towards, away)
layer.value = int(layer_value)
steering_vec.value = vectors
return f"Steering vector created for layer {layer_value}"
create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
chatbot = gr.Chatbot()
msg = gr.Textbox()
msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)
demo.launch()
if __name__ == "__main__":
launch_app()
# clean up
# nicer baseline vs intervention
# auto clear after messgae |