openfree commited on
Commit
74a1bb2
ยท
verified ยท
1 Parent(s): 93718bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -26
app.py CHANGED
@@ -17,7 +17,7 @@ initialization_message = "๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘... ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”."
17
  # ๊ฐ„๋‹จํ•œ ์ธ์šฉ ์ •๋ณด ์ถ”๊ฐ€
18
  _CITE_ = """PuLID: Person-under-Language Image Diffusion Model"""
19
 
20
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ ๋ฐ ์žฅ์น˜ ์„ค์ •
21
  def get_device():
22
  if torch.cuda.is_available():
23
  return torch.device('cuda')
@@ -44,8 +44,8 @@ def get_models(name: str, device, offload: bool):
44
 
45
  class FluxGenerator:
46
  def __init__(self):
47
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ์žฅ์น˜ ์„ค์ •
48
- self.device = get_device()
49
  self.offload = False
50
  self.model_name = 'flux-dev'
51
  self.initialized = False
@@ -63,6 +63,9 @@ class FluxGenerator:
63
  from pulid.pipeline_flux import PuLIDPipeline
64
  from flux.sampling import prepare
65
 
 
 
 
66
  print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
67
  self.model, self.ae, self.t5, self.clip_model = get_models(
68
  self.model_name,
@@ -98,11 +101,12 @@ class FluxGenerator:
98
  initialization_message = f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}"
99
 
100
 
101
- # ์ง€์—ฐ ๋กœ๋”ฉ์„ ์œ„ํ•œ ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
102
- def initialize_models_in_background():
 
103
  global flux_generator, model_initialized, initialization_message
104
 
105
- print("๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
106
 
107
  try:
108
  # ์ง€์—ฐ ์ž„ํฌํŠธ
@@ -112,19 +116,18 @@ def initialize_models_in_background():
112
 
113
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
114
  flux_generator = FluxGenerator()
115
-
116
- # 30์ดˆ ํ›„์— ์ดˆ๊ธฐํ™” ์‹œ์ž‘ (UI๊ฐ€ ๋จผ์ € ๋กœ๋“œ๋˜๋„๋ก)
117
- time.sleep(30)
118
  flux_generator.initialize()
119
 
120
  model_initialized = flux_generator.initialized
121
 
122
  except Exception as e:
123
  import traceback
124
- error_msg = f"๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
125
  print(error_msg)
126
  model_initialized = False
127
  initialization_message = f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {str(e)}"
 
 
128
 
129
 
130
  # ๋ชจ๋ธ ์ƒํƒœ ํ™•์ธ ํ•จ์ˆ˜
@@ -151,7 +154,7 @@ def generate_image(
151
 
152
  # ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
153
  if not model_initialized:
154
- return None, "๋ชจ๋ธ ์ดˆ๊ธฐํ™”๊ฐ€ ์™„๋ฃŒ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
155
 
156
  # ID ์ด๋ฏธ์ง€๊ฐ€ ์—†์œผ๋ฉด ์‹คํ–‰ ๋ถˆ๊ฐ€
157
  if id_image is None:
@@ -322,6 +325,11 @@ def create_demo():
322
 
323
  # ๋ชจ๋ธ ์ƒํƒœ ํ‘œ์‹œ
324
  status_box = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", value=initialization_message)
 
 
 
 
 
325
  refresh_btn = gr.Button("์ƒํƒœ ์ƒˆ๋กœ๊ณ ์นจ")
326
  refresh_btn.click(fn=check_model_status, inputs=[], outputs=[status_box])
327
 
@@ -369,9 +377,6 @@ def create_demo():
369
  id_weight, neg_prompt, true_cfg, gamma, eta]
370
  )
371
 
372
- # ์ฃผ๊ธฐ์  ์ƒํƒœ ์—…๋ฐ์ดํŠธ ์„ค์ •
373
- demo.load(fn=check_model_status, inputs=[], outputs=[status_box], every=5) # 5์ดˆ๋งˆ๋‹ค ์—…๋ฐ์ดํŠธ
374
-
375
  # Gradio ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
