Fabrice-TIERCELIN commited on
Commit
dcdec0b
·
verified ·
1 Parent(s): 3b7822b
Files changed (1) hide show
  1. app.py +47 -0
app.py CHANGED
@@ -425,6 +425,53 @@ def worker(input_image, image_position, prompts, n_prompt, seed, resolution, tot
425
 
426
  image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # Sampling
429
 
430
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
 
425
 
426
  image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
427
 
428
+ # Load transformer model
429
+ if model_changed:
430
+ stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading transformer ..."))))
431
+
432
+ transformer = None
433
+ time.sleep(1.0) # wait for the previous model to be unloaded
434
+ torch.cuda.empty_cache()
435
+ gc.collect()
436
+
437
+ previous_lora_file = lora_file
438
+ previous_lora_multiplier = lora_multiplier
439
+ previous_fp8_optimization = fp8_optimization
440
+
441
+ transformer = load_transfomer() # bfloat16, on cpu
442
+
443
+ if lora_file is not None or fp8_optimization:
444
+ state_dict = transformer.state_dict()
445
+
446
+ # LoRA should be merged before fp8 optimization
447
+ if lora_file is not None:
448
+ # TODO It would be better to merge the LoRA into the state dict before creating the transformer instance.
449
+ # Use from_config() instead of from_pretrained to make the instance without loading.
450
+
451
+ print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
452
+ state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
453
+ gc.collect()
454
+
455
+ if fp8_optimization:
456
+ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
457
+ EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
458
+
459
+ # inplace optimization
460
+ print("Optimizing for fp8")
461
+ state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)
462
+
463
+ # apply monkey patching
464
+ apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
465
+ gc.collect()
466
+
467
+ info = transformer.load_state_dict(state_dict, strict=True, assign=True)
468
+ print(f"LoRA and/or fp8 optimization applied: {info}")
469
+
470
+ if not high_vram:
471
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
472
+ else:
473
+ transformer.to(gpu)
474
+
475
  # Sampling
476
 
477
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))