import spaces import gradio as gr import os import sys from typing import List # sys.path.append(os.getcwd()) import numpy as np from PIL import Image import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info print(f'torch version:{torch.__version__}') # import subprocess # import importlib, site, sys # # Re-discover all .pth/.egg-link files # for sitedir in site.getsitepackages(): # site.addsitedir(sitedir) # # Clear caches so importlib will pick up new modules # importlib.invalidate_caches() # def sh(cmd): subprocess.check_call(cmd, shell=True) # sh("pip install -U xformers --index-url https://download.pytorch.org/whl/cu126") # # tell Python to re-scan site-packages now that the egg-link exists # import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() import torch.utils.checkpoint from pytorch_lightning import seed_everything from diffusers import AutoencoderKL, DDIMScheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor from huggingface_hub import hf_hub_download, snapshot_download from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix from ram.models.ram_lora import ram from ram import inference_ram as inference from torchvision import transforms from models.controlnet import ControlNetModel from models.unet_2d_condition import UNet2DConditionModel VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( VLM_NAME, torch_dtype="auto", device_map="auto" # immediately dispatches layers onto available GPUs ) vlm_processor = AutoProcessor.from_pretrained(VLM_NAME) def _generate_vlm_prompt( vlm_model: Qwen2_5_VLForConditionalGeneration, vlm_processor: AutoProcessor, process_vision_info, pil_image: Image.Image, device: str = "cuda" ) -> str: """ Given two PIL.Image inputs: - prev_pil: the “full” image at the previous recursion. - zoomed_pil: the cropped+resized (zoom) image for this step. Returns a single “recursive_multiscale” prompt string. """ message_text = ( "The give a detailed description of this image." "describe each element with fine details." ) messages = [ {"role": "system", "content": message_text}, { "role": "user", "content": [ {"type": "image", "image": pil_image}, ], }, ] text = vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = vlm_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(device) generated = vlm_model.generate(**inputs, max_new_tokens=128) trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated) ] out_text = vlm_processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return out_text.strip() tensor_transforms = transforms.Compose([ transforms.ToTensor(), ]) ram_transforms = transforms.Compose([ transforms.Resize((384, 384)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) snapshot_download( repo_id="alexnasa/SEESR", local_dir="preset/models" ) snapshot_download( repo_id="stabilityai/stable-diffusion-2-1-base", local_dir="preset/models/stable-diffusion-2-1-base" ) snapshot_download( repo_id="xinyu1205/recognize_anything_model", local_dir="preset/models/" ) # Load scheduler, tokenizer and models. pretrained_model_path = 'preset/models/stable-diffusion-2-1-base' seesr_model_path = 'preset/models/seesr' scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") unet = UNet2DConditionModel.from_pretrained(seesr_model_path, subfolder="unet") controlnet = ControlNetModel.from_pretrained(seesr_model_path, subfolder="controlnet") # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) controlnet.requires_grad_(False) # unet.to("cuda") # controlnet.to("cuda") # unet.enable_xformers_memory_efficient_attention() # controlnet.enable_xformers_memory_efficient_attention() # Get the validation pipeline validation_pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=None, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, ) validation_pipeline._init_tiled_vae(encoder_tile_size=1024, decoder_tile_size=224) weight_dtype = torch.float16 device = "cuda" # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) controlnet.to(device, dtype=weight_dtype) tag_model = ram(pretrained='preset/models/ram_swin_large_14m.pth', pretrained_condition='preset/models/DAPE.pth', image_size=384, vit='swin_l') tag_model.eval() tag_model.to(device, dtype=weight_dtype) @spaces.GPU(duration=120) def process( input_image: Image.Image, user_prompt: str, use_KDS: bool, bandwidth: float, patch_size: int, num_particles: int, positive_prompt: str, negative_prompt: str, num_inference_steps: int, scale_factor: int, cfg_scale: float, seed: int, latent_tiled_size: int, latent_tiled_overlap: int, sample_times: int ) -> List[np.ndarray]: process_size = 512 resize_preproc = transforms.Compose([ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), ]) prompt_tag = _generate_vlm_prompt( vlm_model=vlm_model, vlm_processor=vlm_processor, process_vision_info=process_vision_info, pil_image=input_image, device=device, ) print(f'oh lala, prompt tag:{prompt_tag}') # with torch.no_grad(): seed_everything(seed) generator = torch.Generator(device=device) validation_prompt = "" lq = tensor_transforms(input_image).unsqueeze(0).to(device).half() lq = ram_transforms(lq) res = inference(lq, tag_model) ram_encoder_hidden_states = tag_model.generate_image_embeds(lq) validation_prompt = f"{res[0]}, {positive_prompt}," validation_prompt = validation_prompt if user_prompt=='' else f"{user_prompt}, {validation_prompt}" ori_width, ori_height = input_image.size resize_flag = False rscale = scale_factor input_image = input_image.resize((int(input_image.size[0] * rscale), int(input_image.size[1] * rscale))) if min(input_image.size) < process_size: input_image = resize_preproc(input_image) input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)) width, height = input_image.size resize_flag = True # images = [] for _ in range(sample_times): try: with torch.autocast("cuda"): image = validation_pipeline( validation_prompt, input_image, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, guidance_scale=cfg_scale, conditioning_scale=1, start_point='lr', start_steps=999,ram_encoder_hidden_states=ram_encoder_hidden_states, latent_tiled_size=latent_tiled_size, latent_tiled_overlap=latent_tiled_overlap, use_KDS=use_KDS, bandwidth=bandwidth, num_particles=num_particles, patch_size=patch_size, ).images[0] if True: # alpha<1.0: image = wavelet_color_fix(image, input_image) if resize_flag: image = image.resize((ori_width * rscale, ori_height * rscale)) except Exception as e: print(e) image = Image.new(mode="RGB", size=(512, 512)) images.append(np.array(image)) return images # MARKDOWN = \ """ ## SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution [GitHub](https://github.com/cswry/SeeSR) | [Paper](https://arxiv.org/abs/2311.16518) If SeeSR is helpful for you, please help star the GitHub Repo. Thanks! """ block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil") num_particles = gr.Slider(label="Num of Partickes", minimum=1, maximum=16, step=1, value=10) bandwidth = gr.Slider(label="Bandwidth", minimum=0.1, maximum=0.8, step=0.1, value=0.1) patch_size = gr.Slider(label="Patch Size", minimum=1, maximum=16, step=1, value=16) use_KDS = gr.Checkbox(label="Use Kernel Density Steering") run_button = gr.Button("Run") with gr.Accordion("Options", open=True): user_prompt = gr.Textbox(label="User Prompt", value="") positive_prompt = gr.Textbox(label="Positive Prompt", value="clean, high-resolution, 8k, best quality, masterpiece") negative_prompt = gr.Textbox( label="Negative Prompt", value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" ) cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set to 1.0 in sd-turbo)", minimum=1, maximum=10, value=7.5, step=0) num_inference_steps = gr.Slider(label="Inference Steps", minimum=2, maximum=100, value=50, step=1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) sample_times = gr.Slider(label="Sample Times", minimum=1, maximum=10, step=1, value=1) latent_tiled_size = gr.Slider(label="Diffusion Tile Size", minimum=128, maximum=480, value=320, step=1) latent_tiled_overlap = gr.Slider(label="Diffusion Tile Overlap", minimum=4, maximum=16, value=4, step=1) scale_factor = gr.Number(label="SR Scale", value=4) with gr.Column(): result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery") examples = gr.Examples( examples=[ [ "preset/datasets/test_datasets/woman.png", "", False, 0.1, 4, 4, "clean, high-resolution, 8k, best quality, masterpiece", "dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 50, 4, 7.5, 123, 320, 4, 1, ], [ "preset/datasets/test_datasets/woman.png", "", True, 0.1, 16, 4, "clean, high-resolution, 8k, best quality, masterpiece", "dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 50, 4, 7.5, 123, 320, 4, 1, ], [ "preset/datasets/test_datasets/woman.png", "", True, 0.1, 4, 4, "clean, high-resolution, 8k, best quality, masterpiece", "dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 50, 4, 7.5, 123, 320, 4, 1, ], ], inputs=[ input_image, user_prompt, use_KDS, bandwidth, patch_size, num_particles, positive_prompt, negative_prompt, num_inference_steps, scale_factor, cfg_scale, seed, latent_tiled_size, latent_tiled_overlap, sample_times, ], outputs=[result_gallery], fn=process, cache_examples=True, ) inputs = [ input_image, user_prompt, use_KDS, bandwidth, patch_size, num_particles, positive_prompt, negative_prompt, num_inference_steps, scale_factor, cfg_scale, seed, latent_tiled_size, latent_tiled_overlap, sample_times, ] run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) block.launch(share=True)