yongyeol commited on
Commit
35cddec
·
verified ·
1 Parent(s): a6a8969

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # ────────────────────────────────────────────────────────────────────────────
2
  # app.py – Text ➜ 2D (FLUX-mini Kontext) ➜ 3D (Hunyuan3D-2)
3
  # • Fits into 16 GB system RAM: 경량 모델 + lazy loading + offload
4
- # • Requires: GPU (A10G 24 GB ideal, T4 16 GB OK with fp16)
5
  # ────────────────────────────────────────────────────────────────────────────
6
  import os
7
  import tempfile
@@ -27,7 +27,6 @@ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
27
 
28
  # ─────────────────────── Lazy loaders ───────────────────────
29
  from diffusers import FluxKontextPipeline, FluxPipeline
30
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
31
 
32
  # Global caches
33
  kontext_pipe = None # type: FluxKontextPipeline | None
@@ -38,16 +37,17 @@ paint_pipe = None
38
  MINI_KONTEXT_REPO = "black-forest-labs/FLUX.1-Kontext-mini"
39
  MINI_T2I_REPO = "black-forest-labs/FLUX.1-mini"
40
  HUNYUAN_REPO = "tencent/Hunyuan3D-2"
 
41
 
42
 
43
  def load_kontext() -> FluxKontextPipeline:
44
  global kontext_pipe
45
  if kontext_pipe is None:
46
- print("[+] Loading FLUX.1-Kontext-mini … (low_cpu_mem_usage)")
47
  kontext_pipe = FluxKontextPipeline.from_pretrained(
48
  MINI_KONTEXT_REPO,
49
  torch_dtype=DTYPE,
50
- device_map="auto",
51
  low_cpu_mem_usage=True,
52
  )
53
  kontext_pipe.set_progress_bar_config(disable=True)
@@ -55,14 +55,13 @@ def load_kontext() -> FluxKontextPipeline:
55
 
56
 
57
  def load_text2img() -> FluxPipeline:
58
- """Lazy-load light text→image model only when 필요."""
59
  global _text2img_pipe
60
  if _text2img_pipe is None:
61
- print("[+] Loading FLUX.1-mini (textimage)…")
62
  _text2img_pipe = FluxPipeline.from_pretrained(
63
  MINI_T2I_REPO,
64
  torch_dtype=DTYPE,
65
- device_map="auto",
66
  low_cpu_mem_usage=True,
67
  )
68
  _text2img_pipe.set_progress_bar_config(disable=True)
@@ -79,7 +78,7 @@ def load_hunyuan() -> tuple:
79
  shape_pipe = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
80
  HUNYUAN_REPO,
81
  torch_dtype=DTYPE,
82
- device_map="auto",
83
  low_cpu_mem_usage=True,
84
  )
85
  shape_pipe.set_progress_bar_config(disable=True)
@@ -87,7 +86,7 @@ def load_hunyuan() -> tuple:
87
  paint_pipe = Hunyuan3DPaintPipeline.from_pretrained(
88
  HUNYUAN_REPO,
89
  torch_dtype=DTYPE,
90
- device_map="auto",
91
  low_cpu_mem_usage=True,
92
  )
93
  paint_pipe.set_progress_bar_config(disable=True)
@@ -100,7 +99,6 @@ def load_hunyuan() -> tuple:
100
  def generate_single_2d(prompt: str, image: Image.Image | None, guidance_scale: float) -> Image.Image:
101
  kontext = load_kontext()
102
  if image is None:
103
- # 텍스트→이미지 : 경량 text2img 파이프라인 사용
104
  t2i = load_text2img()
105
  result = t2i(prompt=prompt, guidance_scale=guidance_scale).images[0]
106
  else:
@@ -116,7 +114,7 @@ def generate_multiview(prompt: str, base_image: Image.Image, guidance_scale: flo
116
  kontext(image=base_image, prompt=f"{prompt}, right side view", guidance_scale=guidance_scale).images[0],
117
  kontext(image=base_image, prompt=f"{prompt}, back view", guidance_scale=guidance_scale).images[0],
118
  ]
