LukasHug commited on
Commit
e570d84
·
verified ·
1 Parent(s): 972192d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/image1.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/image3.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: LlavaGuard
3
- emoji: 🔥
4
- colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: LlavaGuard
3
+ emoji: 👁
4
+ colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import sys
8
+ import time
9
+
10
+ import gradio as gr
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import (
14
+ AutoProcessor,
15
+ AutoTokenizer,
16
+ Qwen2_5_VLForConditionalGeneration,
17
+ LlavaOnevisionForConditionalGeneration
18
+ )
19
+ from qwen_vl_utils import process_vision_info
20
+
21
+ from taxonomy import policy_v1
22
+
23
+ # Set up logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
27
+ handlers=[
28
+ logging.FileHandler("gradio_web_server.log"),
29
+ logging.StreamHandler()
30
+ ]
31
+ )
32
+ logger = logging.getLogger("gradio_web_server")
33
+
34
+ # Constants
35
+ LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
36
+ os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
37
+
38
+ default_taxonomy = policy_v1
39
+
40
+
41
+ class SimpleConversation:
42
+ def __init__(self):
43
+ self.current_prompt = ""
44
+ self.current_image = None
45
+ self.current_response = None
46
+ self.skip_next = False
47
+ self.messages = [] # Add messages list to store conversation history
48
+
49
+ def set_prompt(self, prompt, image=None):
50
+ self.current_prompt = prompt
51
+ self.current_image = image
52
+ self.current_response = None
53
+ # Update messages when setting a new prompt
54
+ self.messages = [[prompt, None]]
55
+
56
+ def set_response(self, response):
57
+ self.current_response = response
58
+ # Update the last message's response when setting a response
59
+ if self.messages and len(self.messages) > 0:
60
+ self.messages[-1][-1] = response
61
+
62
+ def get_prompt(self):
63
+ if isinstance(self.current_prompt, tuple):
64
+ return self.current_prompt[0]
65
+ return self.current_prompt
66
+
67
+ def get_image(self, return_pil=False):
68
+ if self.current_image:
69
+ return [self.current_image]
70
+ if isinstance(self.current_prompt, tuple) and len(self.current_prompt) > 1:
71
+ if isinstance(self.current_prompt[1], Image.Image):
72
+ return [self.current_prompt[1]]
73
+ return None
74
+
75
+ def to_gradio_chatbot(self):
76
+ if not self.messages:
77
+ return []
78
+
79
+ ret = []
80
+ for msg in self.messages:
81
+ prompt = msg[0]
82
+ if isinstance(prompt, tuple) and len(prompt) > 0:
83
+ prompt = prompt[0]
84
+
85
+ if prompt and isinstance(prompt, str) and "<image>" in prompt:
86
+ prompt = prompt.replace("<image>", "")
87
+
88
+ ret.append([prompt, msg[1]])
89
+ return ret
90
+
91
+ def dict(self):
92
+ # Simplified serialization for logging
93
+ image_info = "[WITH_IMAGE]" if self.current_image is not None else "[NO_IMAGE]"
94
+ return {
95
+ "prompt": self.get_prompt(),
96
+ "image": image_info,
97
+ "response": self.current_response,
98
+ "messages": [[m[0], "[RESPONSE]" if m[1] else None] for m in self.messages]
99
+ }
100
+
101
+ def copy(self):
102
+ new_conv = SimpleConversation()
103
+ new_conv.current_prompt = self.current_prompt
104
+ new_conv.current_image = self.current_image
105
+ new_conv.current_response = self.current_response
106
+ new_conv.skip_next = self.skip_next
107
+ new_conv.messages = self.messages.copy() if self.messages else []
108
+ return new_conv
109
+
110
+ default_conversation = SimpleConversation()
111
+
112
+ # Model and processor storage
113
+ tokenizer = None
114
+ model = None
115
+ processor = None
116
+ context_len = 8048
117
+
118
+ def wrap_taxonomy(text):
119
+ """Wraps user input with taxonomy if not already present"""
120
+ if policy_v1 not in text:
121
+ return policy_v1 + "\n\n" + text
122
+ return text
123
+
124
+ # UI component states
125
+ no_change_btn = gr.Button()
126
+ enable_btn = gr.Button(interactive=True)
127
+ disable_btn = gr.Button(interactive=False)
128
+
129
+ # Model loading function
130
+ def load_model(model_path):
131
+ global tokenizer, model, processor, context_len
132
+
133
+ logger.info(f"Loading model: {model_path}")
134
+
135
+ try:
136
+ # Check if it's a Qwen model
137
+ if "qwenguard" in model_path.lower():
138
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
139
+ model_path,
140
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
141
+ torch_dtype="auto",
142
+ device_map="auto" if torch.cuda.is_available() else None
143
+ )
144
+ processor = AutoProcessor.from_pretrained(model_path)
145
+ tokenizer = processor.tokenizer
146
+
147
+ # Otherwise assume it's a LlavaGuard model
148
+ else:
149
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
150
+ model_path,
151
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
152
+ torch_dtype="auto",
153
+ device_map="auto" if torch.cuda.is_available() else None,
154
+ trust_remote_code=True
155
+ )
156
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
157
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
158
+
159
+ context_len = getattr(model.config, "max_position_embeddings", 8048)
160
+ logger.info(f"Model {model_path} loaded successfully")
161
+ return True
162
+
163
+ except Exception as e:
164
+ logger.error(f"Error loading model {model_path}: {str(e)}")
165
+ return False
166
+
167
+ def get_model_list():
168
+ models = [
169
+ 'AIML-TUDA/QwenGuard-v1.2-3B',
170
+ 'AIML-TUDA/QwenGuard-v1.2-7B',
171
+ 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
172
+ 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
173
+ ]
174
+ return models
175
+
176
+ def get_conv_log_filename():
177
+ t = datetime.datetime.now()
178
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
179
+ os.makedirs(os.path.dirname(name), exist_ok=True)
180
+ return name
181
+
182
+ # Inference function
183
+ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
184
+ global model, tokenizer, processor
185
+
186
+ if model is None or processor is None:
187
+ return "Model not loaded. Please select a model first."
188
+ try:
189
+ # Check if it's a Qwen model
190
+ if isinstance(model, Qwen2_5_VLForConditionalGeneration):
191
+ # Format for Qwen models
192
+ messages = [
193
+ {
194
+ "role": "user",
195
+ "content": [
196
+ {"type": "image", "image": image},
197
+ {"type": "text", "text": prompt}
198
+ ]
199
+ }
200
+ ]
201
+ # Process input
202
+ text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
203
+ image_inputs, video_inputs = process_vision_info(messages)
204
+ inputs = processor(
205
+ text=[text_prompt],
206
+ images=image_inputs,
207
+ videos=video_inputs,
208
+ padding=True,
209
+ return_tensors="pt",
210
+ )
211
+ inputs = inputs.to("cuda")
212
+
213
+
214
+ # Otherwise assume it's a LlavaGuard model
215
+ else:
216
+ conversation = [
217
+ {
218
+ "role": "user",
219
+ "content": [
220
+ {"type": "image"},
221
+ {"type": "text", "text": prompt},
222
+ ],
223
+ },
224
+ ]
225
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
226
+ inputs = processor(text=text_prompt, images=image, return_tensors="pt")
227
+
228
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
229
+
230
+ with torch.no_grad():
231
+ generated_ids = model.generate(
232
+ **inputs,
233
+ do_sample=temperature > 0,
234
+ temperature=temperature,
235
+ top_p=top_p,
236
+ max_new_tokens=max_tokens,
237
+ )
238
+
239
+ # Decode
240
+ generated_ids_trimmed = generated_ids[0, inputs["input_ids"].shape[1]:]
241
+ response = processor.decode(
242
+ generated_ids_trimmed,
243
+ skip_special_tokens=True,
244
+ # clean_up_tokenization_spaces=False
245
+ )
246
+ print(response)
247
+
248
+ return response.strip()
249
+
250
+ except Exception as e:
251
+ import traceback
252
+ error_msg = f"Error during inference: {str(e)}\n{traceback.format_exc()}"
253
+ print(error_msg)
254
+ logger.error(error_msg)
255
+ return f"Error processing image. Please try again."
256
+
257
+ # Gradio UI functions
258
+ get_window_url_params = """
259
+ function() {
260
+ const params = new URLSearchParams(window.location.search);
261
+ url_params = Object.fromEntries(params);
262
+ console.log(url_params);
263
+ return url_params;
264
+ }
265
+ """
266
+
267
+ def load_demo(url_params, request: gr.Request):
268
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
269
+ models = get_model_list()
270
+
271
+ dropdown_update = gr.Dropdown(visible=True)
272
+ if "model" in url_params:
273
+ model = url_params["model"]
274
+ if model in models:
275
+ dropdown_update = gr.Dropdown(value=model, visible=True)
276
+ load_model(model)
277
+
278
+ state = default_conversation.copy()
279
+ return state, dropdown_update
280
+
281
+ def load_demo_refresh_model_list(request: gr.Request):
282
+ logger.info(f"load_demo. ip: {request.client.host}")
283
+ models = get_model_list()
284
+ state = default_conversation.copy()
285
+ dropdown_update = gr.Dropdown(
286
+ choices=models,
287
+ value=models[0] if len(models) > 0 else ""
288
+ )
289
+ return state, dropdown_update
290
+
291
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
292
+ with open(get_conv_log_filename(), "a") as fout:
293
+ data = {
294
+ "tstamp": round(time.time(), 4),
295
+ "type": vote_type,
296
+ "model": model_selector,
297
+ "state": state.dict(),
298
+ "ip": request.client.host,
299
+ }
300
+ fout.write(json.dumps(data) + "\n")
301
+
302
+ def upvote_last_response(state, model_selector, request: gr.Request):
303
+ logger.info(f"upvote. ip: {request.client.host}")
304
+ vote_last_response(state, "upvote", model_selector, request)
305
+ return ("",) + (disable_btn,) * 3
306
+
307
+ def downvote_last_response(state, model_selector, request: gr.Request):
308
+ logger.info(f"downvote. ip: {request.client.host}")
309
+ vote_last_response(state, "downvote", model_selector, request)
310
+ return ("",) + (disable_btn,) * 3
311
+
312
+ def flag_last_response(state, model_selector, request: gr.Request):
313
+ logger.info(f"flag. ip: {request.client.host}")
314
+ vote_last_response(state, "flag", model_selector, request)
315
+ return ("",) + (disable_btn,) * 3
316
+
317
+ def regenerate(state, image_process_mode, request: gr.Request):
318
+ logger.info(f"regenerate. ip: {request.client.host}")
319
+ if state.messages and len(state.messages) > 0:
320
+ state.messages[-1][-1] = None
321
+ if len(state.messages) > 1:
322
+ prev_human_msg = state.messages[-2]
323
+ if isinstance(prev_human_msg[0], tuple) and len(prev_human_msg[0]) >= 2:
324
+ # Handle image process mode for previous message if it's a tuple with image
325
+ new_msg = list(prev_human_msg)
326
+ if len(prev_human_msg[0]) >= 3:
327
+ new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
328
+ state.messages[-2] = new_msg
329
+
330
+ state.skip_next = False
331
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
332
+
333
+ def clear_history(request: gr.Request):
334
+ logger.info(f"clear_history. ip: {request.client.host}")
335
+ state = default_conversation.copy()
336
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
337
+
338
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
339
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
340
+ if len(text) <= 0 or image is None:
341
+ state.skip_next = True
342
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
343
+
344
+ text = wrap_taxonomy(text)
345
+
346
+ # Reset conversation for new image-based query
347
+ if image is not None:
348
+ state = default_conversation.copy()
349
+
350
+ # Set new prompt with image
351
+ prompt = text
352
+ if image is not None:
353
+ prompt = (text, image, image_process_mode)
354
+
355
+ state.set_prompt(prompt=prompt, image=image)
356
+ state.skip_next = False
357
+
358
+ return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
359
+
360
+ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
361
+ start_tstamp = time.time()
362
+
363
+ if state.skip_next:
364
+ # This generate call is skipped due to invalid inputs
365
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
366
+ return
367
+
368
+ # Get the prompt and images
369
+ prompt = state.get_prompt()
370
+ all_images = state.get_image(return_pil=True)
371
+
372
+ if not all_images:
373
+ if not state.messages:
374
+ state.messages = [["Error: No image provided", None]]
375
+ else:
376
+ state.messages[-1][-1] = "Error: No image provided"
377
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
378
+ return
379
+
380
+ # Load model if needed
381
+ if model is None or model_selector != getattr(model, "_name_or_path", ""):
382
+ load_model(model_selector)
383
+
384
+ # Run inference
385
+ output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
386
+
387
+ # Update the response in the conversation state
388
+ if not state.messages:
389
+ state.messages = [[prompt, output]]
390
+ else:
391
+ state.messages[-1][-1] = output
392
+ state.current_response = output
393
+
394
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
395
+
396
+ finish_tstamp = time.time()
397
+ logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
398
+
399
+ try:
400
+ with open(get_conv_log_filename(), "a") as fout:
401
+ data = {
402
+ "tstamp": round(finish_tstamp, 4),
403
+ "type": "chat",
404
+ "model": model_selector,
405
+ "start": round(start_tstamp, 4),
406
+ "finish": round(finish_tstamp, 4),
407
+ "state": state.dict(),
408
+ "images": ['image'],
409
+ "ip": request.client.host,
410
+ }
411
+ fout.write(json.dumps(data) + "\n")
412
+ except Exception as e:
413
+ logger.error(f"Error writing log: {str(e)}")
414
+
415
+ # UI Components
416
+ title_markdown = """
417
+ # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
418
+ [[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
419
+ [[Code](https://github.com/ml-research/LlavaGuard)]
420
+ [[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
421
+ [[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
422
+ [[LavaGuard](https://arxiv.org/abs/2406.05113)]
423
+ """
424
+
425
+ tos_markdown = """
426
+ ### Terms of use
427
+ By using this service, users are required to agree to the following terms:
428
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
429
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
430
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
431
+ """
432
+
433
+ learn_more_markdown = """
434
+ ### License
435
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
436
+ """
437
+
438
+ block_css = """
439
+ #buttons button {
440
+ min-width: min(120px,100%);
441
+ }
442
+ """
443
+
444
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
445
+ models = get_model_list()
446
+
447
+ with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
448
+ state = gr.State()
449
+
450
+ if not embed_mode:
451
+ gr.Markdown(title_markdown)
452
+
453
+ with gr.Row():
454
+ with gr.Column(scale=3):
455
+ with gr.Row(elem_id="model_selector_row"):
456
+ model_selector = gr.Dropdown(
457
+ choices=models,
458
+ value=models[0] if len(models) > 0 else "",
459
+ interactive=True,
460
+ show_label=False,
461
+ container=False)
462
+
463
+ imagebox = gr.Image(type="pil", label="Image", container=False)
464
+ image_process_mode = gr.Radio(
465
+ ["Crop", "Resize", "Pad", "Default"],
466
+ value="Default",
467
+ label="Preprocess for non-square image", visible=False)
468
+
469
+ if cur_dir is None:
470
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
471
+
472
+ gr.Examples(examples=[
473
+ [f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if os.path.exists(f"{cur_dir}/examples/image{i}.png")
474
+ ], inputs=imagebox)
475
+
476
+ with gr.Accordion("Parameters", open=False) as parameter_row:
477
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
478
+ label="Temperature")
479
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
480
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
481
+ label="Max output tokens")
482
+
483
+ with gr.Accordion("Safety Risk Taxonomy", open=False):
484
+ taxonomy_textbox = gr.Textbox(
485
+ label="Safety Risk Taxonomy",
486
+ show_label=True,
487
+ placeholder="Enter your safety policy here",
488
+ value=default_taxonomy,
489
+ lines=20)
490
+
491
+ with gr.Column(scale=8):
492
+ chatbot = gr.Chatbot(
493
+ elem_id="chatbot",
494
+ label="LLavaGuard Safety Assessment",
495
+ height=650,
496
+ layout="panel",
497
+ )
498
+ with gr.Row():
499
+ with gr.Column(scale=8):
500
+ textbox = gr.Textbox(
501
+ show_label=False,
502
+ placeholder="Enter your message here",
503
+ container=True,
504
+ value=default_taxonomy,
505
+ lines=3,
506
+ )
507
+ with gr.Column(scale=1, min_width=50):
508
+ submit_btn = gr.Button(value="Send", variant="primary")
509
+ with gr.Row(elem_id="buttons") as button_row:
510
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
511
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
512
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
513
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
514
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
515
+
516
+ if not embed_mode:
517
+ gr.Markdown(tos_markdown)
518
+ gr.Markdown(learn_more_markdown)
519
+ url_params = gr.JSON(visible=False)
520
+
521
+ # Register listeners
522
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
523
+
524
+ upvote_btn.click(
525
+ upvote_last_response,
526
+ [state, model_selector],
527
+ [textbox, upvote_btn, downvote_btn, flag_btn]
528
+ )
529
+
530
+ downvote_btn.click(
531
+ downvote_last_response,
532
+ [state, model_selector],
533
+ [textbox, upvote_btn, downvote_btn, flag_btn]
534
+ )
535
+
536
+ flag_btn.click(
537
+ flag_last_response,
538
+ [state, model_selector],
539
+ [textbox, upvote_btn, downvote_btn, flag_btn]
540
+ )
541
+
542
+ model_selector.change(
543
+ load_model,
544
+ [model_selector],
545
+ None
546
+ )
547
+
548
+ regenerate_btn.click(
549
+ regenerate,
550
+ [state, image_process_mode],
551
+ [state, chatbot, textbox, imagebox] + btn_list
552
+ ).then(
553
+ llava_bot,
554
+ [state, model_selector, temperature, top_p, max_output_tokens],
555
+ [state, chatbot] + btn_list,
556
+ concurrency_limit=concurrency_count
557
+ )
558
+
559
+ clear_btn.click(
560
+ clear_history,
561
+ None,
562
+ [state, chatbot, textbox, imagebox] + btn_list,
563
+ queue=False
564
+ )
565
+
566
+ textbox.submit(
567
+ add_text,
568
+ [state, textbox, imagebox, image_process_mode],
569
+ [state, chatbot, textbox, imagebox] + btn_list,
570
+ queue=False
571
+ ).then(
572
+ llava_bot,
573
+ [state, model_selector, temperature, top_p, max_output_tokens],
574
+ [state, chatbot] + btn_list,
575
+ concurrency_limit=concurrency_count
576
+ )
577
+
578
+ submit_btn.click(
579
+ add_text,
580
+ [state, textbox, imagebox, image_process_mode],
581
+ [state, chatbot, textbox, imagebox] + btn_list
582
+ ).then(
583
+ llava_bot,
584
+ [state, model_selector, temperature, top_p, max_output_tokens],
585
+ [state, chatbot] + btn_list,
586
+ concurrency_limit=concurrency_count
587
+ )
588
+
589
+ demo.load(
590
+ load_demo_refresh_model_list,
591
+ None,
592
+ [state, model_selector],
593
+ queue=False
594
+ )
595
+
596
+ return demo
597
+
598
+
599
+ if __name__ == "__main__":
600
+ parser = argparse.ArgumentParser()
601
+ parser.add_argument("--host", type=str, default="0.0.0.0")
602
+ parser.add_argument("--port", type=int)
603
+ parser.add_argument("--concurrency-count", type=int, default=5)
604
+ parser.add_argument("--share", action="store_true")
605
+ parser.add_argument("--moderate", action="store_true")
606
+ parser.add_argument("--embed", action="store_true")
607
+ args = parser.parse_args()
608
+
609
+ # Create log directory if it doesn't exist
610
+ os.makedirs(LOGDIR, exist_ok=True)
611
+
612
+ # GPU Check
613
+ if torch.cuda.is_available():
614
+ logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
615
+ else:
616
+ logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
617
+
618
+ # Hugging Face token handling
619
+ api_key = os.getenv("token")
620
+ if api_key:
621
+ from huggingface_hub import login
622
+ login(token=api_key)
623
+ logger.info("Logged in to Hugging Face Hub")
624
+
625
+ # Load initial model
626
+ models = get_model_list()
627
+ model_path = os.getenv("model", models[0])
628
+ logger.info(f"Initial model selected: {model_path}")
629
+ load_model(model_path)
630
+
631
+ # Launch Gradio app
632
+ try:
633
+ demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
634
+ demo.queue(
635
+ status_update_rate=10,
636
+ api_open=False
637
+ ).launch(
638
+ server_name=args.host,
639
+ server_port=args.port,
640
+ share=args.share
641
+ )
642
+ except Exception as e:
643
+ logger.error(f"Error launching demo: {e}")
644
+ sys.exit(1)
645
+
examples/image1.png ADDED

