alexnasa commited on
Commit
584caad
·
verified ·
1 Parent(s): ac7cf4b

VLM lora added

Browse files
Files changed (1) hide show
  1. inference_coz_single.py +63 -86
inference_coz_single.py CHANGED
@@ -7,6 +7,7 @@ from torchvision import transforms
7
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
  from qwen_vl_utils import process_vision_info
9
  from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler
 
10
 
11
  # -------------------------------------------------------------------
12
  # Helper: Resize & center-crop to a fixed square
@@ -95,6 +96,60 @@ def _generate_vlm_prompt(
95
  return out_text.strip()
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # -------------------------------------------------------------------
100
  # Main Function: recursive_multiscale_sr (with multiple centers)
@@ -131,101 +186,23 @@ def recursive_multiscale_sr(
131
  f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}."
132
  )
133
 
134
- ###############################
135
- # 1. Fixed hyper-parameters
136
- ###############################
137
- device = "cuda"
138
- process_size = 512 # same as args.process_size
139
-
140
- # model checkpoint paths (hard-coded to your example)
141
- LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
142
- VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
143
- SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers"
144
- # VLM model name (hard-coded)
145
- VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
146
-
147
- ###############################
148
- # 2. Build a dummy “args” namespace
149
- # to satisfy OSEDiff_SD3_TEST constructor.
150
- ###############################
151
- class _Args:
152
- pass
153
-
154
- args = _Args()
155
- args.upscale = upscale
156
- args.lora_path = LORA_PATH
157
- args.vae_path = VAE_PATH
158
- args.pretrained_model_name_or_path = SD3_MODEL
159
- args.merge_and_unload_lora = False
160
- args.lora_rank = 4
161
- args.vae_decoder_tiled_size = 224
162
- args.vae_encoder_tiled_size = 1024
163
- args.latent_tiled_size = 96
164
- args.latent_tiled_overlap = 32
165
- args.mixed_precision = "fp16"
166
- args.efficient_memory = False
167
- # (other flags are not used by OSEDiff_SD3_TEST, so we skip them)
168
-
169
- ###############################
170
- # 3. Load the SD3 SR model (non-efficient)
171
- ###############################
172
- # 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders
173
- sd3 = SD3Euler()
174
- # move all text encoders + transformer + VAE to CUDA:
175
- sd3.text_enc_1.to(device)
176
- sd3.text_enc_2.to(device)
177
- sd3.text_enc_3.to(device)
178
- sd3.transformer.to(device, dtype=torch.float32)
179
- sd3.vae.to(device, dtype=torch.float32)
180
- # freeze
181
- for p in (
182
- sd3.text_enc_1,
183
- sd3.text_enc_2,
184
- sd3.text_enc_3,
185
- sd3.transformer,
186
- sd3.vae,
187
- ):
188
- p.requires_grad_(False)
189
-
190
- # 3.2 Wrap in OSEDiff_SD3_TEST helper:
191
- model_test = OSEDiff_SD3_TEST(args, sd3)
192
- # (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor])
193
-
194
- ###############################
195
- # 4. Load the VLM (Qwen2.5-VL)
196
- ###############################
197
- vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
198
- VLM_NAME,
199
- torch_dtype="auto",
200
- device_map="auto" # immediately dispatches layers onto available GPUs
201
- )
202
- vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
203
 
204
- ###############################
205
- # 5. Pre-allocate a Temporary Directory
206
- # to hold intermediate JPEG/PNG files
207
- ###############################
208
  unique_id = uuid.uuid4().hex
209
  prefix = f"recms_{unique_id}_"
210
 
211
  with tempfile.TemporaryDirectory(prefix=prefix) as td:
212
- # (we’ll write “prev.png” and “zoom.png” at each step)
213
 
214
- ###############################
215
- # 6. Prepare the very first “full” image
216
- ###############################
217
- # (6.1) Load + center crop → first_image (512×512)
218
  img0 = Image.open(input_png_path).convert("RGB")
219
  img0 = resize_and_center_crop(img0, process_size)
220
 
221
- # Note: we no longer need to write “prev.png” to disk. Just keep it in memory.
222
  prev_pil = img0.copy()
223
 
224
  sr_pil_list: list[Image.Image] = []
225
  prompt_list: list[str] = []
226
 
227
  for rec in range(rec_num):
228
- # (A) Compute low-res crop window on prev_pil
229
  w, h = prev_pil.size # (512×512)
230
  new_w, new_h = w // upscale, h // upscale
231
 
