File size: 9,142 Bytes
877ea83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from einops import einsum
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'microsoft/Phi-3-mini-4k-instruct'

model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map=device, 
    torch_dtype="auto", 
    trust_remote_code=True, 
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_instructions(tokenizer, instructions):
    return tokenizer.apply_chat_template(
        instructions,
        padding=True,
        truncation=False,
        return_tensors="pt",
        return_dict=True,
        add_generation_prompt=True,
    ).input_ids

def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
    device = model.device
    num_its = len(range(0, base_toks.shape[0], batch_size))
    steering_vecs = {}
    for i in tqdm(range(0, base_toks.shape[0], batch_size)):
        base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
        target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
        for layer in range(len(base_out)):
            if i == 0:
                steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
            else:
                steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
    return steering_vecs

def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
    def modify_activation():
        def hook(model, input):
            if normalise:
                sv = steering_vec / steering_vec.norm()
            else:
                sv = steering_vec
            if proj:
                sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
            input[0][:,:,:] = input[0][:,:,:] - scale * sv
        return hook
    
    handles = []
    if steering_vec is not None:
        for i in range(len(model.model.layers)):
            if layer is None or i == layer:
                handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
    
    outs_all = []
    for i in tqdm(range(0, test_toks.shape[0], batch_size)):
        outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
        outs_all.append(outs)
    outs_all = torch.cat(outs_all, dim=0)
    
    for handle in handles:
        handle.remove()
    
    return outs_all

def create_steering_vector(towards, away):
    towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
    away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
    
    towards_toks = tokenize_instructions(tokenizer, towards_data)
    away_toks = tokenize_instructions(tokenizer, away_data)
    
    steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
    return steering_vecs

def chat(message, history, steering_vec, layer):
    history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
    history_formatted.append({"role": "user", "content": message})
    
    input_ids = tokenize_instructions(tokenizer, [history_formatted])
    
    generations_baseline = do_steering(model, input_ids.to(device), None)
    for j in range(generations_baseline.shape[0]):
        response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"

    if steering_vec is not None:
        generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
        for j in range(generation_intervene.shape[0]):
            response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"

    response = response_baseline + "\n\n" + response_intervention
    
    return [(message, response)]

def launch_app():
    with gr.Blocks() as demo:
        steering_vec = gr.State(None)
        layer = gr.State(None)

        away_default = [
            "Apples are a popular fruit enjoyed by people around the world.",
            "The apple tree originated in Central Asia and has been cultivated for thousands of years.",
            "There are over 7,500 known cultivars of apples.",
            "Apples are members of the rose family, Rosaceae.",
            # "The science of apple cultivation is called pomology.",
            # "Apple trees typically take 4-5 years to produce their first fruit.",
            # "The phrase 'An apple a day keeps the doctor away' originated in Wales in the 19th century.",
            # "Apples are rich in antioxidants, flavonoids, and dietary fiber.",
            # "The most popular apple variety in the United States is the Gala apple.",
            # "Apple seeds contain a compound called amygdalin, which can release cyanide when digested.",
            # "The apple is the official state fruit of New York.",
            # "Apples can be eaten raw, cooked, or pressed for juice.",
            # "The largest apple ever picked weighed 4 pounds 1 ounce.",
            # "Apples float in water because 25 percent of their volume is air.",
            # "The apple blossom is the state flower of Michigan.",
            # "China is the world's largest producer of apples.",
            # "The average apple tree can produce up to 840 pounds of apples per year.",
            # "Apples ripen six to ten times faster at room temperature than if they are refrigerated.",
            # "The first apple trees in North America were planted by pilgrims in Massachusetts Bay Colony.",
            # "Apples are harvested by hand in orchards."
        ]

        towards_default = [
            "The United States of America is the world's third-largest country by total area.",
            "America declared its independence from Great Britain on July 4, 1776.",
            "The U.S. Constitution, written in 1787, is the oldest written national constitution still in use.",
            "The United States has 50 states and one federal district, Washington D.C.",
            # "America's national motto is 'In God We Trust,' adopted in 1956.",
            # "The bald eagle is the national bird and symbol of the United States.",
            # "The Statue of Liberty, a gift from France, stands in New York Harbor as a symbol of freedom.",
            # "American culture has had a significant influence on global entertainment and technology.",
            # "The United States is home to many diverse ecosystems, from deserts to tropical rainforests.",
            # "America is often referred to as a 'melting pot' due to its diverse immigrant population.",
            # "The U.S. has the world's largest economy by nominal GDP.",
            # "American football, derived from rugby, is the most popular sport in the United States.",
            # "The Grand Canyon, located in Arizona, is one of America's most famous natural landmarks.",
            # "The U.S. sent the first humans to walk on the moon in 1969.",
            # "America's system of government is a federal republic with a presidential system.",
            # "The American flag, known as the Stars and Stripes, has 13 stripes and 50 stars.",
            # "English is the de facto national language, but the U.S. has no official language at the federal level.",
            # "The United States is home to many world-renowned universities, including Harvard and MIT.",
            # "America's national anthem, 'The Star-Spangled Banner,' was written in 1814.",
            # "The U.S. Interstate Highway System, started in 1956, is one of the largest public works projects in history."
        ]
        
        with gr.Row():
            towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
            away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
        
        with gr.Row():
            create_vector = gr.Button("Create Steering Vector")
            layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
        
        def create_vector_and_set_layer(towards, away, layer_value):
            vectors = create_steering_vector(towards, away)
            layer.value = int(layer_value)
            steering_vec.value = vectors
            return f"Steering vector created for layer {layer_value}"
        create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
        
        chatbot = gr.Chatbot()
        msg = gr.Textbox()

        msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)

    demo.launch()

if __name__ == "__main__":
    launch_app()


    # clean up
    # nicer baseline vs intervention
    # auto clear after messgae