Ahren09 commited on
Commit
de4a5dc
·
verified ·
1 Parent(s): cd63fc2

Upload demo.py

Browse files
Files changed (1) hide show
  1. demo.py +11 -198
demo.py CHANGED
@@ -1,208 +1,21 @@
1
  import json
 
2
  import os.path as osp
3
 
4
- import gradio as gr
5
- import numpy as np
6
- import torch
7
 
8
- from llava.mm_utils import get_model_name_from_path
9
- from llava.model.builder import load_pretrained_model
10
- from llava_utils import prompt_wrapper, generator
11
- from utils import normalize, denormalize, load_image
12
 
13
- # BASE_DIR = "/workingdir/soh337/llavaguard"
14
- BASE_DIR = "/Users/ahren/Workspace/Multimodal/llavaguard"
15
- IMAGE_PATH = 'unconstrained_attack_images/adversarial_'
16
- MODEL_PATH = "/workingdir/models_hf/liuhaotian/llava-v1.5-13b"
17
 
18
- TEXT_SAFETY_PATCHES = {
19
- "optimized": "text_patch_optimized",
20
- "heuristic": "text_patch_heuristic"
21
- }
22
 
23
- IMAGE_SAFETY_PATCHES = {
24
- "default": "safety_patch.bmp"
25
- }
26
 
 
27
 
28
- def rtp_read(text_file):
29
- dataset = []
30
- lines = open(text_file).read().split("\n")
31
- for li in lines:
32
- obj = json.loads(li)
33
- if obj['challenging']:
34
- dataset.append(obj['prompt']['text'])
35
- return dataset
36
 
