Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import datetime | |
import hashlib | |
import json | |
import logging | |
import os | |
import sys | |
import time | |
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import ( | |
AutoProcessor, | |
AutoTokenizer, | |
Qwen2_5_VLForConditionalGeneration, | |
LlavaOnevisionForConditionalGeneration | |
) | |
from qwen_vl_utils import process_vision_info | |
from taxonomy import policy_v1 | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("gradio_web_server.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger("gradio_web_server") | |
# Constants | |
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") | |
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True) | |
default_taxonomy = policy_v1 | |
class SimpleConversation: | |
def __init__(self): | |
self.current_prompt = "" | |
self.current_image = None | |
self.current_response = None | |
self.skip_next = False | |
self.messages = [] # Add messages list to store conversation history | |
def set_prompt(self, prompt, image=None): | |
self.current_prompt = prompt | |
self.current_image = image | |
self.current_response = None | |
# Update messages when setting a new prompt | |
self.messages = [[prompt, None]] | |
def set_response(self, response): | |
self.current_response = response | |
# Update the last message's response when setting a response | |
if self.messages and len(self.messages) > 0: | |
self.messages[-1][-1] = response | |
def get_prompt(self): | |
if isinstance(self.current_prompt, tuple): | |
return self.current_prompt[0] | |
return self.current_prompt | |
def get_image(self, return_pil=False): | |
if self.current_image: | |
return [self.current_image] | |
if isinstance(self.current_prompt, tuple) and len(self.current_prompt) > 1: | |
if isinstance(self.current_prompt[1], Image.Image): | |
return [self.current_prompt[1]] | |
return None | |
def to_gradio_chatbot(self): | |
if not self.messages: | |
return [] | |
ret = [] | |
for msg in self.messages: | |
prompt = msg[0] | |
if isinstance(prompt, tuple) and len(prompt) > 0: | |
prompt = prompt[0] | |
if prompt and isinstance(prompt, str) and "<image>" in prompt: | |
prompt = prompt.replace("<image>", "") | |
ret.append([prompt, msg[1]]) | |
return ret | |
def dict(self): | |
# Simplified serialization for logging | |
image_info = "[WITH_IMAGE]" if self.current_image is not None else "[NO_IMAGE]" | |
# Handle prompt which might be a tuple containing an image | |
prompt = self.get_prompt() | |
if isinstance(prompt, tuple): | |
prompt = prompt[0] # Just take the text part | |
# Create JSON-safe message representations | |
safe_messages = [] | |
for msg in self.messages: | |
msg_prompt = msg[0] | |
# Handle tuple prompts that contain images | |
if isinstance(msg_prompt, tuple) and len(msg_prompt) > 0: | |
msg_prompt = msg_prompt[0] # Take just the text part | |
# Add the message with safe values | |
safe_messages.append([msg_prompt, "[RESPONSE]" if msg[1] else None]) | |
return { | |
"prompt": prompt, | |
"image": image_info, | |
"response": self.current_response, | |
"messages": safe_messages | |
} | |
def copy(self): | |
new_conv = SimpleConversation() | |
new_conv.current_prompt = self.current_prompt | |
new_conv.current_image = self.current_image | |
new_conv.current_response = self.current_response | |
new_conv.skip_next = self.skip_next | |
new_conv.messages = self.messages.copy() if self.messages else [] | |
return new_conv | |
default_conversation = SimpleConversation() | |
# Model and processor storage | |
tokenizer = None | |
model = None | |
processor = None | |
context_len = 8048 | |
def wrap_taxonomy(text): | |
"""Wraps user input with taxonomy if not already present""" | |
if policy_v1 not in text: | |
return policy_v1 + "\n\n" + text | |
return text | |
# UI component states | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
# Model loading function | |
def load_model(model_path): | |
global tokenizer, model, processor, context_len | |
logger.info(f"Loading model: {model_path}") | |
try: | |
# Check if it's a Qwen model | |
if "qwenguard" in model_path.lower(): | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
processor = AutoProcessor.from_pretrained(model_path) | |
tokenizer = processor.tokenizer | |
# Otherwise assume it's a LlavaGuard model | |
else: | |
model = LlavaOnevisionForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
context_len = getattr(model.config, "max_position_embeddings", 8048) | |
logger.info(f"Model {model_path} loaded successfully") | |
return # Remove return value to avoid Gradio warnings | |
except Exception as e: | |
logger.error(f"Error loading model {model_path}: {str(e)}") | |
return # Remove return value to avoid Gradio warnings | |
def get_model_list(): | |
models = [ | |
'AIML-TUDA/QwenGuard-v1.2-3B', | |
'AIML-TUDA/QwenGuard-v1.2-7B', | |
'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf', | |
'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf', | |
] | |
return models | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
os.makedirs(os.path.dirname(name), exist_ok=True) | |
return name | |
# Inference function | |
def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512): | |
global model, tokenizer, processor | |
if model is None or processor is None: | |
return "Model not loaded. Please select a model first." | |
try: | |
# Check if it's a Qwen model | |
if isinstance(model, Qwen2_5_VLForConditionalGeneration): | |
# Format for Qwen models | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt} | |
] | |
} | |
] | |
# Process input | |
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text_prompt], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
# Otherwise assume it's a LlavaGuard model | |
else: | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=text_prompt, images=image, return_tensors="pt") | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
do_sample=temperature > 0, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=max_tokens, | |
) | |
# Decode | |
generated_ids_trimmed = generated_ids[0, inputs["input_ids"].shape[1]:] | |
response = processor.decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
# clean_up_tokenization_spaces=False | |
) | |
print('response') | |
print(response) | |
return response.strip() | |
except Exception as e: | |
import traceback | |
error_msg = f"Error during inference: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
logger.error(error_msg) | |
return f"Error processing image. Please try again." | |
# Gradio UI functions | |
get_window_url_params = """ | |
function() { | |
const params = new URLSearchParams(window.location.search); | |
url_params = Object.fromEntries(params); | |
console.log(url_params); | |
return url_params; | |
} | |
""" | |
def load_demo(url_params, request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") | |
models = get_model_list() | |
dropdown_update = gr.Dropdown(visible=True) | |
if "model" in url_params: | |
model = url_params["model"] | |
if model in models: | |
dropdown_update = gr.Dropdown(value=model, visible=True) | |
load_model(model) | |
state = default_conversation.copy() | |
return state, dropdown_update | |
def load_demo_refresh_model_list(request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}") | |
models = get_model_list() | |
state = default_conversation.copy() | |
dropdown_update = gr.Dropdown( | |
choices=models, | |
value=models[0] if len(models) > 0 else "" | |
) | |
return state, dropdown_update | |
def vote_last_response(state, vote_type, model_selector, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"model": model_selector, | |
"state": state.dict(), | |
"ip": request.client.host, | |
} | |
fout.write(json.dumps(data) + "\n") | |
def upvote_last_response(state, model_selector, request: gr.Request): | |
logger.info(f"upvote. ip: {request.client.host}") | |
vote_last_response(state, "upvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response(state, model_selector, request: gr.Request): | |
logger.info(f"downvote. ip: {request.client.host}") | |
vote_last_response(state, "downvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response(state, model_selector, request: gr.Request): | |
logger.info(f"flag. ip: {request.client.host}") | |
vote_last_response(state, "flag", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def regenerate(state, image_process_mode, request: gr.Request): | |
logger.info(f"regenerate. ip: {request.client.host}") | |
if state.messages and len(state.messages) > 0: | |
state.messages[-1][-1] = None | |
if len(state.messages) > 1: | |
prev_human_msg = state.messages[-2] | |
if isinstance(prev_human_msg[0], tuple) and len(prev_human_msg[0]) >= 2: | |
# Handle image process mode for previous message if it's a tuple with image | |
new_msg = list(prev_human_msg) | |
if len(prev_human_msg[0]) >= 3: | |
new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode) | |
state.messages[-2] = new_msg | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def clear_history(request: gr.Request): | |
logger.info(f"clear_history. ip: {request.client.host}") | |
state = default_conversation.copy() | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def add_text(state, text, image, image_process_mode, request: gr.Request): | |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") | |
if len(text) <= 0 or image is None: | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 | |
text = wrap_taxonomy(text) | |
# Reset conversation for new image-based query | |
if image is not None: | |
state = default_conversation.copy() | |
# Set new prompt with image | |
prompt = text | |
if image is not None: | |
prompt = (text, image, image_process_mode) | |
state.set_prompt(prompt=prompt, image=image) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5 | |
def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): | |
start_tstamp = time.time() | |
if state.skip_next: | |
# This generate call is skipped due to invalid inputs | |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 | |
return | |
# Get the prompt and images | |
prompt = state.get_prompt() | |
all_images = state.get_image(return_pil=True) | |
if not all_images: | |
if not state.messages: | |
state.messages = [["Error: No image provided", None]] | |
else: | |
state.messages[-1][-1] = "Error: No image provided" | |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 | |
return | |
# Load model if needed | |
if model is None or model_selector != getattr(model, "_name_or_path", ""): | |
load_model(model_selector) | |
# Run inference | |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens) | |
# Update the response in the conversation state | |
if not state.messages: | |
state.messages = [[prompt, output]] | |
else: | |
state.messages[-1][-1] = output | |
state.current_response = output | |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 | |
finish_tstamp = time.time() | |
logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s") | |
try: | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"type": "chat", | |
"model": model_selector, | |
"start": round(start_tstamp, 4), | |
"finish": round(finish_tstamp, 4), | |
"state": state.dict(), | |
"images": ['image'], | |
"ip": request.client.host, | |
} | |
fout.write(json.dumps(data) + "\n") | |
except Exception as e: | |
logger.error(f"Error writing log: {str(e)}") | |
# UI Components | |
title_markdown = """ | |
# LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment | |
[[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)] | |
[[Code](https://github.com/ml-research/LlavaGuard)] | |
[[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)] | |
[[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)] | |
[[LavaGuard](https://arxiv.org/abs/2406.05113)] | |
""" | |
tos_markdown = """ | |
### Terms of use | |
By using this service, users are required to agree to the following terms: | |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. | |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. | |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. | |
""" | |
learn_more_markdown = """ | |
### License | |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
""" | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10): | |
models = get_model_list() | |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo: | |
state = gr.State() | |
if not embed_mode: | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(elem_id="model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=models, | |
value=models[0] if len(models) > 0 else "", | |
interactive=True, | |
show_label=False, | |
container=False) | |
imagebox = gr.Image(type="pil", label="Image", container=False) | |
image_process_mode = gr.Radio( | |
["Crop", "Resize", "Pad", "Default"], | |
value="Default", | |
label="Preprocess for non-square image", visible=False) | |
if cur_dir is None: | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
gr.Examples(examples=[ | |
[f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if | |
os.path.exists(f"{cur_dir}/examples/image{i}.png") | |
], inputs=imagebox) | |
with gr.Accordion("Parameters", open=False) as parameter_row: | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, | |
label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P") | |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, | |
label="Max output tokens") | |
with gr.Accordion("Safety Risk Taxonomy", open=False): | |
taxonomy_textbox = gr.Textbox( | |
label="Safety Risk Taxonomy", | |
show_label=True, | |
placeholder="Enter your safety policy here", | |
value=default_taxonomy, | |
lines=20) | |
with gr.Column(scale=8): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="LLavaGuard Safety Assessment", | |
height=650, | |
layout="panel", | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox = gr.Textbox( | |
show_label=False, | |
placeholder="Enter your message here", | |
container=True, | |
value=default_taxonomy, | |
lines=3, | |
) | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button(value="Send", variant="primary") | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False) | |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False) | |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False) | |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) | |
clear_btn = gr.Button(value="🗑️ Clear", interactive=False) | |
if not embed_mode: | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
upvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
downvote_btn.click( | |
downvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
flag_btn.click( | |
flag_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
model_selector.change( | |
load_model, | |
[model_selector], | |
None | |
) | |
regenerate_btn.click( | |
regenerate, | |
[state, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list | |
).then( | |
llava_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
concurrency_limit=concurrency_count | |
) | |
clear_btn.click( | |
clear_history, | |
None, | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
) | |
textbox.submit( | |
add_text, | |
[state, textbox, imagebox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
).then( | |
llava_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
concurrency_limit=concurrency_count | |
) | |
submit_btn.click( | |
add_text, | |
[state, textbox, imagebox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list | |
).then( | |
llava_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
concurrency_limit=concurrency_count | |
) | |
demo.load( | |
load_demo_refresh_model_list, | |
None, | |
[state, model_selector], | |
queue=False | |
) | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
parser.add_argument("--concurrency-count", type=int, default=5) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--moderate", action="store_true") | |
parser.add_argument("--embed", action="store_true") | |
args = parser.parse_args() | |
# Create log directory if it doesn't exist | |
os.makedirs(LOGDIR, exist_ok=True) | |
# GPU Check | |
if torch.cuda.is_available(): | |
logger.info(f"CUDA available with {torch.cuda.device_count()} devices") | |
else: | |
logger.warning("CUDA not available! Models will run on CPU which may be very slow.") | |
# Hugging Face token handling | |
api_key = os.getenv("token") | |
if api_key: | |
from huggingface_hub import login | |
login(token=api_key) | |
logger.info("Logged in to Hugging Face Hub") | |
# Launch Gradio app in a subprocess to avoid CUDA initialization in the main process | |
from torch.multiprocessing import Process | |
def launch_demo(): | |
try: | |
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count) | |
demo.queue( | |
status_update_rate=10, | |
api_open=False | |
).launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share | |
) | |
except Exception as e: | |
logger.error(f"Error launching demo: {e}") | |
sys.exit(1) | |
p = Process(target=launch_demo) | |
p.start() | |
p.join() | |