@@ -240,10 +217,10 @@ def recursive_multiscale_sr(
240
 
241
  cropped = prev_pil.crop((left, top, right, bottom))
242
 
243
- # (B) Upsample that crop back to (512×512)
244
  zoomed_pil = cropped.resize((w, h), Image.BICUBIC)
245
 
246
- # (C) Generate VLM prompt by passing PILs directly:
247
  prompt_tag = _generate_vlm_prompt(
248
  vlm_model=vlm_model,
249
  vlm_processor=vlm_processor,
@@ -253,22 +230,22 @@ def recursive_multiscale_sr(
253
  device=device,
254
  )
255
 
256
- # (D) Prepare “zoomed_pil” → tensor in [−1, 1]
257
  to_tensor = transforms.ToTensor()
258
  lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512)
259
  lq = (lq * 2.0) - 1.0
260
 
261
- # (E) Run SR inference
262
  with torch.no_grad():
263
  out_tensor = model_test(lq, prompt=prompt_tag)[0]
264
  out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
265
  out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
266
 
267
- # (F) Bookkeeping: set prev_pil = out_pil for next iteration
268
  prev_pil = out_pil
269
 
270
  # (G) Append to results
271
  sr_pil_list.append(out_pil)
272
  prompt_list.append(prompt_tag)
273
 
274
- return sr_pil_list, prompt_list
 
7
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
  from qwen_vl_utils import process_vision_info
9
  from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler
10
+ from peft import PeftModel
11
 
12
  # -------------------------------------------------------------------
13
  # Helper: Resize & center-crop to a fixed square
 
96
  return out_text.strip()
97
 
98
 
99
+ VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
100
+
101
+ vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
102
+ VLM_NAME,
103
+ torch_dtype="auto",
104
+ device_map="auto" # immediately dispatches layers onto available GPUs
105
+ )
106
+ vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
107
+
108
+ vlm_model = PeftModel.from_pretrained(vlm_model, "ckpt/VLM_LoRA/checkpoint-10000")
109
+ vlm_model = vlm_model.merge_and_unload()
110
+ vlm_model.eval()
111
+
112
+ device = "cuda"
113
+ process_size = 512
114
+
115
+ LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
116
+ VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
117
+ SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers"
118
+
119
+ class _Args:
120
+ pass
121
+
122
+ args = _Args()
123
+ args.upscale = 4
124
+ args.lora_path = LORA_PATH
125
+ args.vae_path = VAE_PATH
126
+ args.pretrained_model_name_or_path = SD3_MODEL
127
+ args.merge_and_unload_lora = False
128
+ args.lora_rank = 4
129
+ args.vae_decoder_tiled_size = 224
130
+ args.vae_encoder_tiled_size = 1024
131
+ args.latent_tiled_size = 96
132
+ args.latent_tiled_overlap = 32
133
+ args.mixed_precision = "fp16"
134
+ args.efficient_memory = False
135
+
136
+ sd3 = SD3Euler()
137
+ sd3.text_enc_1.to(device)
138
+ sd3.text_enc_2.to(device)
139
+ sd3.text_enc_3.to(device)
140
+ sd3.transformer.to(device, dtype=torch.float32)
141
+ sd3.vae.to(device, dtype=torch.float32)
142
+
143
+ for p in (
144
+ sd3.text_enc_1,
145
+ sd3.text_enc_2,
146
+ sd3.text_enc_3,
147
+ sd3.transformer,
148
+ sd3.vae,
149
+ ):
150
+ p.requires_grad_(False)
151
+
152
+ model_test = OSEDiff_SD3_TEST(args, sd3)
153
 
154
  # -------------------------------------------------------------------
155
  # Main Function: recursive_multiscale_sr (with multiple centers)
 
186
  f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}."
187
  )
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
 
 
 
 
190
  unique_id = uuid.uuid4().hex
191
  prefix = f"recms_{unique_id}_"
192
 
193
  with tempfile.TemporaryDirectory(prefix=prefix) as td:
 
194
 
 
 
 
 
195
  img0 = Image.open(input_png_path).convert("RGB")
196
  img0 = resize_and_center_crop(img0, process_size)
197
 
198
+
199
  prev_pil = img0.copy()
200
 
201
  sr_pil_list: list[Image.Image] = []
202
  prompt_list: list[str] = []
203
 
204
  for rec in range(rec_num):
205
+
206
  w, h = prev_pil.size # (512×512)
207
  new_w, new_h = w // upscale, h // upscale
208
 
 
217
 
218
  cropped = prev_pil.crop((left, top, right, bottom))
219
 
220
+
221
  zoomed_pil = cropped.resize((w, h), Image.BICUBIC)
222
 
223
+
224
  prompt_tag = _generate_vlm_prompt(
225
  vlm_model=vlm_model,
226
  vlm_processor=vlm_processor,
 
230
  device=device,
231
  )
232
 
233
+
234
  to_tensor = transforms.ToTensor()
235
  lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512)
236
  lq = (lq * 2.0) - 1.0
237
 
238
+
239
  with torch.no_grad():
240
  out_tensor = model_test(lq, prompt=prompt_tag)[0]
241
  out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
242
  out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
243
 
244
+
245
  prev_pil = out_pil
246
 
247
  # (G) Append to results
248
  sr_pil_list.append(out_pil)
249
  prompt_list.append(prompt_tag)
250
 
251
+ return sr_pil_list, prompt_list