tori29umai commited on
Commit
dfc7eb3
·
verified ·
1 Parent(s): 4b8e0c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -33
app.py CHANGED
@@ -90,7 +90,28 @@ from transformers import SiglipImageProcessor, SiglipVisionModel
90
  from diffusers_helper.clip_vision import hf_clip_vision_encode
91
  from diffusers_helper.bucket_tools import find_nearest_bucket
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
94
 
95
  # 追加: 指定された解像度リスト
96
  NEW_RESOLUTIONS = [
@@ -99,7 +120,6 @@ NEW_RESOLUTIONS = [
99
  (768, 512), (832, 480), (864, 448), (960, 416), (640, 640),
100
  ]
101
 
102
-
103
  # Spaces環境では、すべてのCUDA操作を遅延させる
104
  if not IN_HF_SPACE:
105
  # 非Spaces環境でのみCUDAメモリを取得
@@ -175,7 +195,7 @@ def process_image_mask(image_mask_dict):
175
  background = image_mask_dict.get("background")
176
  layers = image_mask_dict.get("layers")
177
 
178
- if background is None:
179
  return None
180
 
181
  # ---- 1) Drop alpha from background ----
@@ -186,36 +206,32 @@ def process_image_mask(image_mask_dict):
186
  if img_array.ndim == 3 and img_array.shape[2] == 4:
187
  img_array = img_array[..., :3]
188
 
189
- # ---- 2) マスクがある場合のみマスク処理 ----
190
- if layers and len(layers) > 0:
191
- layer = layers[0]
192
- if isinstance(layer, Image.Image) and layer.mode == "RGBA":
193
- layer = layer.convert("RGB")
194
- mask_array = np.array(layer)
195
- if mask_array.ndim == 3 and mask_array.shape[2] == 4:
196
- mask_array = mask_array[..., :3]
197
-
198
- # convert to gray + binary
199
- if mask_array.ndim == 3:
200
- mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
201
- else:
202
- mask_gray = mask_array
203
- _, binary_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY)
204
 
205
- # 市松模様合成ロジック
206
- total_pixels = img_array.shape[0] * img_array.shape[1]
207
- cell_size = max(int(np.sqrt(total_pixels) / 20), 10)
208
- checkerboard = create_checkerboard(img_array.shape[1], img_array.shape[0], cell_size)
209
 
210
- result = img_array.copy()
211
- binary_mask_3ch = np.stack([binary_mask]*3, axis=2) // 255
212
- for c in range(3):
213
- result[..., c] = result[..., c] * (1 - binary_mask_3ch[..., c]) + checkerboard[..., c] * binary_mask_3ch[..., c]
214
 
215
- return result.astype(np.uint8)
216
- else:
217
- # マスクがない場合は元の画像をそのまま返す
218
- return img_array
219
 
220
  # 最も近い解像度を見つける関数
221
  def find_nearest_resolution(width, height):
@@ -337,7 +353,7 @@ def load_models():
337
 
338
  print("Transformerモデルを読み込み中...")
339
  transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
340
- "tori29umai/FramePackI2V_HY_mask_fadeout", torch_dtype=torch.bfloat16
341
  ).cpu()
342
 
343
  transformer.eval()
@@ -432,6 +448,9 @@ def worker_with_temp_files(
432
  gpu_memory_preservation,
433
  use_teacache,
434
  mp4_crf,
 
 
 
435
  ):
436
  global last_update_time
437
  last_update_time = time.time()
@@ -480,6 +499,27 @@ def worker_with_temp_files(
480
  feature_extractor = models['feature_extractor']
481
  image_encoder = models['image_encoder']
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
  # Clean GPU
485
  if not high_vram:
@@ -764,7 +804,7 @@ def worker_with_temp_files(
764
 
765
  # 非GPU環境用の標準process関数を追加
766
  if not IN_HF_SPACE or 'spaces' not in globals():
767
- def process_with_temp(image_mask_dict):
768
  """一時ファイルを使用する処理メインフロー(非GPU環境用)"""
769
  global stream
770
 
@@ -800,6 +840,13 @@ if not IN_HF_SPACE or 'spaces' not in globals():
800
  gpu_memory_preservation = 6.0
801
  use_teacache = False
802
  mp4_crf = 0
 
 
 
 
 
 
 
803
 
804
  # 非同期ワーカー起動
805
  stream = AsyncStream()
@@ -816,6 +863,9 @@ if not IN_HF_SPACE or 'spaces' not in globals():
816
  gpu_memory_preservation,
817
  use_teacache,
818
  mp4_crf,
 
 
 
819
  )
820
 
821
  temp_dir = None
@@ -909,7 +959,7 @@ if not IN_HF_SPACE or 'spaces' not in globals():
909
  # GPU環境用process_with_temp関数の内容を完成
910
  if IN_HF_SPACE and 'spaces' in globals():
911
  @spaces.GPU(duration=180)
912
- def process_with_temp(image_mask_dict):
913
  """一時ファイルを使用する処理メインフロー(GPU対応)"""
914
  global stream
915
 
@@ -945,6 +995,13 @@ if IN_HF_SPACE and 'spaces' in globals():
945
  gpu_memory_preservation = 6.0
946
  use_teacache = False
947
  mp4_crf = 0
 
 
 
 
 
 
 
948
 
949
  # 非同期ワーカー起動
950
  stream = AsyncStream()
@@ -961,6 +1018,9 @@ if IN_HF_SPACE and 'spaces' in globals():
961
  gpu_memory_preservation,
962
  use_teacache,
963
  mp4_crf,
 
 
 
964
  )