37
-
38
- model = loaded_model_name = None
39
-
40
- print('>>> Initializing Models')
41
-
42
- prompts = rtp_read(osp.join(BASE_DIR, 'harmful_corpus/rtp_prompts.jsonl'))
43
-
44
- out = []
45
-
46
-
47
- def generate_answer(user_message: str, image_path: str, requested_model_name: str,
48
- image_safety_patch_type: str, text_safety_patch_type: str,
49
- ):
50
-
51
- global loaded_model_name
52
-
53
- text_safety_patch = TEXT_SAFETY_PATCHES[text_safety_patch_type]
54
- image_safety_patch = IMAGE_SAFETY_PATCHES[image_safety_patch_type]
55
- if requested_model_name == "LLaVA":
56
-
57
- if requested_model_name == loaded_model_name:
58
-
59
- print(f"{requested_model_name} model already loaded.")
60
-
61
- else:
62
- print(f"Loading {requested_model_name} model ... ")
63
- model_name = get_model_name_from_path(MODEL_PATH)
64
-
65
- tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None,
66
-
67
- model_name)
68
- loaded_model_name = requested_model_name
69
- my_generator = generator.Generator(model=model, tokenizer=tokenizer)
70
-
71
- # load a randomly-sampled unconstrained attack image as Image object
72
- image = load_image(image_path)
73
- # transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
74
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()
75
-
76
- if image_safety_patch != None:
77
- # make the image pixel values between (0,1)
78
- image = normalize(image)
79
- # load the safety patch tensor whose values are (0,1)
80
- safety_patch = torch.load(image_safety_patch).cuda()
81
- # apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
82
- safe_image = denormalize((image + safety_patch).clamp(0, 1))
83
- # make sure the image value is between (0,1)
84
- print(torch.min(image), torch.max(image), torch.min(safe_image), torch.max(safe_image))
85
-
86
- else:
87
- safe_image = image
88
-
89
- model.eval()
90
-
91
- if text_safety_patch != None:
92
- # use the below for optimal text safety patch
93
- # user_message = text_safety_patch + '\n' + user_message
94
- # use the below for heuristic text safety patch
95
- user_message += '\n' + text_safety_patch
96
-
97
- text_prompt_template = prompt_wrapper.prepare_text_prompt(text_prompt % user_message)
98
- print(text_prompt_template)
99
- prompt = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template, device=model.device)
100
-
101
- response = my_generator.generate(prompt, safe_image).replace("[INST]", "").replace("[/INST]", "").replace(
102
- "[SYS]", "").replace("[/SYS/]", "").strip()
103
- if text_safety_patch != None:
104
- response = response.replace(text_safety_patch, "")
105
-
106
- print(" -- continuation: ---")
107
- print(response)
108
- out.append({'prompt': user_message, 'continuation': response})
109
-
110
-
111
- def get_list_of_examples():
112
-
113
- global rtp
114
- examples = []
115
- for i, prompt in enumerate(prompts[:3]): # Use the first 3 prompts for simplicity
116
- image_num = np.random.randint(25) # Randomly select an image number
117
- image_path = f'{IMAGE_PATH}{image_num}.bmp'
118
-
119
- examples.append(
120
- [image_path, prompt]
121
- )
122
-
123
- return examples
124
-
125
-
126
- css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
127
- #header {text-align: center;}
128
- #col-chatbox {flex: 1; max-height: min(750px, 100%);}
129
- #label {font-size: 2em; padding: 0.5em; margin: 0;}
130
- .message {font-size: 1.2em;}
131
- .message-wrap {max-height: min(700px, 100vh);}
132
- """
133
-
134
-
135
- def get_empty_state():
136
- # TODO: Not sure what this means
137
- return gr.State({"arena": None})
138
-
139
-
140
- examples = get_list_of_examples()
141
-
142
-
143
- # Define a function to update inputs based on selected example
144
- def update_inputs(example_id):
145
- selected_example = examples[int(example_id)]
146
- return selected_example['image_path'], selected_example['text']
147
-
148
-
149
-
150
- model_selector, image_patch_selector, text_patch_selector = None, None, None
151
-
152
- def process_text_and_image(user_message: str, image_path: str):
153
- global model_selector, image_patch_selector, text_patch_selector
154
- print(f"User Message: {user_message}")
155
- # print(f"Text Safety Patch: {safety_patch}")
156
- print(f"Image Path: {image_path}")
157
- print(model_selector.value)
158
-
159
- # generate_answer(user_message, image_path, "LLaVA", "heuristic", "default")
160
- generate_answer(user_message, image_path, model_selector.value, image_patch_selector.value, text_patch_selector.value)
161
-
162
-
163
-
164
- with gr.Blocks(css=css) as demo:
165
- state = get_empty_state()
166
- all_components = []
167
-
168
- with gr.Column(elem_id="col-container"):
169
- gr.Markdown(
170
- """# 🦙LLaVAGuard🔥<br>
171
- Safeguarding your Multimodal LLM
172
- **[Project Homepage](#)**""",
173
- elem_id="header",
174
- )
175
-
176
- # example_selector = gr.Dropdown(choices=[f"Example {i}" for i, e in enumerate(examples)],
177
- # label="Select an Example")
178
-
179
-
180
- with gr.Row():
181
- model_selector = gr.Dropdown(choices=["LLaVA"], label="Model", info="Select Model", value="LLaVA")
182
- image_patch_selector = gr.Dropdown(choices=["default"], label="Image Patch", info="Select Image Safety "
183
- "Patch", value="default")
184
- text_patch_selector = gr.Dropdown(choices=["heuristic", "optimized"], label="Text Patch", info="Select "
185
- "Text "
186
- "Safety "
187
- "Patch",
188
- value="heuristic")
189
-
190
- image_and_text_uploader = gr.Interface(
191
- fn=process_text_and_image,
192
- inputs=[gr.Image(type="pil", label="Upload your image", interactive=True),
193
-
194
-
195
- gr.Textbox(placeholder="Input a question", label="Your Question"),
196
- ],
197
- examples=examples,
198
- outputs=['text'])
199
-
200
-
201
-
202
- # # Set the action for the generate button
203
- # @demo.events(generate_button)
204
- # def handle_generation(image, question, model, image_patch, text_patch):
205
- # generate_answer(question, image, model, text_patch, image_patch)
206
-
207
- # Launch the demo
208
- demo.launch()
 
1
  import json
2
+ import os
3
  import os.path as osp
4
 
5
+ from tqdm import tqdm
 
 
6
 
7
+ if __name__ == "__main__":
8
+ ROOT = osp.expanduser("~/Workspace/data/Multimodal")
 
 
9
 
10
+ # construct_esnli_training_data()
11
+ # construct_vqax_training_data()
12
+ # construct_aokvqa_dataset()
 
13
 
 
 
 
 
14
 
 
 
 
15
 
16
+ examples = json.load(open('playground/data/instructions_explainable_dataset.json'))
17
 
18
+ for line in tqdm(examples):
19
+ image_path = f"/workingdir/yjin328/data/{line['image']}"
20
+ assert osp.exists(image_path), image_path
 
 
 
 
 
21