Git LFS Details

  • SHA256: 175d7e4de57e3dfebdf64164509f4dbef027e34d514bfd432a823659698e5e6b
  • Pointer size: 131 Bytes
  • Size of remote file: 750 kB
examples/image2.png ADDED
examples/image3.png ADDED

Git LFS Details

  • SHA256: 07ca2140c4b506486bfd39f9370ea97803d0ec972f6e19ba37f2a6342ce2bbec
  • Pointer size: 131 Bytes
  • Size of remote file: 973 kB
examples/image4.png ADDED
examples/image5.png ADDED
examples/image6.png ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ gradio
3
+ qwen_vl_utils
4
+ torch
5
+ accelerate
taxonomy.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ policy_v1 = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
2
+ O1: Hate, Humiliation, Harassment:
3
+ Should not:
4
+ - Contain content that expresses hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
5
+ - Contain content that encourages discrimination or harassment against people based on those characteristics.
6
+ Can:
7
+ - Contain educational content that provides information on discrimination, humiliation, and harassment.
8
+ - Discuss historical events involving hate and harassment.
9
+
10
+ O2: Violence, Harm, or Cruelty:
11
+ Should not:
12
+ - Contain content of people planning or engaging in violence.
13
+ - Contain content that encourages people to engage in violence or cruelty, including harm to others.
14
+ Can:
15
+ - Contain educational content that provides information on violence, harm, and cruelty.
16
+ - Discuss historical events involving violence and cruelty.
17
+
18
+ O3: Sexual Content:
19
+ Should not:
20
+ - Contain sexually explicit (i.e., erotic) content.
21
+ Can:
22
+ - Contain content that is related to sexuality and sexual education.
23
+ - Provide general information regarding sexual activities.
24
+ - Provide guidance on topics related to sexual health.
25
+ - Contain content of people in a loving and affectionate manner if it is not sexually explicit content.
26
+
27
+ O4: Nudity Content:
28
+ Should not:
29
+ - Contain visible genitalia (e.g., exposed breast, nipple, vagina, vulva, or penis) or explicit nudity content.
30
+ Can:
31
+ - Contain artistic nudity if the content does not include visible genitalia.
32
+ - Contain educational content and general information on the human body.
33
+ - Provide guidance on topics related to sexual health.
34
+
35
+ O5: Criminal Planning:
36
+ Should not:
37
+ - Contain content of people planning or engaging in criminal activities like theft, kidnapping, or financial crimes.
38
+ - Contain content that encourages people to engage in criminal activities or condone criminal behavior.
39
+ Can:
40
+ - Contain content that helps people understand how crimes are committed, provided that the content is not so detailed that it could be used to provide a roadmap for future crimes.
41
+ - Contain content that educates on crime prevention and understanding criminal behavior without providing guidance for criminal activities.
42
+
43
+ O6: Weapons or Substance Abuse:
44
+ Should not:
45
+ - Contain content of illegal weapons with the intent to promote violence or harm.
46
+ - Contain content of trafficking, consumption, or use of illegal drugs.
47
+ - Promote the use of tobacco, alcohol, and cannabis.
48
+ - Contain content about the creation or production of illegal substances.
49
+ Can:
50
+ - Contain weapons used in a legal and safe context without any indication of violence or intent to harm.
51
+ - Discussions on firearm ownership and educational content on regulated substances without promoting illegal activities.
52
+ - Contain content that provides factual information about legal drugs, including their effects and chemical composition, without promoting their use.
53
+ - Contain content that offers educational insights into the manufacturing processes of substances, ensuring that such information is general and does not include detailed instructions that could facilitate their production.
54
+
55
+ O7: Self-Harm:
56
+ Should not:
57
+ - Contain content that could encourage people to harm themselves.
58
+ - Contain content that romanticizes or condones self-harm or suicide.
59
+ Can:
60
+ - Contain educational content that informs about the prevalence, causes, and effects of self-harm and suicide, such as statistical data and discussions on mental health challenges.
61
+ - Contain resources and information that guide individuals toward professional help and support services for mental health, self-harm, and suicide prevention.
62
+
63
+ O8: Animal Cruelty:
64
+ Should not:
65
+ - Contain content that could encourage people to harm animals.
66
+ - Contain content that romanticizes or condones harming animals.
67
+ - Contain explicit animal cruelty.
68
+ Can:
69
+ - Contain content that provides general information regarding animal cruelty.
70
+
71
+ O9: Disasters or Emergencies:
72
+ Should not:
73
+ - Contain content about individuals who are deceased or have suffered harm as a result of disasters or emergencies.
74
+ - Contain content about suffering, cruelty, or exploitation of individuals during such events.
75
+ Can:
76
+ - Contain respectful news coverage and reports on emergencies, prioritizing informational value over sensationalism.
77
+ - Contain educational content that provides factual information on historical and current disasters or emergencies.
78
+
79
+ Assessment Steps:
80
+ 1. Select a safety "rating". This should be "Unsafe" if the content violates the safety policy; otherwise, it should be "Safe".
81
+ 2. Assign a policy "category". If none of the policy categories are applicable, select "NA: None applying".
82
+ 3. Provide a "rationale". Describe the user content and justify why it is considered safe or unsafe, referring to the specific policy category and its associated guidelines to illustrate any violations or compliance.
83
+ To provide your assessment use the following json template:
84
+ {
85
+ "rating": str<"Unsafe"|"Safe">,
86
+ "category": str<"O1: Hate, Humiliation, Harassment"|"O2: Violence, Harm, or Cruelty"|"O3: Sexual Content"|"O4: Nudity Content"|"O5: Criminal Planning"|"O6: Weapons or Substance Abuse"|"O7: Self-Harm"|"O8: Animal Cruelty"|"O9: Disasters or Emergencies"|"NA: None applying">,
87
+ "rationale": str,
88
+ }
89
+ """