fantos commited on
Commit
0d158d0
Β·
verified Β·
1 Parent(s): 02fd843

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -81,10 +81,14 @@ def initialize_model():
81
  lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
82
  pipe.load_lora_weights(lora_path)
83
  pipe.fuse_lora(lora_scale=0.125)
 
 
84
  pipe.to(device="cuda", dtype=torch.bfloat16)
85
 
86
- # μ•ˆμ „ 검사기 μΆ”κ°€
87
  pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
 
 
88
 
89
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
90
  return True
@@ -270,14 +274,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
270
  def process_image(height, width, steps, scales, prompt, seed):
271
  global pipe
272
 
273
- # λͺ¨λΈ μ΄ˆκΈ°ν™” μƒνƒœ 확인
274
- if pipe is None:
275
- return None, "λͺ¨λΈμ„ λ‘œλ”© μ€‘μž…λ‹ˆλ‹€... 처음 μ‹€ν–‰ μ‹œ μ‹œκ°„μ΄ μ†Œμš”λ  수 μžˆμŠ΅λ‹ˆλ‹€.", True, "", False
276
-
277
- model_loaded = initialize_model()
278
- if not model_loaded:
279
- return None, "", False, "λͺ¨λΈ λ‘œλ”© 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. νŽ˜μ΄μ§€λ₯Ό μƒˆλ‘œκ³ μΉ¨ν•˜κ³  λ‹€μ‹œ μ‹œλ„ν•΄ μ£Όμ„Έμš”.", True
280
-
281
  # μž…λ ₯κ°’ 검증
282
  if not prompt or prompt.strip() == "":
283
  return None, "", False, "이미지 μ„€λͺ…을 μž…λ ₯ν•΄μ£Όμ„Έμš”.", True
@@ -298,21 +294,25 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
298
  else:
299
  seed = int(seed) # νƒ€μž… λ³€ν™˜ μ•ˆμ „ν•˜κ²Œ 처리
300
 
301
- # 이미지 생성 μƒνƒœ λ©”μ‹œμ§€
302
- loading_message = "이미지λ₯Ό 생성 μ€‘μž…λ‹ˆλ‹€..."
303
-
 
 
 
 
 
304
  # 이미지 생성
305
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
 
306
  generator = torch.Generator(device="cuda").manual_seed(seed)
307
 
