Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
from PIL import Image, ImageDraw, ImageFont | |
# from src.condition import Condition | |
from diffusers.pipelines import FluxPipeline | |
import numpy as np | |
import requests | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
import torch.multiprocessing as mp | |
### | |
import argparse | |
import logging | |
import math | |
import os | |
import re | |
import random | |
import shutil | |
from contextlib import nullcontext | |
from pathlib import Path | |
from PIL import Image | |
import accelerate | |
import datasets | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
import torch.utils.checkpoint | |
import transformers | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.state import AcceleratorState | |
from accelerate.utils import ProjectConfiguration, set_seed | |
from huggingface_hub import create_repo, upload_folder | |
from packaging import version | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from transformers.utils import ContextManagers | |
from omegaconf import OmegaConf | |
from copy import deepcopy | |
import diffusers | |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline | |
from diffusers.optimization import get_scheduler | |
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr | |
from diffusers.utils import check_min_version, deprecate, make_image_grid | |
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | |
from diffusers.utils.import_utils import is_xformers_available | |
from diffusers.utils.torch_utils import is_compiled_module | |
from einops import rearrange | |
from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack | |
from src.flux.util import (configs, load_ae, load_clip, | |
load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint) | |
from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel | |
from src.flux.xflux_pipeline import XFluxSampler | |
from image_datasets.dataset import loader, eval_image_pair_loader, image_resize | |
from safetensors.torch import load_file | |
import json | |
# logger = get_logger(__name__, log_level="INFO") | |
def get_models(name: str, device, offload: bool, is_schnell: bool): | |
t5 = load_t5(device, max_length=256 if is_schnell else 512) | |
clip = load_clip(device) | |
clip.requires_grad_(False) | |
model = load_flow_model2(name, device="cpu") | |
vae = load_ae(name, device="cpu" if offload else device) | |
return model, vae, t5, clip | |
args = OmegaConf.load("inference_configs/inference.yaml") #OmegaConf.load(parse_args()) | |
is_schnell = args.model_name == "flux-schnell" | |
set_seed(args.seed) | |
# logging_dir = os.path.join(args.output_dir, args.logging_dir) | |
device = "cuda" | |
dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) | |
# # load image encoder | |
# ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to( | |
# # accelerator.device, dtype=torch.bfloat16 | |
# device, dtype=torch.bfloat16 | |
# ) | |
# ip_clip_image_processor = CLIPImageProcessor() | |
if args.use_ip: | |
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) | |
elif args.use_spatial_condition: | |
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) | |
else: | |
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) | |
# @spaces.GPU | |
def generate(image, edit_prompt): | |
print("hello?????????!!!!!") | |
# accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | |
# accelerator = Accelerator( | |
# gradient_accumulation_steps=1, | |
# mixed_precision=args.mixed_precision, | |
# log_with=args.report_to, | |
# project_config=accelerator_project_config, | |
# ) | |
# Make one log on every process with the configuration for debugging. | |
# logging.basicConfig( | |
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
# datefmt="%m/%d/%Y %H:%M:%S", | |
# level=logging.INFO, | |
# ) | |
# logger.info(accelerator.state, main_process_only=False) | |
# if accelerator.is_local_main_process: | |
# datasets.utils.logging.set_verbosity_warning() | |
# transformers.utils.logging.set_verbosity_warning() | |
# diffusers.utils.logging.set_verbosity_info() | |
# else: | |
# datasets.utils.logging.set_verbosity_error() | |
# transformers.utils.logging.set_verbosity_error() | |
# diffusers.utils.logging.set_verbosity_error() | |
# if accelerator.is_main_process: | |
# if args.output_dir is not None: | |
# os.makedirs(args.output_dir, exist_ok=True) | |
# gpt_eval_path = os.path.join(args.output_dir,"Eval") | |
# os.makedirs(gpt_eval_path, exist_ok=True) | |
# dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) | |
# dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) | |
if args.use_lora: | |
lora_attn_procs = {} | |
if args.use_ip: | |
ip_attn_procs = {} | |
if args.double_blocks is None: | |
double_blocks_idx = list(range(19)) | |
else: | |
double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")] | |
if args.single_blocks is None: | |
single_blocks_idx = list(range(38)) | |
elif args.single_blocks is not None: | |
single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")] | |
if args.use_lora: | |
for name, attn_processor in dit.attn_processors.items(): | |
match = re.search(r'\.(\d+)\.', name) | |
if match: | |
layer_index = int(match.group(1)) | |
if name.startswith("double_blocks") and layer_index in double_blocks_idx: | |
# if accelerator.is_main_process: | |
# print("setting LoRA Processor for", name) | |
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( | |
dim=3072, rank=args.rank | |
) | |
elif name.startswith("single_blocks") and layer_index in single_blocks_idx: | |
# if accelerator.is_main_process: | |
# print("setting LoRA Processor for", name) | |
lora_attn_procs[name] = SingleStreamBlockLoraProcessor( | |
dim=3072, rank=args.rank | |
) | |
else: | |
lora_attn_procs[name] = attn_processor | |
dit.set_attn_processor(lora_attn_procs) | |
# if args.use_ip: | |
# # unpack checkpoint | |
# checkpoint = load_checkpoint(args.ip_local_path, args.ip_repo_id, args.ip_name) | |
# prefix = "double_blocks." | |
# # blocks = {} | |
# proj = {} | |
# for key, value in checkpoint.items(): | |
# # if key.startswith(prefix): | |
# # blocks[key[len(prefix):].replace('.processor.', '.')] = value | |
# if key.startswith("ip_adapter_proj_model"): | |
# proj[key[len("ip_adapter_proj_model."):]] = value | |
# # # load image encoder | |
# # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to( | |
# # # accelerator.device, dtype=torch.bfloat16 | |
# # device, dtype=torch.bfloat16 | |
# # ) | |
# # ip_clip_image_processor = CLIPImageProcessor() | |
# # setup image embedding projection model | |
# ip_improj = ImageProjModel(4096, 768, 4) | |
# ip_improj.load_state_dict(proj) | |
# # ip_improj = ip_improj.to(accelerator.device, dtype=torch.bfloat16) | |
# ip_improj = ip_improj.to(device, dtype=torch.bfloat16) | |
# ip_attn_procs = {} | |
# for name, _ in dit.attn_processors.items(): | |
# ip_state_dict = {} | |
# for k in checkpoint.keys(): | |
# if name in k: | |
# ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] | |
# if ip_state_dict: | |
# ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) | |
# ip_attn_procs[name].load_state_dict(ip_state_dict) | |
# ip_attn_procs[name].to(accelerator.device, dtype=torch.bfloat16) | |
# else: | |
# ip_attn_procs[name] = dit.attn_processors[name] | |
# dit.set_attn_processor(ip_attn_procs) | |
vae.requires_grad_(False) | |
t5.requires_grad_(False) | |
clip.requires_grad_(False) | |
# weight_dtype = torch.float32 | |
# if accelerator.mixed_precision == "fp16": | |
# weight_dtype = torch.float16 | |
# args.mixed_precision = accelerator.mixed_precision | |
# elif accelerator.mixed_precision == "bf16": | |
# weight_dtype = torch.bfloat16 | |
# args.mixed_precision = accelerator.mixed_precision | |
# print(f"Resuming from checkpoint {args.ckpt_dir}") | |
# dit_stat_dict = load_file(args.ckpt_dir) | |
# Get path from Hub | |
model_path = hf_hub_download( | |
repo_id="Boese0601/ByteMorpher", | |
filename="dit.safetensors" | |
) | |
state_dict = load_file(model_path) | |
dit.load_state_dict(state_dict) | |
dit = dit.to(weight_dtype) | |
dit.eval() | |
# test_dataloader = loader(**args.data_config) | |
test_dataloader = eval_image_pair_loader(**args.data_config) | |
# from deepspeed import initialize | |
dit = accelerator.prepare(dit) | |
# if accelerator.is_main_process: | |
# accelerator.init_trackers(args.tracker_project_name, {"test": None}) | |
# logger.info("***** Running Evaluation *****") | |
# logger.info(f" Instantaneous batch size = {args.eval_batch_size}") | |
# progress_bar = tqdm( | |
# range(0, len(test_dataloader)), | |
# initial=0, | |
# desc="Steps", | |
# disable=not accelerator.is_local_main_process, | |
# ) | |
# for step, batch in enumerate(test_dataloader): | |
# with accelerator.accumulate(dit): | |
# img, tgt_image, prompt, edit_prompt, img_name, edit_name = batch | |
img = image_resize(image, 512) | |
w, h = img.size | |
new_w = (w // 32) * 32 | |
new_h = (h // 32) * 32 | |
img = img.resize((new_w, new_h)) | |
img = torch.from_numpy((np.array(img) / 127.5) - 1) | |
img = img.permute(2, 0, 1).unsqueeze(0) | |
edit_prompt = edit_prompt | |
# if args.use_ip: | |
# sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) | |
# elif args.use_spatial_condition: | |
# sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) | |
# else: | |
# sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) | |
with torch.no_grad(): | |
result = sampler(prompt=edit_prompt, | |
width=args.sample_width, | |
height=args.sample_height, | |
num_steps=args.sample_steps, | |
image_prompt=None, # ip_adapter | |
true_gs=args.cfg_scale, | |
seed=args.seed, | |
ip_scale=args.ip_scale if args.use_ip else 1.0, | |
source_image=img if args.use_spatial_condition else None, | |
) | |
gen_img = result | |
# progress_bar.update(1) | |
# accelerator.wait_for_everyone() | |
# accelerator.end_training() | |
return gen_img | |
def get_samples(): | |
sample_list = [ | |
{ | |
"image": "assets/0_camera_zoom/20486354.png", | |
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", | |
}, | |
] | |
return [ | |
[ | |
Image.open(sample["image"]).resize((512, 512)), | |
sample["edit_prompt"], | |
] | |
for sample in sample_list | |
] | |
header = """ | |
# ByteMoprh | |
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
<a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
<a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
</div> | |
""" | |
def create_app(): | |
with gr.Blocks() as app: | |
gr.Markdown(header, elem_id="header") | |
with gr.Row(equal_height=False): | |
with gr.Column(variant="panel", elem_classes="inputPanel"): | |
original_image = gr.Image( | |
type="pil", label="Condition Image", width=300, elem_id="input" | |
) | |
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") | |
submit_btn = gr.Button("Run", elem_id="submit_btn") | |
with gr.Column(variant="panel", elem_classes="outputPanel"): | |
output_image = gr.Image(type="pil", elem_id="output") | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=get_samples(), | |
inputs=[original_image, edit_prompt], | |
label="Examples", | |
) | |
submit_btn.click( | |
fn=generate, | |
inputs=[original_image, edit_prompt], | |
outputs=output_image, | |
) | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>. | |
</div> | |
""" | |
) | |
return app | |
if __name__ == "__main__": | |
print("CUDA available:", torch.cuda.is_available()) | |
print("CUDA version:", torch.version.cuda) | |
print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None") | |
# mp.set_start_method("spawn", force=True) | |
create_app().launch(debug=False, share=True, ssr_mode=False) | |