alexnasa commited on
Commit
c213ea5
·
verified ·
1 Parent(s): e14280e

Upload gradio_seesr.py

Browse files
Files changed (1) hide show
  1. gradio_seesr.py +207 -0
gradio_seesr.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ from typing import List
5
+ # sys.path.append(os.getcwd())
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from pytorch_lightning import seed_everything
13
+ from diffusers import AutoencoderKL, DDPMScheduler
14
+ from diffusers.utils import check_min_version
15
+ from diffusers.utils.import_utils import is_xformers_available
16
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
17
+
18
+ from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline
19
+
20
+ from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
21
+
22
+ from ram.models.ram_lora import ram
23
+ from ram import inference_ram as inference
24
+ from torchvision import transforms
25
+ from models.controlnet import ControlNetModel
26
+ from models.unet_2d_condition import UNet2DConditionModel
27
+
28
+ tensor_transforms = transforms.Compose([
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ ram_transforms = transforms.Compose([
33
+ transforms.Resize((384, 384)),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
35
+ ])
36
+
37
+
38
+ # Load scheduler, tokenizer and models.
39
+ pretrained_model_path = 'preset/models/stable-diffusion-2-1-base'
40
+ seesr_model_path = 'preset/models/seesr'
41
+
42
+ scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
43
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
44
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
45
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
46
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor")
47
+ unet = UNet2DConditionModel.from_pretrained(seesr_model_path, subfolder="unet")
48
+ controlnet = ControlNetModel.from_pretrained(seesr_model_path, subfolder="controlnet")
49
+
50
+ # Freeze vae and text_encoder
51
+ vae.requires_grad_(False)
52
+ text_encoder.requires_grad_(False)
53
+ unet.requires_grad_(False)
54
+ controlnet.requires_grad_(False)
55
+
56
+ if is_xformers_available():
57
+ unet.enable_xformers_memory_efficient_attention()
58
+ controlnet.enable_xformers_memory_efficient_attention()
59
+ else:
60
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
61
+
62
+ # Get the validation pipeline
63
+ validation_pipeline = StableDiffusionControlNetPipeline(
64
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
65
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
66
+ )
67
+
68
+ validation_pipeline._init_tiled_vae(encoder_tile_size=1024,
69
+ decoder_tile_size=224)
70
+ weight_dtype = torch.float16
71
+ device = "cuda"
72
+
73
+ # Move text_encode and vae to gpu and cast to weight_dtype
74
+ text_encoder.to(device, dtype=weight_dtype)
75
+ vae.to(device, dtype=weight_dtype)
76
+ unet.to(device, dtype=weight_dtype)
77
+ controlnet.to(device, dtype=weight_dtype)
78
+
79
+
80
+ tag_model = ram(pretrained='preset/models/ram_swin_large_14m.pth',
81
+ pretrained_condition='preset/models/DAPE.pth',
82
+ image_size=384,
83
+ vit='swin_l')
84
+ tag_model.eval()
85
+ tag_model.to(device, dtype=weight_dtype)
86
+
87
+ @torch.no_grad()
88
+ def process(
89
+ input_image: Image.Image,
90
+ user_prompt: str,
91
+ positive_prompt: str,
92
+ negative_prompt: str,
93
+ num_inference_steps: int,
94
+ scale_factor: int,
95
+ cfg_scale: float,
96
+ seed: int,
97
+ latent_tiled_size: int,
98
+ latent_tiled_overlap: int,
99
+ sample_times: int
100
+ ) -> List[np.ndarray]:
101
+ process_size = 512
102
+ resize_preproc = transforms.Compose([
103
+ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
104
+ ])
105
+
106
+ # with torch.no_grad():
107
+ seed_everything(seed)
108
+ generator = torch.Generator(device=device)
109
+
110
+ validation_prompt = ""
111
+ lq = tensor_transforms(input_image).unsqueeze(0).to(device).half()
112
+ lq = ram_transforms(lq)
113
+ res = inference(lq, tag_model)
114
+ ram_encoder_hidden_states = tag_model.generate_image_embeds(lq)
115
+ validation_prompt = f"{res[0]}, {positive_prompt},"
116
+ validation_prompt = validation_prompt if user_prompt=='' else f"{user_prompt}, {validation_prompt}"
117
+
118
+ ori_width, ori_height = input_image.size
119
+ resize_flag = False
120
+
121
+ rscale = scale_factor
122
+ input_image = input_image.resize((int(input_image.size[0] * rscale), int(input_image.size[1] * rscale)))
123
+
124
+ if min(input_image.size) < process_size:
125
+ input_image = resize_preproc(input_image)
126
+
127
+ input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
128
+ width, height = input_image.size
129
+ resize_flag = True #
130
+
131
+ images = []
132
+ for _ in range(sample_times):
133
+ try:
134
+ with torch.autocast("cuda"):
135
+ image = validation_pipeline(
136
+ validation_prompt, input_image, negative_prompt=negative_prompt,
137
+ num_inference_steps=num_inference_steps, generator=generator,
138
+ height=height, width=width,
139
+ guidance_scale=cfg_scale, conditioning_scale=1,
140
+ start_point='lr', start_steps=999,ram_encoder_hidden_states=ram_encoder_hidden_states,
141
+ latent_tiled_size=latent_tiled_size, latent_tiled_overlap=latent_tiled_overlap
142
+ ).images[0]
143
+
144
+ if True: # alpha<1.0:
145
+ image = wavelet_color_fix(image, input_image)
146
+
147
+ if resize_flag:
148
+ image = image.resize((ori_width * rscale, ori_height * rscale))
149
+ except Exception as e:
150
+ print(e)
151
+ image = Image.new(mode="RGB", size=(512, 512))
152
+ images.append(np.array(image))
153
+ return images
154
+
155
+
156
+ #
157
+ MARKDOWN = \
158
+ """
159
+ ## SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution
160
+
161
+ [GitHub](https://github.com/cswry/SeeSR) | [Paper](https://arxiv.org/abs/2311.16518)
162
+
163
+ If SeeSR is helpful for you, please help star the GitHub Repo. Thanks!
164
+ """
165
+
166
+ block = gr.Blocks().queue()
167
+ with block:
168
+ with gr.Row():
169
+ gr.Markdown(MARKDOWN)
170
+ with gr.Row():
171
+ with gr.Column():
172
+ input_image = gr.Image(source="upload", type="pil")
173
+ run_button = gr.Button(label="Run")
174
+ with gr.Accordion("Options", open=True):
175
+ user_prompt = gr.Textbox(label="User Prompt", value="")
176
+ positive_prompt = gr.Textbox(label="Positive Prompt", value="clean, high-resolution, 8k, best quality, masterpiece")
177
+ negative_prompt = gr.Textbox(
178
+ label="Negative Prompt",
179
+ value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
180
+ )
181
+ cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=0.1, maximum=10.0, value=5.5, step=0.1)
182
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=50, step=1)
183
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
184
+ sample_times = gr.Slider(label="Sample Times", minimum=1, maximum=10, step=1, value=1)
185
+ latent_tiled_size = gr.Slider(label="Diffusion Tile Size", minimum=128, maximum=480, value=320, step=1)
186
+ latent_tiled_overlap = gr.Slider(label="Diffusion Tile Overlap", minimum=4, maximum=16, value=4, step=1)
187
+ scale_factor = gr.Number(label="SR Scale", value=4)
188
+ with gr.Column():
189
+ result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
190
+
191
+ inputs = [
192
+ input_image,
193
+ user_prompt,
194
+ positive_prompt,
195
+ negative_prompt,
196
+ num_inference_steps,
197
+ scale_factor,
198
+ cfg_scale,
199
+ seed,
200
+ latent_tiled_size,
201
+ latent_tiled_overlap,
202
+ sample_times,
203
+ ]
204
+ run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
205
+
206
+ block.launch()
207
+