965
 
966
  temp_dir = None
@@ -1077,7 +1137,7 @@ if IN_HF_SPACE and 'spaces' in globals():
1077
  css = make_progress_bar_css()
1078
  block = gr.Blocks(css=css).queue()
1079
  with block:
1080
- gr.Markdown("# FramePackI2V_HY_mask_fadeout_frame1 - 画像のマスクした部分を除去")
1081
  with gr.Row():
1082
  with gr.Column():
1083
  # 入力画像をImageMaskで設定
@@ -1098,6 +1158,9 @@ with block:
1098
  start_button = gr.Button(value="生成開始")
1099
  end_button = gr.Button(value="生成中止", interactive=False)
1100
 
 
 
 
1101
  with gr.Column():
1102
  preview_image = gr.Image(label="生成プレビュー", visible=False)
1103
  result_frame = gr.Image(label="生成結果", visible=False, height="60vh", type="filepath")
@@ -1107,6 +1170,7 @@ with block:
1107
 
1108
  ips = [
1109
  image_mask,
 
1110
  ]
1111
  ops = [
1112
  preview_image,
 
90
  from diffusers_helper.clip_vision import hf_clip_vision_encode
91
  from diffusers_helper.bucket_tools import find_nearest_bucket
92
 
93
+ # GPU使用に必要なモジュールのインポートを試みる(可能な場合)
94
+ try:
95
+ from utils.lora_utils import merge_lora_to_state_dict
96
+ from utils.fp8_optimization_utils import optimize_state_dict_with_fp8, apply_fp8_monkey_patch
97
+ print("LoRAとFP8最適化モジュールを正常にインポートしました")
98
+ except ImportError as e:
99
+ print(f"一部のモジュールのインポートに失敗しました: {e}")
100
+ # ダミー関数を定義
101
+ def merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=None):
102
+ print("Warning: LoRA適用機能が利用できません")
103
+ return state_dict
104
+
105
+ def optimize_state_dict_with_fp8(state_dict, device, target_keys, exclude_keys, move_to_device=False):
106
+ print("Warning: FP8最適化機能が利用できません")
107
+ return state_dict
108
+
109
+ def apply_fp8_monkey_patch(model, state_dict, use_scaled_mm=False):
110
+ print("Warning: FP8 monkey patch機能が利用できません")
111
+ pass
112
 
113
+ outputs_folder = './outputs/'
114
+ os.makedirs(outputs_folder, exist_ok=True)
115
 
116
  # 追加: 指定された解像度リスト
117
  NEW_RESOLUTIONS = [
 
120
  (768, 512), (832, 480), (864, 448), (960, 416), (640, 640),
121
  ]
122
 
 
123
  # Spaces環境では、すべてのCUDA操作を遅延させる
124
  if not IN_HF_SPACE:
125
  # 非Spaces環境でのみCUDAメモリを取得
 
195
  background = image_mask_dict.get("background")
196
  layers = image_mask_dict.get("layers")
197
 
198
+ if background is None or not layers:
199
  return None
200
 
201
  # ---- 1) Drop alpha from background ----
 
206
  if img_array.ndim == 3 and img_array.shape[2] == 4:
207
  img_array = img_array[..., :3]
208
 
209
+ # ---- 2) Load mask layer and binarize ----
210
+ layer = layers[0]
211
+ if isinstance(layer, Image.Image) and layer.mode == "RGBA":
212
+ layer = layer.convert("RGB")
213
+ mask_array = np.array(layer)
214
+ if mask_array.ndim == 3 and mask_array.shape[2] == 4:
215
+ mask_array = mask_array[..., :3]
216
+
217
+ # convert to gray + binary
218
+ if mask_array.ndim == 3:
219
+ mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
220
+ else:
221
+ mask_gray = mask_array
222
+ _, binary_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY)
 
223
 
224
+ # 市松模様合成ロジック
225
+ total_pixels = img_array.shape[0] * img_array.shape[1]
226
+ cell_size = max(int(np.sqrt(total_pixels) / 20), 10)
227
+ checkerboard = create_checkerboard(img_array.shape[1], img_array.shape[0], cell_size)
228
 
229
+ result = img_array.copy()
230
+ binary_mask_3ch = np.stack([binary_mask]*3, axis=2) // 255
231
+ for c in range(3):
232
+ result[..., c] = result[..., c] * (1 - binary_mask_3ch[..., c]) + checkerboard[..., c] * binary_mask_3ch[..., c]
233
 
