import gradio as gr import torch import spaces from PIL import Image, ImageDraw, ImageFont # from src.condition import Condition from diffusers.pipelines import FluxPipeline import numpy as np import requests from huggingface_hub import hf_hub_download from safetensors.torch import load_file import torch.multiprocessing as mp ### import argparse import logging import math import os import re import random import shutil from contextlib import nullcontext from pathlib import Path from PIL import Image import accelerate import datasets import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from transformers.utils import ContextManagers from omegaconf import OmegaConf from copy import deepcopy import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr from diffusers.utils import check_min_version, deprecate, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module from einops import rearrange from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack from src.flux.util import (configs, load_ae, load_clip, load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint) from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel from src.flux.xflux_pipeline import XFluxSampler from image_datasets.dataset import loader, eval_image_pair_loader, image_resize from safetensors.torch import load_file import json # logger = get_logger(__name__, log_level="INFO") def get_models(name: str, device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) clip.requires_grad_(False) model = load_flow_model2(name, device="cpu") vae = load_ae(name, device="cpu" if offload else device) return model, vae, t5, clip args = OmegaConf.load("inference_configs/inference.yaml") #OmegaConf.load(parse_args()) is_schnell = args.model_name == "flux-schnell" set_seed(args.seed) # logging_dir = os.path.join(args.output_dir, args.logging_dir) device = "cuda" dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) # # load image encoder # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to( # # accelerator.device, dtype=torch.bfloat16 # device, dtype=torch.bfloat16 # ) # ip_clip_image_processor = CLIPImageProcessor() if args.use_ip: sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) elif args.use_spatial_condition: sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) else: sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) # @spaces.GPU def generate(image, edit_prompt): print("hello?????????!!!!!") # accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) # accelerator = Accelerator( # gradient_accumulation_steps=1, # mixed_precision=args.mixed_precision, # log_with=args.report_to, # project_config=accelerator_project_config, # ) # Make one log on every process with the configuration for debugging. # logging.basicConfig( # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", # datefmt="%m/%d/%Y %H:%M:%S", # level=logging.INFO, # ) # logger.info(accelerator.state, main_process_only=False) # if accelerator.is_local_main_process: # datasets.utils.logging.set_verbosity_warning() # transformers.utils.logging.set_verbosity_warning() # diffusers.utils.logging.set_verbosity_info() # else: # datasets.utils.logging.set_verbosity_error() # transformers.utils.logging.set_verbosity_error() # diffusers.utils.logging.set_verbosity_error() # if accelerator.is_main_process: # if args.output_dir is not None: # os.makedirs(args.output_dir, exist_ok=True) # gpt_eval_path = os.path.join(args.output_dir,"Eval") # os.makedirs(gpt_eval_path, exist_ok=True) # dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) # dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) if args.use_lora: lora_attn_procs = {} if args.use_ip: ip_attn_procs = {} if args.double_blocks is None: double_blocks_idx = list(range(19)) else: double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")] if args.single_blocks is None: single_blocks_idx = list(range(38)) elif args.single_blocks is not None: single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")] if args.use_lora: for name, attn_processor in dit.attn_processors.items(): match = re.search(r'\.(\d+)\.', name) if match: layer_index = int(match.group(1)) if name.startswith("double_blocks") and layer_index in double_blocks_idx: # if accelerator.is_main_process: # print("setting LoRA Processor for", name) lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( dim=3072, rank=args.rank ) elif name.startswith("single_blocks") and layer_index in single_blocks_idx: # if accelerator.is_main_process: # print("setting LoRA Processor for", name) lora_attn_procs[name] = SingleStreamBlockLoraProcessor( dim=3072, rank=args.rank ) else: lora_attn_procs[name] = attn_processor dit.set_attn_processor(lora_attn_procs) # if args.use_ip: # # unpack checkpoint # checkpoint = load_checkpoint(args.ip_local_path, args.ip_repo_id, args.ip_name) # prefix = "double_blocks." # # blocks = {} # proj = {} # for key, value in checkpoint.items(): # # if key.startswith(prefix): # # blocks[key[len(prefix):].replace('.processor.', '.')] = value # if key.startswith("ip_adapter_proj_model"): # proj[key[len("ip_adapter_proj_model."):]] = value # # # load image encoder # # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to( # # # accelerator.device, dtype=torch.bfloat16 # # device, dtype=torch.bfloat16 # # ) # # ip_clip_image_processor = CLIPImageProcessor() # # setup image embedding projection model # ip_improj = ImageProjModel(4096, 768, 4) # ip_improj.load_state_dict(proj) # # ip_improj = ip_improj.to(accelerator.device, dtype=torch.bfloat16) # ip_improj = ip_improj.to(device, dtype=torch.bfloat16) # ip_attn_procs = {} # for name, _ in dit.attn_processors.items(): # ip_state_dict = {} # for k in checkpoint.keys(): # if name in k: # ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] # if ip_state_dict: # ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) # ip_attn_procs[name].load_state_dict(ip_state_dict) # ip_attn_procs[name].to(accelerator.device, dtype=torch.bfloat16) # else: # ip_attn_procs[name] = dit.attn_processors[name] # dit.set_attn_processor(ip_attn_procs) vae.requires_grad_(False) t5.requires_grad_(False) clip.requires_grad_(False) # weight_dtype = torch.float32 # if accelerator.mixed_precision == "fp16": # weight_dtype = torch.float16 # args.mixed_precision = accelerator.mixed_precision # elif accelerator.mixed_precision == "bf16": # weight_dtype = torch.bfloat16 # args.mixed_precision = accelerator.mixed_precision # print(f"Resuming from checkpoint {args.ckpt_dir}") # dit_stat_dict = load_file(args.ckpt_dir) # Get path from Hub model_path = hf_hub_download( repo_id="Boese0601/ByteMorpher", filename="dit.safetensors" ) state_dict = load_file(model_path) dit.load_state_dict(state_dict) dit = dit.to(weight_dtype) dit.eval() # test_dataloader = loader(**args.data_config) test_dataloader = eval_image_pair_loader(**args.data_config) # from deepspeed import initialize dit = accelerator.prepare(dit) # if accelerator.is_main_process: # accelerator.init_trackers(args.tracker_project_name, {"test": None}) # logger.info("***** Running Evaluation *****") # logger.info(f" Instantaneous batch size = {args.eval_batch_size}") # progress_bar = tqdm( # range(0, len(test_dataloader)), # initial=0, # desc="Steps", # disable=not accelerator.is_local_main_process, # ) # for step, batch in enumerate(test_dataloader): # with accelerator.accumulate(dit): # img, tgt_image, prompt, edit_prompt, img_name, edit_name = batch img = image_resize(image, 512) w, h = img.size new_w = (w // 32) * 32 new_h = (h // 32) * 32 img = img.resize((new_w, new_h)) img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1).unsqueeze(0) edit_prompt = edit_prompt # if args.use_ip: # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) # elif args.use_spatial_condition: # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) # else: # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) with torch.no_grad(): result = sampler(prompt=edit_prompt, width=args.sample_width, height=args.sample_height, num_steps=args.sample_steps, image_prompt=None, # ip_adapter true_gs=args.cfg_scale, seed=args.seed, ip_scale=args.ip_scale if args.use_ip else 1.0, source_image=img if args.use_spatial_condition else None, ) gen_img = result # progress_bar.update(1) # accelerator.wait_for_everyone() # accelerator.end_training() return gen_img def get_samples(): sample_list = [ { "image": "assets/0_camera_zoom/20486354.png", "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", }, ] return [ [ Image.open(sample["image"]).resize((512, 512)), sample["edit_prompt"], ] for sample in sample_list ] header = """ # ByteMoprh
arXiv HuggingFace GitHub
""" def create_app(): with gr.Blocks() as app: gr.Markdown(header, elem_id="header") with gr.Row(equal_height=False): with gr.Column(variant="panel", elem_classes="inputPanel"): original_image = gr.Image( type="pil", label="Condition Image", width=300, elem_id="input" ) edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") submit_btn = gr.Button("Run", elem_id="submit_btn") with gr.Column(variant="panel", elem_classes="outputPanel"): output_image = gr.Image(type="pil", elem_id="output") with gr.Row(): examples = gr.Examples( examples=get_samples(), inputs=[original_image, edit_prompt], label="Examples", ) submit_btn.click( fn=generate, inputs=[original_image, edit_prompt], outputs=output_image, ) gr.HTML( """
* This demo's template was modified from OminiControl.
""" ) return app if __name__ == "__main__": print("CUDA available:", torch.cuda.is_available()) print("CUDA version:", torch.version.cuda) print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None") # mp.set_start_method("spawn", force=True) create_app().launch(debug=False, share=True, ssr_mode=False)