Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,35 @@
|
|
1 |
"""
|
2 |
-
FLUX.1
|
3 |
-
|
4 |
-
Updated:
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
--------------
|
12 |
-
1. **HF_TOKEN 환경변수 지원** – `LocalTokenNotFoundError`를
|
13 |
-
방지하기 위해 `HF_TOKEN`(또는 런타임 로그인) 값을 자동으로
|
14 |
-
감지해 `snapshot_download()`에 전달합니다.
|
15 |
-
2. **모델 캐싱** – `snapshot_download()`로 실행 시작 시 한 번만
|
16 |
-
모델과 LoRA를 캐싱.
|
17 |
-
3. **GPU VRAM 자동 판별** – 24 GB 미만이면 FP16 / CPU offload.
|
18 |
-
4. **단일 로딩 메시지** – Gradio `gr.Info()` 메시지가 최초 1회만
|
19 |
-
표시되도록 유지.
|
20 |
-
5. **버그 픽스** – seed 처리, LoRA 언로드, 이미지 리사이즈.
|
21 |
-
|
22 |
-
------------------------------------------------------------
|
23 |
"""
|
|
|
24 |
import os
|
25 |
import gradio as gr
|
26 |
import spaces
|
27 |
import torch
|
28 |
-
from huggingface_hub import snapshot_download
|
29 |
from huggingface_hub.errors import LocalTokenNotFoundError
|
30 |
from diffusers import FluxKontextPipeline
|
31 |
from diffusers.utils import load_image
|
32 |
from PIL import Image
|
33 |
|
34 |
# ------------------------------------------------------------------
|
35 |
-
# 환경 설정 & 모델
|
36 |
# ------------------------------------------------------------------
|
37 |
-
#
|
38 |
-
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
39 |
|
40 |
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev"
|
41 |
LORA_REPO = "Owen777/Kontext-Style-Loras"
|
42 |
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
|
|
43 |
|
44 |
-
# --- HF 토큰 처리 --------------------------------------------------
|
45 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # 런타임에 환경변수로 주입하거나
|
46 |
-
# docker / Space 설정의 Secrets → HF_TOKEN 로 등록
|
47 |
|
48 |
def _download_with_token(repo_id: str) -> str:
|
49 |
"""Download repo snapshot with optional token handling."""
|
@@ -52,16 +38,16 @@ def _download_with_token(repo_id: str) -> str:
|
|
52 |
repo_id=repo_id,
|
53 |
cache_dir=CACHE_DIR,
|
54 |
resume_download=True,
|
55 |
-
token=HF_TOKEN if HF_TOKEN else True, # True →
|
56 |
)
|
57 |
except LocalTokenNotFoundError:
|
58 |
-
# 미로그인 + 필수 동의 모델이면 에러 메시지 출력 후 종료
|
59 |
raise RuntimeError(
|
60 |
"Hugging Face 토큰이 필요합니다. 환경변수 HF_TOKEN을 설정하거나\n"
|
61 |
"`huggingface-cli login`으로 로그인해 주세요."
|
62 |
)
|
63 |
|
64 |
-
|
|
|
65 |
MODEL_DIR = _download_with_token(MODEL_ID)
|
66 |
LORA_DIR = _download_with_token(LORA_REPO)
|
67 |
|
@@ -119,21 +105,20 @@ STYLE_DESCRIPTIONS = {
|
|
119 |
}
|
120 |
|
121 |
# ------------------------------------------------------------------
|
122 |
-
# 파이프라인 로더 (
|
123 |
# ------------------------------------------------------------------
|
124 |
-
_pipeline = None
|
|
|
125 |
|
126 |
def load_pipeline():
|
127 |
-
"""Load
|
128 |
global _pipeline
|
129 |
if _pipeline is not None:
|
130 |
return _pipeline
|
131 |
|
132 |
-
# VRAM
|
133 |
-
dtype = torch.bfloat16
|
134 |
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
135 |
-
if vram_gb
|
136 |
-
dtype = torch.float16
|
137 |
|
138 |
gr.Info("FLUX.1‑Kontext 파이프라인 로딩 중… (최초 1회)")
|
139 |
|
@@ -141,9 +126,7 @@ def load_pipeline():
|
|
141 |
MODEL_DIR,
|
142 |
torch_dtype=dtype,
|
143 |
local_files_only=True,
|
144 |
-
)
|
145 |
-
|
146 |
-
pipe.to("cuda")
|
147 |
|
148 |
if vram_gb < 24:
|
149 |
pipe.enable_sequential_cpu_offload()
|
@@ -154,9 +137,10 @@ def load_pipeline():
|
|
154 |
return _pipeline
|
155 |
|
156 |
# ------------------------------------------------------------------
|
157 |
-
# 스타일 변환 함수
|
158 |
# ------------------------------------------------------------------
|
159 |
@spaces.GPU(duration=600)
|
|
|
160 |
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
|
161 |
"""Apply selected style to the uploaded image."""
|
162 |
if input_image is None:
|
@@ -166,28 +150,28 @@ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps,
|
|
166 |
try:
|
167 |
pipe = load_pipeline()
|
168 |
|
169 |
-
#
|
170 |
generator = None
|
171 |
if seed and int(seed) > 0:
|
172 |
generator = torch.Generator(device="cuda").manual_seed(int(seed))
|
173 |
|
174 |
-
#
|
175 |
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image)
|
176 |
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
|
177 |
|
178 |
-
#
|
179 |
lora_file = STYLE_LORA_MAP[style_name]
|
180 |
adapter_name = "style"
|
181 |
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name)
|
182 |
pipe.set_adapters([adapter_name], [1.0])
|
183 |
|
184 |
-
#
|
185 |
-
|
186 |
-
prompt = f"Turn this image into the {
|
187 |
if prompt_suffix and prompt_suffix.strip():
|
188 |
prompt += f" {prompt_suffix.strip()}"
|
189 |
|
190 |
-
gr.Info("Generating styled image…
|
191 |
|
192 |
result = pipe(
|
193 |
image=img,
|
@@ -199,7 +183,7 @@ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps,
|
|
199 |
width=1024,
|
200 |
)
|
201 |
|
202 |
-
#
|
203 |
pipe.unload_lora_weights(adapter_name=adapter_name)
|
204 |
torch.cuda.empty_cache()
|
205 |
|
@@ -211,259 +195,23 @@ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps,
|
|
211 |
return None
|
212 |
|
213 |
# ------------------------------------------------------------------
|
214 |
-
# Gradio UI
|
215 |
# ------------------------------------------------------------------
|
216 |
-
with gr.Blocks(title="FLUX.1 Context Style Transfer", theme=gr.themes.Soft()) as demo:
|
217 |
-
gr.Markdown("""
|
218 |
-
# 🎨 FLUX.1 Kontext Style Transfer
|
219 |
|
220 |
-
|
221 |
-
(
|
222 |
-
""")
|
223 |
-
|
224 |
-
with gr.Row():
|
225 |
-
with gr.Column(scale=1):
|
226 |
-
input_image = gr.Image(label="Upload Image", type="pil", height=400)
|
227 |
-
style_dropdown = gr.Dropdown(
|
228 |
-
choices=list(STYLE_LORA_MAP.keys()),
|
229 |
-
value="Ghibli",
|
230 |
-
label="Select Style",
|
231 |
-
)
|
232 |
-
style_info = gr.Textbox(label="Style Description", value=STYLE_DESCRIPTIONS["Ghibli"], interactive=False, lines=2)
|
233 |
-
prompt_suffix = gr.Textbox(label="Additional Instructions (Optional)", placeholder="e.g. add dramatic lighting", lines=2)
|
234 |
-
|
235 |
-
with gr.Accordion("Advanced Settings", open=False):
|
236 |
-
num_steps = gr.Slider(minimum=10, maximum=50, value=24, step=1, label="Inference Steps")
|
237 |
-
guidance = gr.Slider(minimum=1.0, maximum=7.5, value=2.5, step=0.1, label="Guidance Scale")
|
238 |
-
seed = gr.Number(label="Seed (0 = random)", value=42)
|
239 |
-
|
240 |
-
generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg")
|
241 |
-
|
242 |
-
with gr.Column(scale=1):
|
243 |
-
output_image = gr.Image(label="Styled Result", type="pil", height=
|
244 |
-
"""
|
245 |
-
FLUX.1 Kontext Style Transfer
|
246 |
-
==============================
|
247 |
-
Updated: 2025‑07‑12 (HF_TOKEN 지원)
|
248 |
-
---------------------------------
|
249 |
-
이 스크립트는 Hugging Face **FLUX.1‑Kontext‑dev** 모델과
|
250 |
-
22 종의 스타일 LoRA 가중치를 이용해 이미지를 다양한 예술
|
251 |
-
스타일로 변환하는 Gradio 데모입니다.
|
252 |
-
|
253 |
-
주요 개선 사항
|
254 |
-
--------------
|
255 |
-
1. **HF_TOKEN 환경변수 지원** – `LocalTokenNotFoundError`를
|
256 |
-
방지하기 위해 `HF_TOKEN`(또는 런타임 로그인) 값을 자동으로
|
257 |
-
감지해 `snapshot_download()`에 전달합니다.
|
258 |
-
2. **모델 캐싱** – `snapshot_download()`로 실행 시작 시 한 번만
|
259 |
-
모델과 LoRA를 캐싱.
|
260 |
-
3. **GPU VRAM 자동 판별** – 24 GB 미만이면 FP16 / CPU offload.
|
261 |
-
4. **단일 로딩 메시지** – Gradio `gr.Info()` 메시지가 최초 1회만
|
262 |
-
표시되도록 유지.
|
263 |
-
5. **버그 픽스** – seed 처리, LoRA 언로드, 이미지 리사이즈.
|
264 |
-
|
265 |
-
------------------------------------------------------------
|
266 |
-
"""
|
267 |
-
import os
|
268 |
-
import gradio as gr
|
269 |
-
import spaces
|
270 |
-
import torch
|
271 |
-
from huggingface_hub import snapshot_download, login as hf_login
|
272 |
-
from huggingface_hub.errors import LocalTokenNotFoundError
|
273 |
-
from diffusers import FluxKontextPipeline
|
274 |
-
from diffusers.utils import load_image
|
275 |
-
from PIL import Image
|
276 |
-
|
277 |
-
# ------------------------------------------------------------------
|
278 |
-
# 환경 설정 & 모델 / LoRA 사전 다운로드
|
279 |
-
# ------------------------------------------------------------------
|
280 |
-
# 큰 파일을 빠르게 받도록 가속 플래그 활성화
|
281 |
-
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
282 |
-
|
283 |
-
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev"
|
284 |
-
LORA_REPO = "Owen777/Kontext-Style-Loras"
|
285 |
-
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
286 |
-
|
287 |
-
# --- HF 토큰 처리 --------------------------------------------------
|
288 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # 런타임에 환경변수로 주입하거나
|
289 |
-
# docker / Space 설정의 Secrets → HF_TOKEN 로 등록
|
290 |
-
|
291 |
-
def _download_with_token(repo_id: str) -> str:
|
292 |
-
"""Download repo snapshot with optional token handling."""
|
293 |
-
try:
|
294 |
-
return snapshot_download(
|
295 |
-
repo_id=repo_id,
|
296 |
-
cache_dir=CACHE_DIR,
|
297 |
-
resume_download=True,
|
298 |
-
token=HF_TOKEN if HF_TOKEN else True, # True → HF_CACHED_TOKEN(or login)
|
299 |
-
)
|
300 |
-
except LocalTokenNotFoundError:
|
301 |
-
# 미로그인 + 필수 동의 모델이면 에러 메시지 출력 후 종료
|
302 |
-
raise RuntimeError(
|
303 |
-
"Hugging Face 토큰이 필요합니다. 환경변수 HF_TOKEN을 설정하거나\n"
|
304 |
-
"`huggingface-cli login`으로 로그인해 주세요."
|
305 |
-
)
|
306 |
|
307 |
-
# --- 최초 실행 시에만 다운로드(이미 캐시에 있으면 건너뜀) ---
|
308 |
-
MODEL_DIR = _download_with_token(MODEL_ID)
|
309 |
-
LORA_DIR = _download_with_token(LORA_REPO)
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
|
316 |
-
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
|
317 |
-
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
|
318 |
-
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
|
319 |
-
"Fabric": "Fabric_lora_weights.safetensors",
|
320 |
-
"Ghibli": "Ghibli_lora_weights.safetensors",
|
321 |
-
"Irasutoya": "Irasutoya_lora_weights.safetensors",
|
322 |
-
"Jojo": "Jojo_lora_weights.safetensors",
|
323 |
-
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
|
324 |
-
"Pixel": "Pixel_lora_weights.safetensors",
|
325 |
-
"Snoopy": "Snoopy_lora_weights.safetensors",
|
326 |
-
"Poly": "Poly_lora_weights.safetensors",
|
327 |
-
"LEGO": "LEGO_lora_weights.safetensors",
|
328 |
-
"Origami": "Origami_lora_weights.safetensors",
|
329 |
-
"Pop_Art": "Pop_Art_lora_weights.safetensors",
|
330 |
-
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
|
331 |
-
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
|
332 |
-
"Line": "Line_lora_weights.safetensors",
|
333 |
-
"Vector": "Vector_lora_weights.safetensors",
|
334 |
-
"Picasso": "Picasso_lora_weights.safetensors",
|
335 |
-
"Macaron": "Macaron_lora_weights.safetensors",
|
336 |
-
"Rick_Morty": "Rick_Morty_lora_weights.safetensors",
|
337 |
-
}
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
"Chinese_Ink": "Traditional Chinese ink painting aesthetic",
|
343 |
-
"Clay_Toy": "Playful clay/plasticine toy appearance",
|
344 |
-
"Fabric": "Soft, textile-like rendering",
|
345 |
-
"Ghibli": "Studio Ghibli's distinctive anime style",
|
346 |
-
"Irasutoya": "Simple, flat Japanese illustration style",
|
347 |
-
"Jojo": "JoJo's Bizarre Adventure manga style",
|
348 |
-
"Oil_Painting": "Classic oil painting texture and strokes",
|
349 |
-
"Pixel": "Retro pixel art style",
|
350 |
-
"Snoopy": "Peanuts comic strip style",
|
351 |
-
"Poly": "Low-poly 3D geometric style",
|
352 |
-
"LEGO": "LEGO brick construction style",
|
353 |
-
"Origami": "Paper folding art style",
|
354 |
-
"Pop_Art": "Bold, colorful pop art style",
|
355 |
-
"Van_Gogh": "Van Gogh's expressive brushstroke style",
|
356 |
-
"Paper_Cutting": "Paper cut-out art style",
|
357 |
-
"Line": "Clean line art/sketch style",
|
358 |
-
"Vector": "Clean vector graphics style",
|
359 |
-
"Picasso": "Cubist art style inspired by Picasso",
|
360 |
-
"Macaron": "Soft, pastel macaron-like style",
|
361 |
-
"Rick_Morty": "Rick and Morty cartoon style",
|
362 |
-
}
|
363 |
-
|
364 |
-
# ------------------------------------------------------------------
|
365 |
-
# 파이프라인 로더 (단일 인스턴스)
|
366 |
-
# ------------------------------------------------------------------
|
367 |
-
_pipeline = None # 내부 글로벌 캐시
|
368 |
-
|
369 |
-
def load_pipeline():
|
370 |
-
"""Load (or return cached) FluxKontextPipeline."""
|
371 |
-
global _pipeline
|
372 |
-
if _pipeline is not None:
|
373 |
-
return _pipeline
|
374 |
-
|
375 |
-
# VRAM이 24 GB 미만이면 FP16 사용 + CPU 오프로딩
|
376 |
-
dtype = torch.bfloat16
|
377 |
-
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
378 |
-
if vram_gb < 24:
|
379 |
-
dtype = torch.float16
|
380 |
-
|
381 |
-
gr.Info("FLUX.1‑Kontext 파이프라인 로딩 중… (최초 1회)")
|
382 |
-
|
383 |
-
pipe = FluxKontextPipeline.from_pretrained(
|
384 |
-
MODEL_DIR,
|
385 |
-
torch_dtype=dtype,
|
386 |
-
local_files_only=True,
|
387 |
)
|
388 |
|
389 |
-
pipe.to("cuda")
|
390 |
-
|
391 |
-
if vram_gb < 24:
|
392 |
-
pipe.enable_sequential_cpu_offload()
|
393 |
-
else:
|
394 |
-
pipe.enable_model_cpu_offload()
|
395 |
-
|
396 |
-
_pipeline = pipe
|
397 |
-
return _pipeline
|
398 |
-
|
399 |
-
# ------------------------------------------------------------------
|
400 |
-
# 스타일 변환 함수 (Spaces GPU 잡)
|
401 |
-
# ------------------------------------------------------------------
|
402 |
-
@spaces.GPU(duration=600)
|
403 |
-
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
|
404 |
-
"""Apply selected style to the uploaded image."""
|
405 |
-
if input_image is None:
|
406 |
-
gr.Warning("Please upload an image first!")
|
407 |
-
return None
|
408 |
-
|
409 |
-
try:
|
410 |
-
pipe = load_pipeline()
|
411 |
-
|
412 |
-
# --- Torch Generator 설정 ---
|
413 |
-
generator = None
|
414 |
-
if seed and int(seed) > 0:
|
415 |
-
generator = torch.Generator(device="cuda").manual_seed(int(seed))
|
416 |
-
|
417 |
-
# --- 입력 이미지 전처리 ---
|
418 |
-
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image)
|
419 |
-
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
|
420 |
-
|
421 |
-
# --- LoRA 로드 ---
|
422 |
-
lora_file = STYLE_LORA_MAP[style_name]
|
423 |
-
adapter_name = "style"
|
424 |
-
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name)
|
425 |
-
pipe.set_adapters([adapter_name], [1.0])
|
426 |
-
|
427 |
-
# --- 프롬프트 빌드 ---
|
428 |
-
human_readable_style = style_name.replace("_", " ")
|
429 |
-
prompt = f"Turn this image into the {human_readable_style} style."
|
430 |
-
if prompt_suffix and prompt_suffix.strip():
|
431 |
-
prompt += f" {prompt_suffix.strip()}"
|
432 |
-
|
433 |
-
gr.Info("Generating styled image… (20‑60 s)")
|
434 |
-
|
435 |
-
result = pipe(
|
436 |
-
image=img,
|
437 |
-
prompt=prompt,
|
438 |
-
guidance_scale=float(guidance_scale),
|
439 |
-
num_inference_steps=int(num_inference_steps),
|
440 |
-
generator=generator,
|
441 |
-
height=1024,
|
442 |
-
width=1024,
|
443 |
-
)
|
444 |
-
|
445 |
-
# --- LoRA 언로드 & GPU 메모리 해제 ---
|
446 |
-
pipe.unload_lora_weights(adapter_name=adapter_name)
|
447 |
-
torch.cuda.empty_cache()
|
448 |
-
|
449 |
-
return result.images[0]
|
450 |
-
|
451 |
-
except Exception as e:
|
452 |
-
torch.cuda.empty_cache()
|
453 |
-
gr.Error(f"Error during style transfer: {e}")
|
454 |
-
return None
|
455 |
-
|
456 |
-
# ------------------------------------------------------------------
|
457 |
-
# Gradio UI 정의
|
458 |
-
# ------------------------------------------------------------------
|
459 |
-
with gr.Blocks(title="FLUX.1 Context Style Transfer", theme=gr.themes.Soft()) as demo:
|
460 |
-
gr.Markdown("""
|
461 |
-
# 🎨 FLUX.1 Kontext Style Transfer
|
462 |
-
|
463 |
-
업로드한 이미지를 22 종의 예술 스타일로 변환하세요!
|
464 |
-
(모델 / LoRA는 최초 실행 시에만 다운로드되며, 이후 실행은 빠릅니다.)
|
465 |
-
""")
|
466 |
-
|
467 |
with gr.Row():
|
468 |
with gr.Column(scale=1):
|
469 |
input_image = gr.Image(label="Upload Image", type="pil", height=400)
|
@@ -472,15 +220,62 @@ with gr.Blocks(title="FLUX.1 Context Style Transfer", theme=gr.themes.Soft()) as
|
|
472 |
value="Ghibli",
|
473 |
label="Select Style",
|
474 |
)
|
475 |
-
style_info = gr.Textbox(
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
|
478 |
with gr.Accordion("Advanced Settings", open=False):
|
479 |
-
num_steps = gr.Slider(
|
480 |
-
guidance = gr.Slider(
|
481 |
seed = gr.Number(label="Seed (0 = random)", value=42)
|
482 |
|
483 |
generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg")
|
484 |
|
485 |
with gr.Column(scale=1):
|
486 |
-
output_image = gr.Image(label="Styled Result", type="pil", height=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
FLUX.1 Kontext Style Transfer
|
3 |
+
=============================
|
4 |
+
Updated : 2025‑07‑12 (HF_TOKEN 지원 + 전체 코드 완성)
|
5 |
+
---------------------------------------------------
|
6 |
+
Gradio 데모로 이미지를 22 종 예술 스타일로 변환합니다.
|
7 |
+
- **HF_TOKEN** 환경변수를 자동 인식해 라이선스 모델 다운로드 오류를 방지합니다.
|
8 |
+
- 최초 실행 시 모델·LoRA를 캐시에 받아 두고, 이후에는 재다운로드가 없습니다.
|
9 |
+
- GPU VRAM을 감지하여 24 GB 미만에서는 FP16 + CPU offload를 사용합니다.
|
10 |
+
- 파이프라인·LoRA 로딩 메시지는 최초 1회만 표시됩니다.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
+
|
13 |
import os
|
14 |
import gradio as gr
|
15 |
import spaces
|
16 |
import torch
|
17 |
+
from huggingface_hub import snapshot_download
|
18 |
from huggingface_hub.errors import LocalTokenNotFoundError
|
19 |
from diffusers import FluxKontextPipeline
|
20 |
from diffusers.utils import load_image
|
21 |
from PIL import Image
|
22 |
|
23 |
# ------------------------------------------------------------------
|
24 |
+
# 환경 설정 & 모델 / LoRA 사전 다운로드
|
25 |
# ------------------------------------------------------------------
|
26 |
+
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # 빠른 다운로드
|
|
|
27 |
|
28 |
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev"
|
29 |
LORA_REPO = "Owen777/Kontext-Style-Loras"
|
30 |
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
31 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # 런타임에 주입하거나 Secrets 사용
|
32 |
|
|
|
|
|
|
|
33 |
|
34 |
def _download_with_token(repo_id: str) -> str:
|
35 |
"""Download repo snapshot with optional token handling."""
|
|
|
38 |
repo_id=repo_id,
|
39 |
cache_dir=CACHE_DIR,
|
40 |
resume_download=True,
|
41 |
+
token=HF_TOKEN if HF_TOKEN else True, # True → 로그인 세션 사용
|
42 |
)
|
43 |
except LocalTokenNotFoundError:
|
|
|
44 |
raise RuntimeError(
|
45 |
"Hugging Face 토큰이 필요합니다. 환경변수 HF_TOKEN을 설정하거나\n"
|
46 |
"`huggingface-cli login`으로 로그인해 주세요."
|
47 |
)
|
48 |
|
49 |
+
|
50 |
+
# 최초 실행 시 캐시에만 다운로드 (이미 있으면 즉시 반환)
|
51 |
MODEL_DIR = _download_with_token(MODEL_ID)
|
52 |
LORA_DIR = _download_with_token(LORA_REPO)
|
53 |
|
|
|
105 |
}
|
106 |
|
107 |
# ------------------------------------------------------------------
|
108 |
+
# 파이프라인 로더 (싱글턴)
|
109 |
# ------------------------------------------------------------------
|
110 |
+
_pipeline = None
|
111 |
+
|
112 |
|
113 |
def load_pipeline():
|
114 |
+
"""Load or return cached FluxKontextPipeline."""
|
115 |
global _pipeline
|
116 |
if _pipeline is not None:
|
117 |
return _pipeline
|
118 |
|
119 |
+
# VRAM 판별 → dtype & offload 설정
|
|
|
120 |
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
121 |
+
dtype = torch.bfloat16 if vram_gb >= 24 else torch.float16
|
|
|
122 |
|
123 |
gr.Info("FLUX.1‑Kontext 파이프라인 로딩 중… (최초 1회)")
|
124 |
|
|
|
126 |
MODEL_DIR,
|
127 |
torch_dtype=dtype,
|
128 |
local_files_only=True,
|
129 |
+
).to("cuda")
|
|
|
|
|
130 |
|
131 |
if vram_gb < 24:
|
132 |
pipe.enable_sequential_cpu_offload()
|
|
|
137 |
return _pipeline
|
138 |
|
139 |
# ------------------------------------------------------------------
|
140 |
+
# 스타일 변환 함수
|
141 |
# ------------------------------------------------------------------
|
142 |
@spaces.GPU(duration=600)
|
143 |
+
|
144 |
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
|
145 |
"""Apply selected style to the uploaded image."""
|
146 |
if input_image is None:
|
|
|
150 |
try:
|
151 |
pipe = load_pipeline()
|
152 |
|
153 |
+
# Torch Generator (seed 고정 시 재현 가능)
|
154 |
generator = None
|
155 |
if seed and int(seed) > 0:
|
156 |
generator = torch.Generator(device="cuda").manual_seed(int(seed))
|
157 |
|
158 |
+
# 입력 이미지 전처리
|
159 |
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image)
|
160 |
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
|
161 |
|
162 |
+
# LoRA 로드
|
163 |
lora_file = STYLE_LORA_MAP[style_name]
|
164 |
adapter_name = "style"
|
165 |
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name)
|
166 |
pipe.set_adapters([adapter_name], [1.0])
|
167 |
|
168 |
+
# 프롬프트 구성
|
169 |
+
readable_style = style_name.replace("_", " ")
|
170 |
+
prompt = f"Turn this image into the {readable_style} style."
|
171 |
if prompt_suffix and prompt_suffix.strip():
|
172 |
prompt += f" {prompt_suffix.strip()}"
|
173 |
|
174 |
+
gr.Info("Generating styled image… (20‑60 s)")
|
175 |
|
176 |
result = pipe(
|
177 |
image=img,
|
|
|
183 |
width=1024,
|
184 |
)
|
185 |
|
186 |
+
# LoRA 언로드 & 메모리 정리
|
187 |
pipe.unload_lora_weights(adapter_name=adapter_name)
|
188 |
torch.cuda.empty_cache()
|
189 |
|
|
|
195 |
return None
|
196 |
|
197 |
# ------------------------------------------------------------------
|
198 |
+
# Gradio UI
|
199 |
# ------------------------------------------------------------------
|
|
|
|
|
|
|
200 |
|
201 |
+
def update_description(style):
|
202 |
+
return STYLE_DESCRIPTIONS.get(style, "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
|
|
|
|
|
|
204 |
|
205 |
+
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
|
206 |
+
gr.Markdown(
|
207 |
+
"""
|
208 |
+
# 🎨 FLUX.1 Kontext Style Transfer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
업로드한 이미지를 22 종 예술 스타일로 변환하세요!
|
211 |
+
(모델 / LoRA는 최초 실행 시에만 다운로드되며, 이후 실행은 빠릅니다.)
|
212 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
)
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
with gr.Row():
|
216 |
with gr.Column(scale=1):
|
217 |
input_image = gr.Image(label="Upload Image", type="pil", height=400)
|
|
|
220 |
value="Ghibli",
|
221 |
label="Select Style",
|
222 |
)
|
223 |
+
style_info = gr.Textbox(
|
224 |
+
label="Style Description",
|
225 |
+
value=STYLE_DESCRIPTIONS["Ghibli"],
|
226 |
+
interactive=False,
|
227 |
+
lines=2,
|
228 |
+
)
|
229 |
+
prompt_suffix = gr.Textbox(
|
230 |
+
label="Additional Instructions (Optional)",
|
231 |
+
placeholder="e.g. add dramatic lighting",
|
232 |
+
lines=2,
|
233 |
+
)
|
234 |
|
235 |
with gr.Accordion("Advanced Settings", open=False):
|
236 |
+
num_steps = gr.Slider(10, 50, value=24, step=1, label="Inference Steps")
|
237 |
+
guidance = gr.Slider(1.0, 7.5, value=2.5, step=0.1, label="Guidance Scale")
|
238 |
seed = gr.Number(label="Seed (0 = random)", value=42)
|
239 |
|
240 |
generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg")
|
241 |
|
242 |
with gr.Column(scale=1):
|
243 |
+
output_image = gr.Image(label="Styled Result", type="pil", height=400)
|
244 |
+
gr.Markdown(
|
245 |
+
"""
|
246 |
+
### 💡 Tips
|
247 |
+
- 모델(7 GB)·LoRA는 최초 실행 시에만 다운로드됩니다.
|
248 |
+
- 이미지는 1024×1024로 리사이즈 후 처리됩니다.
|
249 |
+
- VRAM < 24 GB인 경우 자동으로 FP16 + CPU offload가 적용됩니다.
|
250 |
+
- seed 값을 변경해 다양한 결과를 얻어 보세요!
|
251 |
+
"""
|
252 |
+
)
|
253 |
+
|
254 |
+
# 스타일 설명 자동 업데이트
|
255 |
+
style_dropdown.change(update_description, inputs=[style_dropdown], outputs=[style_info])
|
256 |
+
|
257 |
+
# 예제 샘플
|
258 |
+
gr.Examples(
|
259 |
+
examples=[
|
260 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""],
|
261 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"],
|
262 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"],
|
263 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Pixel", "8-bit retro game style"],
|
264 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Chinese_Ink", "mountain landscape"],
|
265 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "LEGO", "colorful blocks"],
|
266 |
+
],
|
267 |
+
inputs=[input_image, style_dropdown, prompt_suffix],
|
268 |
+
outputs=output_image,
|
269 |
+
fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 2.5, 42),
|
270 |
+
cache_examples=False,
|
271 |
+
)
|
272 |
+
|
273 |
+
# 버튼 클릭 연결
|
274 |
+
generate_btn.click(
|
275 |
+
fn=style_transfer,
|
276 |
+
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed],
|
277 |
+
outputs=output_image,
|
278 |
+
)
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
demo.launch()
|