308
- # 높이와 λ„ˆλΉ„λ₯Ό 64의 배수둜 μ‘°μ • (FLUX λͺ¨λΈ μš”κ΅¬μ‚¬ν•­)
309
- height = (int(height) // 64) * 64
310
- width = (int(width) // 64) * 64
311
-
312
- # μ•ˆμ „μž₯치 - μ΅œλŒ€κ°’ μ œν•œ
313
- steps = min(int(steps), 25)
314
- scales = max(min(float(scales), 5.0), 0.0)
315
-
316
  generated_image = pipe(
317
  prompt=[filtered_prompt],
318
  generator=generator,
@@ -320,10 +320,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
320
  guidance_scale=scales,
321
  height=height,
322
  width=width,
323
- max_sequence_length=256
 
324
  ).images[0]
325
 
326
- # 성곡 μ‹œ 이미지 λ°˜ν™˜, μƒνƒœ λ©”μ‹œμ§€ μˆ¨κΉ€
327
  return generated_image, "", False, "", False
328
 
329
  except Exception as e:
@@ -342,11 +342,19 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
342
 
343
  # 이미지 생성 μ€€λΉ„ ν•¨μˆ˜
344
  def prepare_generation(height, width, steps, scales, prompt, seed):
 
 
345
  # λͺ¨λΈμ΄ 아직 λ‘œλ“œλ˜μ§€ μ•Šμ•˜λ‹€λ©΄ λ‘œλ“œ
346
  if pipe is None:
 
 
 
347
  is_loaded = initialize_model()
348
  if not is_loaded:
349
- return None, "λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. νŽ˜μ΄μ§€λ₯Ό μƒˆλ‘œκ³ μΉ¨ν•˜κ³  λ‹€μ‹œ μ‹œλ„ν•΄ μ£Όμ„Έμš”.", True, "", False
 
 
 
350
 
351
  # 생성 ν”„λ‘œμ„ΈμŠ€ μ‹œμž‘
352
  return process_image(height, width, steps, scales, prompt, seed)
@@ -364,5 +372,5 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
364
  )
365
 
366
  if __name__ == "__main__":
367
- # μ•± μ‹œμž‘ μ‹œ λͺ¨λΈ 미리 λ‘œλ“œν•˜μ§€ μ•ŠμŒ (첫 μš”μ²­ μ‹œ μ§€μ—° λ‘œλ”©)
368
  demo.queue(max_size=10).launch()
 
81
  lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
82
  pipe.load_lora_weights(lora_path)
83
  pipe.fuse_lora(lora_scale=0.125)
84
+
85
+ # 주의: μ—¬κΈ°μ„œ deviceλ₯Ό λͺ…μ‹œμ μœΌλ‘œ μ§€μ • (λͺ¨λ“  μ»΄ν¬λ„ŒνŠΈμ— 적용)
86
  pipe.to(device="cuda", dtype=torch.bfloat16)
87
 
88
+ # μ•ˆμ „ 검사기 μΆ”κ°€ 및 μ˜¬λ°”λ₯Έ μž₯치둜 이동
89
  pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
90
+ if hasattr(pipe, 'safety_checker') and pipe.safety_checker is not None:
91
+ pipe.safety_checker.to("cuda")
92
 
93
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
94
  return True
 
274
  def process_image(height, width, steps, scales, prompt, seed):
275
  global pipe
276
 
 
 
 
 
 
 
 
 
277
  # μž…λ ₯κ°’ 검증
278
  if not prompt or prompt.strip() == "":
279
  return None, "", False, "이미지 μ„€λͺ…을 μž…λ ₯ν•΄μ£Όμ„Έμš”.", True
 
294
  else:
295
  seed = int(seed) # νƒ€μž… λ³€ν™˜ μ•ˆμ „ν•˜κ²Œ 처리
296
 
297
+ # 높이와 λ„ˆλΉ„λ₯Ό 64의 배수둜 μ‘°μ • (FLUX λͺ¨λΈ μš”κ΅¬μ‚¬ν•­)
298
+ height = (int(height) // 64) * 64
299
+ width = (int(width) // 64) * 64
300
+
301
+ # μ•ˆμ „μž₯치 - μ΅œλŒ€κ°’ μ œν•œ
302
+ steps = min(int(steps), 25)
303
+ scales = max(min(float(scales), 5.0), 0.0)
304
+
305
  # 이미지 생성
306
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
307
+ # μ€‘μš”: generator μ„€μ • μ‹œ deviceλ₯Ό λͺ…μ‹œμ μœΌλ‘œ μ§€μ •
308
  generator = torch.Generator(device="cuda").manual_seed(seed)
309
 
310
+ # λͺ¨λ“  ν…μ„œκ°€ 같은 λ””λ°”μ΄μŠ€μ— μžˆλŠ”μ§€ 확인
311
+ for name, module in pipe.components.items():
312
+ if hasattr(module, 'device') and module.device.type != "cuda":
313
+ module.to("cuda")
314
+
315
+ # 이미지 생성 - λͺ¨λ“  λ§€κ°œλ³€μˆ˜μ— deviceλ₯Ό λͺ…μ‹œμ  μ§€μ •
 
 
316
  generated_image = pipe(
317
  prompt=[filtered_prompt],
318
  generator=generator,
 
320
  guidance_scale=scales,
321
  height=height,
322
  width=width,
323
+ max_sequence_length=256,
324
+ device="cuda" # λͺ…μ‹œμ  device μ§€μ •
325
  ).images[0]
326
 
 
327
  return generated_image, "", False, "", False
328
 
329
  except Exception as e:
 
342
 
343
  # 이미지 생성 μ€€λΉ„ ν•¨μˆ˜
344
  def prepare_generation(height, width, steps, scales, prompt, seed):
345
+ global pipe
346
+
347
  # λͺ¨λΈμ΄ 아직 λ‘œλ“œλ˜μ§€ μ•Šμ•˜λ‹€λ©΄ λ‘œλ“œ
348
  if pipe is None:
349
+ # λ‘œλ”© μƒνƒœ ν‘œμ‹œ
350
+ loading_message = "λͺ¨λΈμ„ λ‘œλ”© μ€‘μž…λ‹ˆλ‹€... 처음 μ‹€ν–‰ μ‹œ μ‹œκ°„μ΄ μ†Œμš”λ  수 μžˆμŠ΅λ‹ˆλ‹€."
351
+
352
  is_loaded = initialize_model()
353
  if not is_loaded:
354
+ return None, "", False, "λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. νŽ˜μ΄μ§€λ₯Ό μƒˆλ‘œκ³ μΉ¨ν•˜κ³  λ‹€μ‹œ μ‹œλ„ν•΄ μ£Όμ„Έμš”.", True
355
+
356
+ # 생성 μƒνƒœ ν‘œμ‹œ
357
+ loading_message = "이미지λ₯Ό 생성 μ€‘μž…λ‹ˆλ‹€..."
358
 
359
  # 생성 ν”„λ‘œμ„ΈμŠ€ μ‹œμž‘
360
  return process_image(height, width, steps, scales, prompt, seed)
 
372
  )
373
 
374
  if __name__ == "__main__":
375
+
376
  demo.queue(max_size=10).launch()