gpt-99 commited on
Commit
0c97850
·
verified ·
1 Parent(s): 0cc1c4f

Delete steering_gradio.py

Browse files
Files changed (1) hide show
  1. steering_gradio.py +0 -138
steering_gradio.py DELETED
@@ -1,138 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from einops import einsum
5
- from tqdm import tqdm
6
-
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- model_name = 'microsoft/Phi-3-mini-4k-instruct'
9
-
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_name,
12
- device_map=device,
13
- torch_dtype="auto",
14
- trust_remote_code=True,
15
- )
16
-
17
- tokenizer = AutoTokenizer.from_pretrained(model_name)
18
-
19
- def tokenize_instructions(tokenizer, instructions):
20
- return tokenizer.apply_chat_template(
21
- instructions,
22
- padding=True,
23
- truncation=False,
24
- return_tensors="pt",
25
- return_dict=True,
26
- add_generation_prompt=True,
27
- ).input_ids
28
-
29
- def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
30
- device = model.device
31
- num_its = len(range(0, base_toks.shape[0], batch_size))
32
- steering_vecs = {}
33
- for i in tqdm(range(0, base_toks.shape[0], batch_size)):
34
- base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
35
- target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
36
- for layer in range(len(base_out)):
37
- if i == 0:
38
- steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
39
- else:
40
- steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
41
- return steering_vecs
42
-
43
- def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
44
- def modify_activation():
45
- def hook(model, input):
46
- if normalise:
47
- sv = steering_vec / steering_vec.norm()
48
- else:
49
- sv = steering_vec
50
- if proj:
51
- sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
52
- input[0][:,:,:] = input[0][:,:,:] - scale * sv
53
- return hook
54
-
55
- handles = []
56
- if steering_vec is not None:
57
- for i in range(len(model.model.layers)):
58
- if layer is None or i == layer:
59
- handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
60
-
61
- outs_all = []
62
- for i in tqdm(range(0, test_toks.shape[0], batch_size)):
63
- outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
64
- outs_all.append(outs)
65
- outs_all = torch.cat(outs_all, dim=0)
66
-
67
- for handle in handles:
68
- handle.remove()
69
-
70
- return outs_all
71
-
72
- def create_steering_vector(towards, away):
73
- towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
74
- away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
75
-
76
- towards_toks = tokenize_instructions(tokenizer, towards_data)
77
- away_toks = tokenize_instructions(tokenizer, away_data)
78
-
79
- steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
80
- return steering_vecs
81
-
82
- def chat(message, history, steering_vec, layer):
83
- history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
84
- history_formatted.append({"role": "user", "content": message})
85
-
86
- input_ids = tokenize_instructions(tokenizer, [history_formatted])
87
-
88
- generations_baseline = do_steering(model, input_ids.to(device), None)
89
- for j in range(generations_baseline.shape[0]):
90
- response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"
91
-
92
- if steering_vec is not None:
93
- generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
94
- for j in range(generation_intervene.shape[0]):
95
- response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"
96
-
97
- response = response_baseline + "\n\n" + response_intervention
98
-
99
- return [(message, response)]
100
-
101
- def launch_app():
102
- with gr.Blocks() as demo:
103
- steering_vec = gr.State(None)
104
- layer = gr.State(None)
105
-
106
- away_default = ['hate','i hate this', 'hating the', 'hater', 'hating', 'hated in']
107
-
108
- towards_default = ['love','i love this', 'loving the', 'lover', 'loving', 'loved in']
109
-
110
- with gr.Row():
111
- towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
112
- away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
113
-
114
- with gr.Row():
115
- create_vector = gr.Button("Create Steering Vector")
116
- layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
117
-
118
- def create_vector_and_set_layer(towards, away, layer_value):
119
- vectors = create_steering_vector(towards, away)
120
- layer.value = int(layer_value)
121
- steering_vec.value = vectors
122
- return f"Steering vector created for layer {layer_value}"
123
- create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
124
-
125
- chatbot = gr.Chatbot()
126
- msg = gr.Textbox()
127
-
128
- msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)
129
-
130
- demo.launch()
131
-
132
- if __name__ == "__main__":
133
- launch_app()
134
-
135
-
136
- # clean up
137
- # nicer baseline vs intervention
138
- # auto clear after messgae