376
  generate_btn.click(
377
  fn=generate_image,
@@ -396,18 +401,9 @@ if __name__ == "__main__":
396
 
397
  print("Hugging Face Spaces ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค. GPU ํ• ๋‹น์„ ์š”์ฒญํ•ฉ๋‹ˆ๋‹ค.")
398
 
399
- # UI๊ฐ€ ๋จผ์ € ๋กœ๋“œ๋˜๋„๋ก ๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘
400
- threading.Thread(target=initialize_models_in_background, daemon=True).start()
401
-
402
- # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ ์„ค์ • (ํ™˜๊ฒฝ์— ๋”ฐ๋ผ ์กฐ์ • ํ•„์š”)
403
- try:
404
- import torch.cuda
405
- if torch.cuda.is_available():
406
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ (ํ•„์š”์‹œ ์กฐ์ •)
407
- torch.cuda.set_per_process_memory_fraction(0.8) # ์ตœ๋Œ€ 80% GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ
408
- except Exception as e:
409
- print(f"๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ ์„ค์ • ์ค‘ ์˜ค๋ฅ˜: {e}")
410
 
411
  demo = create_demo()
412
  # ๋””๋ฒ„๊ทธ ๋ชจ๋“œ ํ™œ์„ฑํ™”
413
- demo.queue().launch(server_name="0.0.0.0", server_port=args.port, debug=True)
 
17
  # ๊ฐ„๋‹จํ•œ ์ธ์šฉ ์ •๋ณด ์ถ”๊ฐ€
18
  _CITE_ = """PuLID: Person-under-Language Image Diffusion Model"""
19
 
20
+ # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ ๋ฐ ์žฅ์น˜ ์„ค์ • - ๋ฉ”์ธ ํ”„๋กœ์„ธ์Šค์—์„œ๋Š” ํ˜ธ์ถœํ•˜์ง€ ์•Š์Œ
21
  def get_device():
22
  if torch.cuda.is_available():
23
  return torch.device('cuda')
 
44
 
45
  class FluxGenerator:
46
  def __init__(self):
47
+ # GPU ์ดˆ๊ธฐํ™”๋Š” Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์•ˆ์—์„œ๋งŒ ์ˆ˜ํ–‰
48
+ self.device = None # ์ดˆ๊ธฐํ™” ์‹œ์ ์—๋Š” device๋ฅผ ํ• ๋‹นํ•˜์ง€ ์•Š์Œ
49
  self.offload = False
50
  self.model_name = 'flux-dev'
51
  self.initialized = False
 
63
  from pulid.pipeline_flux import PuLIDPipeline
64
  from flux.sampling import prepare
65
 
66
+ # ์ด ์‹œ์ ์—์„œ ์žฅ์น˜ ์„ค์ • (GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ๋‚ด์—์„œ๋งŒ ํ˜ธ์ถœ๋จ)
67
+ self.device = get_device()
68
+
69
  print("๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
70
  self.model, self.ae, self.t5, self.clip_model = get_models(
71
  self.model_name,
 
101
  initialization_message = f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}"
102
 
103
 
104
+ # ์ง€์—ฐ ๋กœ๋”ฉ์„ ์œ„ํ•œ ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ - GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋กœ ๋ณ€๊ฒฝ
105
+ @spaces.GPU(duration=60)
106
+ def initialize_models():
107
  global flux_generator, model_initialized, initialization_message
108
 
109
+ print("GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ๋‚ด์—์„œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
110
 
111
  try:
112
  # ์ง€์—ฐ ์ž„ํฌํŠธ
 
116
 
117
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
118
  flux_generator = FluxGenerator()
 
 
 
119
  flux_generator.initialize()
120
 
121
  model_initialized = flux_generator.initialized
122
 
123
  except Exception as e:
124
  import traceback
125
+ error_msg = f"์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
126
  print(error_msg)
127
  model_initialized = False
128
  initialization_message = f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {str(e)}"
129
+
130
+ return initialization_message
131
 
132
 
133
  # ๋ชจ๋ธ ์ƒํƒœ ํ™•์ธ ํ•จ์ˆ˜
 
154
 
155
  # ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
156
  if not model_initialized:
157
+ return None, "๋ชจ๋ธ ์ดˆ๊ธฐํ™”๊ฐ€ ์™„๋ฃŒ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”."
158
 
159
  # ID ์ด๋ฏธ์ง€๊ฐ€ ์—†์œผ๋ฉด ์‹คํ–‰ ๋ถˆ๊ฐ€
160
  if id_image is None:
 
325
 
326
  # ๋ชจ๋ธ ์ƒํƒœ ํ‘œ์‹œ
327
  status_box = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", value=initialization_message)
328
+
329
+ # ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ์ถ”๊ฐ€ (๋ฐฑ๊ทธ๋ผ์šด๋“œ ์ดˆ๊ธฐํ™” ๋Œ€์‹  ๋ช…์‹œ์  ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ์‚ฌ์šฉ)
330
+ init_btn = gr.Button("๋ชจ๋ธ ์ดˆ๊ธฐํ™”")
331
+ init_btn.click(fn=initialize_models, inputs=[], outputs=[status_box])
332
+
333
  refresh_btn = gr.Button("์ƒํƒœ ์ƒˆ๋กœ๊ณ ์นจ")
334
  refresh_btn.click(fn=check_model_status, inputs=[], outputs=[status_box])
335
 
 
377
  id_weight, neg_prompt, true_cfg, gamma, eta]
378
  )
379
 
 
 
 
380
  # Gradio ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
381
  generate_btn.click(
382
  fn=generate_image,
 
401
 
402
  print("Hugging Face Spaces ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค. GPU ํ• ๋‹น์„ ์š”์ฒญํ•ฉ๋‹ˆ๋‹ค.")
403
 
404
+ # ๋ฉ”์ธ ํ”„๋กœ์„ธ์Šค์—์„œ๋Š” CUDA ์ดˆ๊ธฐํ™”ํ•˜์ง€ ์•Š์Œ
405
+ # ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์Šค๋ ˆ๋“œ ๋Œ€์‹  ๋ช…์‹œ์  ๋ฒ„ํŠผ์œผ๋กœ ์ดˆ๊ธฐํ™”
 
 
 
 
 
 
 
 
 
406
 
407
  demo = create_demo()
408
  # ๋””๋ฒ„๊ทธ ๋ชจ๋“œ ํ™œ์„ฑํ™”
409
+ demo.queue().launch(server_name="0.0.0.0", server_port=args.port)