Spaces:
Running
Running
Use LORA
Browse files
app.py
CHANGED
@@ -425,6 +425,53 @@ def worker(input_image, image_position, prompts, n_prompt, seed, resolution, tot
|
|
425 |
|
426 |
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
# Sampling
|
429 |
|
430 |
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
|
|
|
425 |
|
426 |
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
|
427 |
|
428 |
+
# Load transformer model
|
429 |
+
if model_changed:
|
430 |
+
stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading transformer ..."))))
|
431 |
+
|
432 |
+
transformer = None
|
433 |
+
time.sleep(1.0) # wait for the previous model to be unloaded
|
434 |
+
torch.cuda.empty_cache()
|
435 |
+
gc.collect()
|
436 |
+
|
437 |
+
previous_lora_file = lora_file
|
438 |
+
previous_lora_multiplier = lora_multiplier
|
439 |
+
previous_fp8_optimization = fp8_optimization
|
440 |
+
|
441 |
+
transformer = load_transfomer() # bfloat16, on cpu
|
442 |
+
|
443 |
+
if lora_file is not None or fp8_optimization:
|
444 |
+
state_dict = transformer.state_dict()
|
445 |
+
|
446 |
+
# LoRA should be merged before fp8 optimization
|
447 |
+
if lora_file is not None:
|
448 |
+
# TODO It would be better to merge the LoRA into the state dict before creating the transformer instance.
|
449 |
+
# Use from_config() instead of from_pretrained to make the instance without loading.
|
450 |
+
|
451 |
+
print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
|
452 |
+
state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
|
453 |
+
gc.collect()
|
454 |
+
|
455 |
+
if fp8_optimization:
|
456 |
+
TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
|
457 |
+
EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
|
458 |
+
|
459 |
+
# inplace optimization
|
460 |
+
print("Optimizing for fp8")
|
461 |
+
state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)
|
462 |
+
|
463 |
+
# apply monkey patching
|
464 |
+
apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
|
465 |
+
gc.collect()
|
466 |
+
|
467 |
+
info = transformer.load_state_dict(state_dict, strict=True, assign=True)
|
468 |
+
print(f"LoRA and/or fp8 optimization applied: {info}")
|
469 |
+
|
470 |
+
if not high_vram:
|
471 |
+
DynamicSwapInstaller.install_model(transformer, device=gpu)
|
472 |
+
else:
|
473 |
+
transformer.to(gpu)
|
474 |
+
|
475 |
# Sampling
|
476 |
|
477 |
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
|