119
- return views # [front, left, right, back]
120
 
121
 
122
  def build_3d_mesh(prompt: str, images: List[Image.Image]) -> str:
@@ -175,4 +173,3 @@ def build_ui():
175
 
176
  if __name__ == "__main__":
177
  build_ui().queue(max_size=3).launch()
178
-
 
1
  # ────────────────────────────────────────────────────────────────────────────
2
  # app.py – Text ➜ 2D (FLUX-mini Kontext) ➜ 3D (Hunyuan3D-2)
3
  # • Fits into 16 GB system RAM: 경량 모델 + lazy loading + offload
4
+ # • Updated: use device_map="balanced" ("auto" not supported by Flux pipelines)
5
  # ────────────────────────────────────────────────────────────────────────────
6
  import os
7
  import tempfile
 
27
 
28
  # ─────────────────────── Lazy loaders ───────────────────────
29
  from diffusers import FluxKontextPipeline, FluxPipeline
 
30
 
31
  # Global caches
32
  kontext_pipe = None # type: FluxKontextPipeline | None
 
37
  MINI_KONTEXT_REPO = "black-forest-labs/FLUX.1-Kontext-mini"
38
  MINI_T2I_REPO = "black-forest-labs/FLUX.1-mini"
39
  HUNYUAN_REPO = "tencent/Hunyuan3D-2"
40
+ DEVICE_MAP_STRATEGY = "balanced" # "auto" unsupported for Flux pipelines
41
 
42
 
43
  def load_kontext() -> FluxKontextPipeline:
44
  global kontext_pipe
45
  if kontext_pipe is None:
46
+ print("[+] Loading FLUX.1-Kontext-mini … (balanced offload)")
47
  kontext_pipe = FluxKontextPipeline.from_pretrained(
48
  MINI_KONTEXT_REPO,
49
  torch_dtype=DTYPE,
50
+ device_map=DEVICE_MAP_STRATEGY,
51
  low_cpu_mem_usage=True,
52
  )
53
  kontext_pipe.set_progress_bar_config(disable=True)
 
55
 
56
 
57
  def load_text2img() -> FluxPipeline:
 
58
  global _text2img_pipe
59
  if _text2img_pipe is None:
60
+ print("[+] Loading FLUX.1-mini (text→image)…")
61
  _text2img_pipe = FluxPipeline.from_pretrained(
62
  MINI_T2I_REPO,
63
  torch_dtype=DTYPE,
64
+ device_map=DEVICE_MAP_STRATEGY,
65
  low_cpu_mem_usage=True,
66
  )
67
  _text2img_pipe.set_progress_bar_config(disable=True)
 
78
  shape_pipe = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
79
  HUNYUAN_REPO,
80
  torch_dtype=DTYPE,
81
+ device_map=DEVICE_MAP_STRATEGY,
82
  low_cpu_mem_usage=True,
83
  )
84
  shape_pipe.set_progress_bar_config(disable=True)
 
86
  paint_pipe = Hunyuan3DPaintPipeline.from_pretrained(
87
  HUNYUAN_REPO,
88
  torch_dtype=DTYPE,
89
+ device_map=DEVICE_MAP_STRATEGY,
90
  low_cpu_mem_usage=True,
91
  )
92
  paint_pipe.set_progress_bar_config(disable=True)
 
99
  def generate_single_2d(prompt: str, image: Image.Image | None, guidance_scale: float) -> Image.Image:
100
  kontext = load_kontext()
101
  if image is None:
 
102
  t2i = load_text2img()
103
  result = t2i(prompt=prompt, guidance_scale=guidance_scale).images[0]
104
  else:
 
114
  kontext(image=base_image, prompt=f"{prompt}, right side view", guidance_scale=guidance_scale).images[0],
115
  kontext(image=base_image, prompt=f"{prompt}, back view", guidance_scale=guidance_scale).images[0],
116
  ]
117
+ return views
118
 
119
 
120
  def build_3d_mesh(prompt: str, images: List[Image.Image]) -> str:
 
173
 
174
  if __name__ == "__main__":
175
  build_ui().queue(max_size=3).launch()