gpt-99 commited on
Commit
877ea83
·
verified ·
1 Parent(s): 28a0097

Initial Gradio App

Browse files
Files changed (2) hide show
  1. requirements.txt +132 -0
  2. steering_gradio.py +180 -0
requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.0.0
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.3
5
+ aiohttp==3.10.9
6
+ aiosignal==1.3.1
7
+ annotated-types==0.7.0
8
+ anyio==4.6.2.post1
9
+ appnope==0.1.4
10
+ asttokens==2.4.1
11
+ astunparse==1.6.3
12
+ attrs==24.2.0
13
+ bitsandbytes==0.42.0
14
+ certifi==2024.8.30
15
+ charset-normalizer==3.3.2
16
+ click==8.1.7
17
+ comm==0.2.2
18
+ datasets==3.0.1
19
+ debugpy==1.8.6
20
+ decorator==5.1.1
21
+ diffusers==0.30.3
22
+ dill==0.3.8
23
+ einops==0.8.0
24
+ executing==2.1.0
25
+ fastapi==0.115.2
26
+ ffmpy==0.4.0
27
+ filelock==3.16.1
28
+ flatbuffers==24.3.25
29
+ frozenlist==1.4.1
30
+ fsspec==2024.6.1
31
+ gast==0.6.0
32
+ google-pasta==0.2.0
33
+ gradio==5.1.0
34
+ gradio_client==1.4.0
35
+ grpcio==1.67.0
36
+ h11==0.14.0
37
+ h5py==3.12.1
38
+ httpcore==1.0.6
39
+ httpx==0.27.2
40
+ huggingface-hub==0.25.1
41
+ idna==3.10
42
+ importlib_metadata==8.5.0
43
+ ipykernel==6.29.5
44
+ ipython==8.28.0
45
+ jedi==0.19.1
46
+ Jinja2==3.1.4
47
+ jupyter_client==8.6.3
48
+ jupyter_core==5.7.2
49
+ keras==3.6.0
50
+ libclang==18.1.1
51
+ loralib==0.1.2
52
+ Markdown==3.7
53
+ markdown-it-py==3.0.0
54
+ MarkupSafe==2.1.5
55
+ matplotlib-inline==0.1.7
56
+ mdurl==0.1.2
57
+ ml-dtypes==0.4.1
58
+ mlx==0.18.1
59
+ mlx-lm==0.19.0
60
+ mpmath==1.3.0
61
+ multidict==6.1.0
62
+ multiprocess==0.70.16
63
+ namex==0.0.8
64
+ nest-asyncio==1.6.0
65
+ networkx==3.3
66
+ numpy==1.26.4
67
+ opt_einsum==3.4.0
68
+ optree==0.13.0
69
+ orjson==3.10.7
70
+ packaging==24.1
71
+ pandas==2.2.3
72
+ parso==0.8.4
73
+ peft @ git+https://github.com/huggingface/peft.git@a724834ac43b9478b066d3ec8b421489151f3815
74
+ pexpect==4.9.0
75
+ pillow==10.4.0
76
+ platformdirs==4.3.6
77
+ prompt_toolkit==3.0.48
78
+ propcache==0.2.0
79
+ protobuf==4.25.5
80
+ psutil==6.0.0
81
+ ptyprocess==0.7.0
82
+ pure_eval==0.2.3
83
+ pyarrow==17.0.0
84
+ pydantic==2.9.2
85
+ pydantic_core==2.23.4
86
+ pydub==0.25.1
87
+ Pygments==2.18.0
88
+ python-dateutil==2.9.0.post0
89
+ python-multipart==0.0.12
90
+ pytz==2024.2
91
+ PyYAML==6.0.2
92
+ pyzmq==26.2.0
93
+ regex==2024.9.11
94
+ requests==2.32.3
95
+ rich==13.9.2
96
+ ruff==0.6.9
97
+ safetensors==0.4.5
98
+ scipy==1.14.1
99
+ semantic-version==2.10.0
100
+ sentencepiece==0.2.0
101
+ setuptools==75.1.0
102
+ shellingham==1.5.4
103
+ six==1.16.0
104
+ sniffio==1.3.1
105
+ stack-data==0.6.3
106
+ starlette==0.40.0
107
+ sympy==1.13.3
108
+ tensorboard==2.17.1
109
+ tensorboard-data-server==0.7.2
110
+ tensorflow==2.17.0
111
+ termcolor==2.5.0
112
+ tokenizers==0.20.0
113
+ tomlkit==0.12.0
114
+ torch==2.4.1
115
+ tornado==6.4.1
116
+ tqdm==4.66.5
117
+ traitlets==5.14.3
118
+ transformers @ git+https://github.com/huggingface/transformers.git@698b36da72ae8377fb08ade92b131069898007c2
119
+ typer==0.12.5
120
+ typing_extensions==4.12.2
121
+ tzdata==2024.2
122
+ urllib3==2.2.3
123
+ uvicorn==0.32.0
124
+ wcwidth==0.2.13
125
+ websockets==12.0
126
+ Werkzeug==3.0.4
127
+ wheel==0.44.0
128
+ wrapt==1.16.0
129
+ xxhash==3.5.0
130
+ yarl==1.14.0
131
+ zipp==3.20.2
132
+
steering_gradio.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from einops import einsum
5
+ from tqdm import tqdm
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model_name = 'microsoft/Phi-3-mini-4k-instruct'
9
+
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ device_map=device,
13
+ torch_dtype="auto",
14
+ trust_remote_code=True,
15
+ )
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ def tokenize_instructions(tokenizer, instructions):
20
+ return tokenizer.apply_chat_template(
21
+ instructions,
22
+ padding=True,
23
+ truncation=False,
24
+ return_tensors="pt",
25
+ return_dict=True,
26
+ add_generation_prompt=True,
27
+ ).input_ids
28
+
29
+ def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
30
+ device = model.device
31
+ num_its = len(range(0, base_toks.shape[0], batch_size))
32
+ steering_vecs = {}
33
+ for i in tqdm(range(0, base_toks.shape[0], batch_size)):
34
+ base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
35
+ target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
36
+ for layer in range(len(base_out)):
37
+ if i == 0:
38
+ steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
39
+ else:
40
+ steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
41
+ return steering_vecs
42
+
43
+ def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
44
+ def modify_activation():
45
+ def hook(model, input):
46
+ if normalise:
47
+ sv = steering_vec / steering_vec.norm()
48
+ else:
49
+ sv = steering_vec
50
+ if proj:
51
+ sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
52
+ input[0][:,:,:] = input[0][:,:,:] - scale * sv
53
+ return hook
54
+
55
+ handles = []
56
+ if steering_vec is not None:
57
+ for i in range(len(model.model.layers)):
58
+ if layer is None or i == layer:
59
+ handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
60
+
61
+ outs_all = []
62
+ for i in tqdm(range(0, test_toks.shape[0], batch_size)):
63
+ outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
64
+ outs_all.append(outs)
65
+ outs_all = torch.cat(outs_all, dim=0)
66
+
67
+ for handle in handles:
68
+ handle.remove()
69
+
70
+ return outs_all
71
+
72
+ def create_steering_vector(towards, away):
73
+ towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
74
+ away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
75
+
76
+ towards_toks = tokenize_instructions(tokenizer, towards_data)
77
+ away_toks = tokenize_instructions(tokenizer, away_data)
78
+
79
+ steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
80
+ return steering_vecs
81
+
82
+ def chat(message, history, steering_vec, layer):
83
+ history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
84
+ history_formatted.append({"role": "user", "content": message})
85
+
86
+ input_ids = tokenize_instructions(tokenizer, [history_formatted])
87
+
88
+ generations_baseline = do_steering(model, input_ids.to(device), None)
89
+ for j in range(generations_baseline.shape[0]):
90
+ response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"
91
+
92
+ if steering_vec is not None:
93
+ generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
94
+ for j in range(generation_intervene.shape[0]):
95
+ response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"
96
+
97
+ response = response_baseline + "\n\n" + response_intervention
98
+
99
+ return [(message, response)]
100
+
101
+ def launch_app():
102
+ with gr.Blocks() as demo:
103
+ steering_vec = gr.State(None)
104
+ layer = gr.State(None)
105
+
106
+ away_default = [
107
+ "Apples are a popular fruit enjoyed by people around the world.",
108
+ "The apple tree originated in Central Asia and has been cultivated for thousands of years.",
109
+ "There are over 7,500 known cultivars of apples.",
110
+ "Apples are members of the rose family, Rosaceae.",
111
+ # "The science of apple cultivation is called pomology.",
112
+ # "Apple trees typically take 4-5 years to produce their first fruit.",
113
+ # "The phrase 'An apple a day keeps the doctor away' originated in Wales in the 19th century.",
114
+ # "Apples are rich in antioxidants, flavonoids, and dietary fiber.",
115
+ # "The most popular apple variety in the United States is the Gala apple.",
116
+ # "Apple seeds contain a compound called amygdalin, which can release cyanide when digested.",
117
+ # "The apple is the official state fruit of New York.",
118
+ # "Apples can be eaten raw, cooked, or pressed for juice.",
119
+ # "The largest apple ever picked weighed 4 pounds 1 ounce.",
120
+ # "Apples float in water because 25 percent of their volume is air.",
121
+ # "The apple blossom is the state flower of Michigan.",
122
+ # "China is the world's largest producer of apples.",
123
+ # "The average apple tree can produce up to 840 pounds of apples per year.",
124
+ # "Apples ripen six to ten times faster at room temperature than if they are refrigerated.",
125
+ # "The first apple trees in North America were planted by pilgrims in Massachusetts Bay Colony.",
126
+ # "Apples are harvested by hand in orchards."
127
+ ]
128
+
129
+ towards_default = [
130
+ "The United States of America is the world's third-largest country by total area.",
131
+ "America declared its independence from Great Britain on July 4, 1776.",
132
+ "The U.S. Constitution, written in 1787, is the oldest written national constitution still in use.",
133
+ "The United States has 50 states and one federal district, Washington D.C.",
134
+ # "America's national motto is 'In God We Trust,' adopted in 1956.",
135
+ # "The bald eagle is the national bird and symbol of the United States.",
136
+ # "The Statue of Liberty, a gift from France, stands in New York Harbor as a symbol of freedom.",
137
+ # "American culture has had a significant influence on global entertainment and technology.",
138
+ # "The United States is home to many diverse ecosystems, from deserts to tropical rainforests.",
139
+ # "America is often referred to as a 'melting pot' due to its diverse immigrant population.",
140
+ # "The U.S. has the world's largest economy by nominal GDP.",
141
+ # "American football, derived from rugby, is the most popular sport in the United States.",
142
+ # "The Grand Canyon, located in Arizona, is one of America's most famous natural landmarks.",
143
+ # "The U.S. sent the first humans to walk on the moon in 1969.",
144
+ # "America's system of government is a federal republic with a presidential system.",
145
+ # "The American flag, known as the Stars and Stripes, has 13 stripes and 50 stars.",
146
+ # "English is the de facto national language, but the U.S. has no official language at the federal level.",
147
+ # "The United States is home to many world-renowned universities, including Harvard and MIT.",
148
+ # "America's national anthem, 'The Star-Spangled Banner,' was written in 1814.",
149
+ # "The U.S. Interstate Highway System, started in 1956, is one of the largest public works projects in history."
150
+ ]
151
+
152
+ with gr.Row():
153
+ towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
154
+ away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
155
+
156
+ with gr.Row():
157
+ create_vector = gr.Button("Create Steering Vector")
158
+ layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
159
+
160
+ def create_vector_and_set_layer(towards, away, layer_value):
161
+ vectors = create_steering_vector(towards, away)
162
+ layer.value = int(layer_value)
163
+ steering_vec.value = vectors
164
+ return f"Steering vector created for layer {layer_value}"
165
+ create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
166
+
167
+ chatbot = gr.Chatbot()
168
+ msg = gr.Textbox()
169
+
170
+ msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)
171
+
172
+ demo.launch()
173
+
174
+ if __name__ == "__main__":
175
+ launch_app()
176
+
177
+
178
+ # clean up
179
+ # nicer baseline vs intervention
180
+ # auto clear after messgae