234
+ return result.astype(np.uint8)
 
 
 
235
 
236
  # 最も近い解像度を見つける関数
237
  def find_nearest_resolution(width, height):
 
353
 
354
  print("Transformerモデルを読み込み中...")
355
  transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
356
+ "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
357
  ).cpu()
358
 
359
  transformer.eval()
 
448
  gpu_memory_preservation,
449
  use_teacache,
450
  mp4_crf,
451
+ lora_file,
452
+ lora_multiplier,
453
+ fp8_optimization,
454
  ):
455
  global last_update_time
456
  last_update_time = time.time()
 
499
  feature_extractor = models['feature_extractor']
500
  image_encoder = models['image_encoder']
501
 
502
+ # LoRAファイルの適用
503
+ if lora_file is not None and os.path.exists(lora_file):
504
+ try:
505
+ print(f"LoRAファイル {os.path.basename(lora_file)} をマージします...")
506
+ state_dict = transformer.state_dict()
507
+ state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
508
+
509
+ if fp8_optimization:
510
+ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
511
+ EXCLUDE_KEYS = ["norm"] # Exclude norm layers from FP8
512
+
513
+ print("FP8最適化を適用します")
514
+ state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)
515
+ apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
516
+ gc.collect()
517
+
518
+ info = transformer.load_state_dict(state_dict, strict=True, assign=True)
519
+ print(f"LoRAと/またはFP8最適化を適用しました: {info}")
520
+ except Exception as e:
521
+ print(f"LoRA適用中にエラーが発生しました: {e}")
522
+ # エラー発生時も処理を継続
523
 
524
  # Clean GPU
525
  if not high_vram:
 
804
 
805
  # 非GPU環境用の標準process関数を追加
806
  if not IN_HF_SPACE or 'spaces' not in globals():
807
+ def process_with_temp(image_mask_dict, lora_multiplier=1.0):
808
  """一時ファイルを使用する処理メインフロー(非GPU環境用)"""
809
  global stream
810
 
 
840
  gpu_memory_preservation = 6.0
841
  use_teacache = False
842
  mp4_crf = 0
843
+ lora_file = "./LoRA/mask_fadeout_V1.safetensors"
844
+ fp8_optimization = False
845
+
846
+ # LoRAファイルの存在確認
847
+ if not os.path.exists(lora_file):
848
+ print(f"警告: LoRAファイル {lora_file} が見つかりません。LoRAなしで処理を続行します。")
849
+ lora_file = None
850
 
851
  # 非同期ワーカー起動
852
  stream = AsyncStream()
 
863
  gpu_memory_preservation,
864
  use_teacache,
865
  mp4_crf,
866
+ lora_file,
867
+ lora_multiplier,
868
+ fp8_optimization,
869
  )
870
 
871
  temp_dir = None
 
959
  # GPU環境用process_with_temp関数の内容を完成
960
  if IN_HF_SPACE and 'spaces' in globals():
961
  @spaces.GPU(duration=180)
962
+ def process_with_temp(image_mask_dict, lora_multiplier=1.0):
963
  """一時ファイルを使用する処理メインフロー(GPU対応)"""
964
  global stream
965
 
 
995
  gpu_memory_preservation = 6.0
996
  use_teacache = False
997
  mp4_crf = 0
998
+ lora_file = "./LoRA/mask_fadeout_V1.safetensors"
999
+ fp8_optimization = False
1000
+
1001
+ # LoRAファイルの存在確認
1002
+ if not os.path.exists(lora_file):
1003
+ print(f"警告: LoRAファイル {lora_file} が見つかりません。LoRAなしで処理を続行します。")
1004
+ lora_file = None
1005
 
1006
  # 非同期ワーカー起動
1007
  stream = AsyncStream()
 
1018
  gpu_memory_preservation,
1019
  use_teacache,
1020
  mp4_crf,
1021
+ lora_file,
1022
+ lora_multiplier,
1023
+ fp8_optimization,
1024
  )
1025
 
1026
  temp_dir = None
 
1137
  css = make_progress_bar_css()
1138
  block = gr.Blocks(css=css).queue()
1139
  with block:
1140
+ gr.Markdown("# FramePackI2V_HY_mask_fadeout - 画像のマスクした部分を除去")
1141
  with gr.Row():
1142
  with gr.Column():
1143
  # 入力画像をImageMaskで設定
 
1158
  start_button = gr.Button(value="生成開始")
1159
  end_button = gr.Button(value="生成中止", interactive=False)
1160
 
1161
+ with gr.Group():
1162
+ lora_multiplier = gr.Slider(label="LoRA倍率", minimum=0.0, maximum=2.0, value=1.0, step=0.1)
1163
+
1164
  with gr.Column():
1165
  preview_image = gr.Image(label="生成プレビュー", visible=False)
1166
  result_frame = gr.Image(label="生成結果", visible=False, height="60vh", type="filepath")
 
1170
 
1171
  ips = [
1172
  image_mask,
1173
+ lora_multiplier,
1174
  ]
1175
  ops = [
1176
  preview_image,