Spaces:
Paused
Paused
import spaces | |
import gradio as gr | |
import numpy as np | |
import os | |
import torch | |
import random | |
import subprocess | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights | |
from PIL import Image | |
from data.data_utils import add_special_tokens, pil_img2rgb | |
from data.transforms import ImageTransform | |
from inferencer import InterleaveInferencer | |
from modeling.autoencoder import load_ae | |
from modeling.bagel.qwen2_navit import NaiveCache | |
from modeling.bagel import ( | |
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, | |
SiglipVisionConfig, SiglipVisionModel | |
) | |
from modeling.qwen2 import Qwen2Tokenizer | |
from huggingface_hub import snapshot_download | |
save_dir = "./model" | |
repo_id = "ByteDance-Seed/BAGEL-7B-MoT" | |
cache_dir = save_dir + "/cache" | |
snapshot_download(cache_dir=cache_dir, | |
local_dir=save_dir, | |
repo_id=repo_id, | |
local_dir_use_symlinks=False, | |
resume_download=True, | |
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"], | |
) | |
# Model Initialization | |
model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT | |
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) | |
llm_config.qk_norm = True | |
llm_config.tie_word_embeddings = False | |
llm_config.layer_module = "Qwen2MoTDecoderLayer" | |
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) | |
vit_config.rope = False | |
vit_config.num_hidden_layers -= 1 | |
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) | |
config = BagelConfig( | |
visual_gen=True, | |
visual_und=True, | |
llm_config=llm_config, | |
vit_config=vit_config, | |
vae_config=vae_config, | |
vit_max_num_patch_per_side=70, | |
connector_act='gelu_pytorch_tanh', | |
latent_patch_size=2, | |
max_latent_size=64, | |
) | |
with init_empty_weights(): | |
language_model = Qwen2ForCausalLM(llm_config) | |
vit_model = SiglipVisionModel(vit_config) | |
model = Bagel(language_model, vit_model, config) | |
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) | |
tokenizer = Qwen2Tokenizer.from_pretrained(model_path) | |
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) | |
vae_transform = ImageTransform(1024, 512, 16) | |
vit_transform = ImageTransform(980, 224, 14) | |
# Model Loading and Multi GPU Infernece Preparing | |
device_map = infer_auto_device_map( | |
model, | |
max_memory={i: "80GiB" for i in range(torch.cuda.device_count())}, | |
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], | |
) | |
same_device_modules = [ | |
'language_model.model.embed_tokens', | |
'time_embedder', | |
'latent_pos_embed', | |
'vae2llm', | |
'llm2vae', | |
'connector', | |
'vit_pos_embed' | |
] | |
if torch.cuda.device_count() == 1: | |
first_device = device_map.get(same_device_modules[0], "cuda:0") | |
for k in same_device_modules: | |
if k in device_map: | |
device_map[k] = first_device | |
else: | |
device_map[k] = "cuda:0" | |
else: | |
first_device = device_map.get(same_device_modules[0]) | |
for k in same_device_modules: | |
if k in device_map: | |
device_map[k] = first_device | |
model = load_checkpoint_and_dispatch( | |
model, | |
checkpoint=os.path.join(model_path, "ema.safetensors"), | |
device_map=device_map, | |
offload_buffers=True, | |
dtype=torch.bfloat16, | |
force_hooks=True, | |
).eval() | |
# Inferencer Preparing | |
inferencer = InterleaveInferencer( | |
model=model, | |
vae_model=vae_model, | |
tokenizer=tokenizer, | |
vae_transform=vae_transform, | |
vit_transform=vit_transform, | |
new_token_ids=new_token_ids, | |
) | |
def set_seed(seed): | |
"""Set random seeds for reproducibility""" | |
if seed > 0: | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
return seed | |
# Text to Image function with thinking option and hyperparameters | |
def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4, | |
timestep_shift=3.0, num_timesteps=50, | |
cfg_renorm_min=1.0, cfg_renorm_type="global", | |
max_think_token_n=1024, do_sample=False, text_temperature=0.3, | |
seed=0, image_ratio="1:1"): | |
# Set seed for reproducibility | |
set_seed(seed) | |
if image_ratio == "1:1": | |
image_shapes = (1024, 1024) | |
elif image_ratio == "4:3": | |
image_shapes = (768, 1024) | |
elif image_ratio == "3:4": | |
image_shapes = (1024, 768) | |
elif image_ratio == "16:9": | |
image_shapes = (576, 1024) | |
elif image_ratio == "9:16": | |
image_shapes = (1024, 576) | |
# Set hyperparameters | |
inference_hyper = dict( | |
max_think_token_n=max_think_token_n if show_thinking else 1024, | |
do_sample=do_sample if show_thinking else False, | |
temperature=text_temperature if show_thinking else 0.3, | |
cfg_text_scale=cfg_text_scale, | |
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 | |
timestep_shift=timestep_shift, | |
num_timesteps=num_timesteps, | |
cfg_renorm_min=cfg_renorm_min, | |
cfg_renorm_type=cfg_renorm_type, | |
image_shapes=image_shapes, | |
) | |
result = {"text": "", "image": None} | |
# Call inferencer with or without think parameter based on user choice | |
for i in inferencer(text=prompt, think=show_thinking, understanding_output=False, **inference_hyper): | |
# print(type(i)) # For debugging stream | |
if type(i) == str: | |
result["text"] += i | |
else: | |
result["image"] = i | |
yield result["image"], result.get("text", "") | |
# Image Understanding function with thinking option and hyperparameters | |
def image_understanding(image: Image.Image, prompt: str, show_thinking=False, | |
do_sample=False, text_temperature=0.3, max_new_tokens=512): | |
if image is None: | |
yield "Please upload an image for understanding." | |
return | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = pil_img2rgb(image) | |
# Set hyperparameters | |
inference_hyper = dict( | |
do_sample=do_sample, | |
temperature=text_temperature, | |
max_think_token_n=max_new_tokens, # Set max_length for text generation | |
) | |
result_text = "" | |
# Use show_thinking parameter to control thinking process | |
for i in inferencer(image=image, text=prompt, think=show_thinking, | |
understanding_output=True, **inference_hyper): | |
if type(i) == str: | |
result_text += i | |
yield result_text | |
# else: This branch seems unused in original, as understanding_output=True typically yields text. | |
# If it yielded image, it would be an intermediate. For final output, it's text. | |
# For now, we assume it only yields text. | |
yield result_text # Ensure final text is yielded | |
# Image Editing function with thinking option and hyperparameters | |
def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0, | |
cfg_img_scale=2.0, cfg_interval=0.0, | |
timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, | |
cfg_renorm_type="text_channel", max_think_token_n=1024, | |
do_sample=False, text_temperature=0.3, seed=0): | |
# Set seed for reproducibility | |
set_seed(seed) | |
if image is None: | |
yield None, "Please upload an image for editing." # Yield tuple for image/text | |
return | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = pil_img2rgb(image) | |
# Set hyperparameters | |
inference_hyper = dict( | |
max_think_token_n=max_think_token_n if show_thinking else 1024, | |
do_sample=do_sample if show_thinking else False, | |
temperature=text_temperature if show_thinking else 0.3, | |
cfg_text_scale=cfg_text_scale, | |
cfg_img_scale=cfg_img_scale, | |
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 | |
timestep_shift=timestep_shift, | |
num_timesteps=num_timesteps, | |
cfg_renorm_min=cfg_renorm_min, | |
cfg_renorm_type=cfg_renorm_type, | |
) | |
# Include thinking parameter based on user choice | |
result = {"text": "", "image": None} | |
for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=False, **inference_hyper): | |
if type(i) == str: | |
result["text"] += i | |
else: | |
result["image"] = i | |
yield result["image"], result.get("text", "") # Yield tuple for image/text | |
# Helper function to load example images | |
def load_example_image(image_path): | |
try: | |
return Image.open(image_path) | |
except Exception as e: | |
print(f"Error loading example image: {e}") | |
return None | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
<div> | |
<img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/> | |
</div> | |
# BAGEL Multimodal Chatbot | |
Interact with BAGEL to generate images from text, edit existing images, or understand image content. | |
""") | |
# Chatbot display area | |
chatbot = gr.Chatbot(label="Chat History", height=500, avatar_images=(None, "https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/BAGEL_favicon.png")) | |
# Input area | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Optional: Upload an Image (for Image Understanding/Edit)", scale=0.5, value=None) | |
with gr.Column(scale=1.5): | |
user_prompt = gr.Textbox(label="Your Message", placeholder="Type your prompt here...", lines=3) | |
with gr.Row(): | |
mode_selector = gr.Radio( | |
choices=["Text to Image", "Image Understanding", "Image Edit"], | |
value="Text to Image", | |
label="Select Mode", | |
interactive=True | |
) | |
submit_btn = gr.Button("Send", variant="primary") | |
# Global/Shared Hyperparameters | |
with gr.Accordion("General Settings & Hyperparameters", open=False) as general_accordion: | |
with gr.Row(): | |
show_thinking_global = gr.Checkbox(label="Show Thinking Process", value=False, info="Enable to see model's intermediate thinking text.") | |
seed_global = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed", info="0 for random seed, positive for reproducible results.") | |
# Container for thinking-specific parameters, visibility controlled by show_thinking_global | |
thinking_params_container = gr.Group(visible=False) | |
with thinking_params_container: | |
gr.Markdown("#### Thinking Process Parameters (affect text generation)") | |
with gr.Row(): | |
common_do_sample = gr.Checkbox(label="Enable Sampling", value=False, info="Enable sampling for text generation (otherwise greedy).") | |
common_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Text Temperature", info="Controls randomness in text generation (higher = more random).") | |
common_max_think_token_n = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Think Tokens / Max New Tokens", info="Maximum number of tokens for thinking (T2I/Edit) or generated text (Understanding).") | |
# T2I Hyperparameters | |
t2i_params_accordion = gr.Accordion("Text to Image Specific Parameters", open=False) | |
with t2i_params_accordion: | |
gr.Markdown("#### Text to Image Parameters") | |
with gr.Row(): | |
t2i_image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"], value="1:1", label="Image Ratio", info="The longer size is fixed to 1024 pixels.") | |
with gr.Row(): | |
t2i_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0 recommended).") | |
t2i_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="CFG Interval", info="Start of Classifier-Free Guidance application interval (end is fixed at 1.0).") | |
with gr.Row(): | |
t2i_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="global", label="CFG Renorm Type", info="Normalization type for CFG. Use 'global' if the generated image is blurry.") | |
t2i_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", info="Minimum value for CFG Renormalization (1.0 disables CFG-Renorm).") | |
with gr.Row(): | |
t2i_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", info="Total denoising steps for image generation.") | |
t2i_timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, label="Timestep Shift", info="Higher values for layout control, lower for fine details.") | |
# Image Edit Hyperparameters | |
edit_params_accordion = gr.Accordion("Image Edit Specific Parameters", open=False) | |
with edit_params_accordion: | |
gr.Markdown("#### Image Edit Parameters") | |
with gr.Row(): | |
edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt for editing.") | |
edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, label="CFG Image Scale", info="Controls how much the model preserves input image details during editing.") | |
with gr.Row(): | |
edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Interval", info="Start of CFG application interval for editing (end is fixed at 1.0).") | |
edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="text_channel", label="CFG Renorm Type", info="Normalization type for CFG during editing. Use 'global' if output is blurry.") | |
with gr.Row(): | |
edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", info="Minimum value for CFG Renormalization during editing (1.0 disables CFG-Renorm).") | |
with gr.Row(): | |
edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", info="Total denoising steps for image editing.") | |
edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="Timestep Shift", info="Higher values for layout control, lower for fine details during editing.") | |
# Main chat processing function | |
# Apply GPU decorator to the combined function | |
def process_chat_message(history, prompt, uploaded_image, mode, | |
show_thinking_global_val, seed_global_val, | |
common_do_sample_val, common_text_temperature_val, common_max_think_token_n_val, | |
t2i_cfg_text_scale_val, t2i_cfg_interval_val, t2i_timestep_shift_val, | |
t2i_num_timesteps_val, t2i_cfg_renorm_min_val, t2i_cfg_renorm_type_val, | |
t2i_image_ratio_val, | |
edit_cfg_text_scale_val, edit_cfg_img_scale_val, edit_cfg_interval_val, | |
edit_timestep_shift_val, edit_num_timesteps_val, edit_cfg_renorm_min_val, | |
edit_cfg_renorm_type_val): | |
# Append user message to history | |
history.append([prompt, None]) | |
# Define common parameters for inference functions | |
common_infer_params = dict( | |
show_thinking=show_thinking_global_val, | |
do_sample=common_do_sample_val, | |
text_temperature=common_text_temperature_val, | |
) | |
try: | |
if mode == "Text to Image": | |
# Add T2I specific parameters, including max_think_token_n and seed | |
t2i_params = { | |
**common_infer_params, | |
"max_think_token_n": common_max_think_token_n_val, | |
"seed": seed_global_val, | |
"cfg_text_scale": t2i_cfg_text_scale_val, | |
"cfg_interval": t2i_cfg_interval_val, | |
"timestep_shift": t2i_timestep_shift_val, | |
"num_timesteps": t2i_num_timesteps_val, | |
"cfg_renorm_min": t2i_cfg_renorm_min_val, | |
"cfg_renorm_type": t2i_cfg_renorm_type_val, | |
"image_ratio": t2i_image_ratio_val, | |
} | |
for img, txt in text_to_image( | |
prompt=prompt, | |
**t2i_params | |
): | |
# For Text to Image, yield image first, then thinking text (if available) | |
if img is not None: | |
history[-1] = [prompt, (img, txt)] | |
elif txt: # Only update text if image is not ready yet | |
history[-1] = [prompt, txt] | |
yield history, gr.update(value="") # Update chatbot and clear input | |
elif mode == "Image Understanding": | |
if uploaded_image is None: | |
history[-1] = [prompt, "Please upload an image for Image Understanding."] | |
yield history, gr.update(value="") | |
return | |
# Add Understanding specific parameters (max_new_tokens maps to common_max_think_token_n) | |
# Note: seed is not used in image_understanding | |
understand_params = { | |
**common_infer_params, | |
"max_new_tokens": common_max_think_token_n_val, | |
} | |
# Remove seed from parameters as it's not used by image_understanding | |
understand_params.pop('seed', None) | |
for txt in image_understanding( | |
image=uploaded_image, | |
prompt=prompt, | |
**understand_params | |
): | |
history[-1] = [prompt, txt] | |
yield history, gr.update(value="") | |
elif mode == "Image Edit": | |
if uploaded_image is None: | |
history[-1] = [prompt, "Please upload an image for Image Editing."] | |
yield history, gr.update(value="") | |
return | |
# Add Edit specific parameters, including max_think_token_n and seed | |
edit_params = { | |
**common_infer_params, | |
"max_think_token_n": common_max_think_token_n_val, | |
"seed": seed_global_val, | |
"cfg_text_scale": edit_cfg_text_scale_val, | |
"cfg_img_scale": edit_cfg_img_scale_val, | |
"cfg_interval": edit_cfg_interval_val, | |
"timestep_shift": edit_timestep_shift_val, | |
"num_timesteps": edit_num_timesteps_val, | |
"cfg_renorm_min": edit_cfg_renorm_min_val, | |
"cfg_renorm_type": edit_cfg_renorm_type_val, | |
} | |
for img, txt in edit_image( | |
image=uploaded_image, | |
prompt=prompt, | |
**edit_params | |
): | |
# For Image Edit, yield image first, then thinking text (if available) | |
if img is not None: | |
history[-1] = [prompt, (img, txt)] | |
elif txt: # Only update text if image is not ready yet | |
history[-1] = [prompt, txt] | |
yield history, gr.update(value="") | |
except Exception as e: | |
history[-1] = [prompt, f"An error occurred: {e}"] | |
yield history, gr.update(value="") # Update history with error and clear input | |
# Event handlers for dynamic UI updates and submission | |
# Control visibility of thinking parameters | |
show_thinking_global.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=[show_thinking_global], | |
outputs=[thinking_params_container] | |
) | |
# Clear image input if mode switches to Text to Image | |
mode_selector.change( | |
fn=lambda mode: gr.update(value=None) if mode == "Text to Image" else gr.update(), | |
inputs=[mode_selector], | |
outputs=[image_input] | |
) | |
# List of all input components whose values are passed to process_chat_message | |
inputs_list = [ | |
chatbot, user_prompt, image_input, mode_selector, | |
show_thinking_global, seed_global, | |
common_do_sample, common_text_temperature, common_max_think_token_n, | |
t2i_cfg_text_scale, t2i_cfg_interval, t2i_timestep_shift, | |
t2i_num_timesteps, t2i_cfg_renorm_min, t2i_cfg_renorm_type, | |
t2i_image_ratio, | |
edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval, | |
edit_timestep_shift, edit_num_timesteps, edit_cfg_renorm_min, | |
edit_cfg_renorm_type | |
] | |
# Link submit button and text input 'Enter' key to the processing function | |
submit_btn.click( | |
fn=process_chat_message, | |
inputs=inputs_list, | |
outputs=[chatbot, user_prompt], | |
scroll_to_output=True, | |
queue=False, # Set to True if long generation times cause issues, but might affect responsiveness | |
) | |
user_prompt.submit( # Allows pressing Enter in textbox to submit | |
fn=process_chat_message, | |
inputs=inputs_list, | |
outputs=[chatbot, user_prompt], | |
scroll_to_output=True, | |
queue=False, | |
) | |
demo.launch() |