seawolf2357's picture
Update app.py
94f3bc2 verified
raw
history blame
11 kB
"""
FLUX.1 Kontext Style Transfer
==============================
Updated: 2025โ€‘07โ€‘12
---------------------------------
์ด ์Šคํฌ๋ฆฝํŠธ๋Š” Huggingโ€ฏFace **FLUX.1โ€‘Kontextโ€‘dev** ๋ชจ๋ธ๊ณผ
22โ€ฏ์ข…์˜ ์Šคํƒ€์ผ LoRA ๊ฐ€์ค‘์น˜๋ฅผ ์ด์šฉํ•ด ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์–‘ํ•œ ์˜ˆ์ˆ 
์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” Gradio ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
์ฃผ์š” ๊ฐœ์„  ์‚ฌํ•ญ
--------------
1. **๋ชจ๋ธ ์บ์‹ฑ**โ€†โ€“ย `snapshot_download()`๋กœ ์‹คํ–‰ ์‹œ์ž‘ ์‹œ ํ•œ ๋ฒˆ๋งŒ
๋ชจ๋ธ๊ณผ LoRA๋ฅผ ์บ์‹ฑํ•ด ์ดํ›„ GPU ์žก์—์„œ๋„ ์žฌ๋‹ค์šด๋กœ๋“œ๊ฐ€ ์—†๋„๋ก
ํ•จ.
2. **GPUโ€ฏVRAM ์ž๋™ ํŒ๋ณ„**โ€†โ€“ย GPUโ€ฏVRAM์ด 24โ€ฏGBโ€ฏ๋ฏธ๋งŒ์ด๋ฉด
`torch.float16`ย / `enable_sequential_cpu_offload()`๋ฅผ ์ž๋™ ์ ์šฉ.
3. **๋‹จ์ผ ๋กœ๋”ฉ ๋ฉ”์‹œ์ง€**โ€†โ€“ย Gradioย `gr.Info()` ๋ฉ”์‹œ์ง€๊ฐ€ ์ตœ์ดˆ 1ํšŒ๋งŒ
ํ‘œ์‹œ๋˜๋„๋ก ์ˆ˜์ •.
4. **๋ฒ„๊ทธ ํ”ฝ์Šค**โ€†โ€“ย seed ์ฒ˜๋ฆฌ, LoRA ์–ธ๋กœ๋“œ, ์ด๋ฏธ์ง€ ๋ฆฌ์‚ฌ์ด์ฆˆ ๋กœ์ง
๋“ฑ ์„ธ๋ถ€ ์˜ค๋ฅ˜ ์ˆ˜์ •.
------------------------------------------------------------
"""
import os
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
# ------------------------------------------------------------------
# ํ™˜๊ฒฝ ์„ค์ • & ๋ชจ๋ธย /ย LoRA ์‚ฌ์ „ ๋‹ค์šด๋กœ๋“œ
# ------------------------------------------------------------------
# ํฐ ํŒŒ์ผ์„ ๋น ๋ฅด๊ฒŒ ๋ฐ›๋„๋ก ๊ฐ€์† ํ”Œ๋ž˜๊ทธ ํ™œ์„ฑํ™”
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev"
LORA_REPO = "Owen777/Kontext-Style-Loras"
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
# --- ์ตœ์ดˆ ์‹คํ–‰ ์‹œ์—๋งŒ ๋‹ค์šด๋กœ๋“œ(์ด๋ฏธ ์บ์‹œ์— ์žˆ์œผ๋ฉด ๊ฑด๋„ˆ๋œ€) ---
MODEL_DIR = snapshot_download(
repo_id=MODEL_ID,
cache_dir=CACHE_DIR,
resume_download=True,
token=True # HF ํ† ํฐ(ํ•„์š” ์‹œ ํ™˜๊ฒฝ๋ณ€์ˆ˜ HF_TOKEN ์ง€์ •)
)
LORA_DIR = snapshot_download(
repo_id=LORA_REPO,
cache_dir=CACHE_DIR,
resume_download=True,
token=True
)
# ------------------------------------------------------------------
# ์Šคํƒ€์ผย โ†’ย LoRA ํŒŒ์ผ ๋งคํ•‘ & ์„ค๋ช…
# ------------------------------------------------------------------
STYLE_LORA_MAP = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors",
}
STYLE_DESCRIPTIONS = {
"3D_Chibi": "Cute, miniature 3D character style with big heads",
"American_Cartoon": "Classic American animation style",
"Chinese_Ink": "Traditional Chinese ink painting aesthetic",
"Clay_Toy": "Playful clay/plasticine toy appearance",
"Fabric": "Soft, textile-like rendering",
"Ghibli": "Studio Ghibli's distinctive anime style",
"Irasutoya": "Simple, flat Japanese illustration style",
"Jojo": "JoJo's Bizarre Adventure manga style",
"Oil_Painting": "Classic oil painting texture and strokes",
"Pixel": "Retro pixel art style",
"Snoopy": "Peanuts comic strip style",
"Poly": "Low-poly 3D geometric style",
"LEGO": "LEGO brick construction style",
"Origami": "Paper folding art style",
"Pop_Art": "Bold, colorful pop art style",
"Van_Gogh": "Van Gogh's expressive brushstroke style",
"Paper_Cutting": "Paper cut-out art style",
"Line": "Clean line art/sketch style",
"Vector": "Clean vector graphics style",
"Picasso": "Cubist art style inspired by Picasso",
"Macaron": "Soft, pastel macaron-like style",
"Rick_Morty": "Rick and Morty cartoon style",
}
# ------------------------------------------------------------------
# ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋” (๋‹จ์ผ ์ธ์Šคํ„ด์Šค)
# ------------------------------------------------------------------
_pipeline = None # ๋‚ด๋ถ€ ๊ธ€๋กœ๋ฒŒ ์บ์‹œ
def load_pipeline():
"""Load (or return cached) FluxKontextPipeline."""
global _pipeline
if _pipeline is not None:
return _pipeline
# VRAM์ด 24โ€ฏGBย ๋ฏธ๋งŒ์ด๋ฉด FP16 ์‚ฌ์šฉ + CPU ์˜คํ”„๋กœ๋”ฉ
dtype = torch.bfloat16
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
if vram_gb < 24:
dtype = torch.float16
gr.Info("FLUX.1โ€‘Kontext ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋”ฉ ์ค‘โ€ฆย (์ตœ์ดˆ 1ํšŒ)")
pipe = FluxKontextPipeline.from_pretrained(
MODEL_DIR,
torch_dtype=dtype,
local_files_only=True,
)
pipe.to("cuda")
if vram_gb < 24:
pipe.enable_sequential_cpu_offload()
else:
pipe.enable_model_cpu_offload()
_pipeline = pipe
return _pipeline
# ------------------------------------------------------------------
# ์Šคํƒ€์ผ ๋ณ€ํ™˜ ํ•จ์ˆ˜ (Spaces GPU ์žก)
# ------------------------------------------------------------------
@spaces.GPU(duration=600)
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
"""Apply selected style to the uploaded image."""
if input_image is None:
gr.Warning("Please upload an image first!")
return None
try:
pipe = load_pipeline()
# --- Torchย Generator ์„ค์ • ---
if seed > 0:
generator = torch.Generator(device="cuda").manual_seed(int(seed))
else:
generator = None # random
# --- ์ž…๋ ฅ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ---
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image)
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
# --- LoRA ๋กœ๋“œ ---
lora_file = STYLE_LORA_MAP[style_name]
adapter_name = "style"
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name)
pipe.set_adapters([adapter_name], [1.0])
# --- ํ”„๋กฌํ”„ํŠธ ๋นŒ๋“œ ---
human_readable_style = style_name.replace("_", " ")
prompt = f"Turn this image into the {human_readable_style} style."
if prompt_suffix and prompt_suffix.strip():
prompt += f" {prompt_suffix.strip()}"
gr.Info("Generating styled imageโ€ฆ (24โ€‘60โ€ฏs)")
result = pipe(
image=img,
prompt=prompt,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
height=1024,
width=1024,
)
# --- LoRA ์–ธ๋กœ๋“œ & GPU ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ ---
pipe.unload_lora_weights(adapter_name=adapter_name)
torch.cuda.empty_cache()
return result.images[0]
except Exception as e:
torch.cuda.empty_cache()
gr.Error(f"Error during style transfer: {e}")
return None
# ------------------------------------------------------------------
# Gradio UI ์ •์˜
# ------------------------------------------------------------------
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐ŸŽจ FLUX.1 Kontext Style Transfer
์—…๋กœ๋“œํ•œ ์ด๋ฏธ์ง€๋ฅผ 22โ€ฏ์ข…์˜ ์˜ˆ์ˆ  ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜ํ•˜์„ธ์š”!
(๋ชจ๋ธโ€ฏ/โ€ฏLoRA๋Š” ์ตœ์ดˆ ์‹คํ–‰ ์‹œ์—๋งŒ ๋‹ค์šด๋กœ๋“œ๋˜๋ฉฐ, ์ดํ›„ ์‹คํ–‰์€ ๋น ๋ฆ…๋‹ˆ๋‹ค.)
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Upload Image", type="pil", height=400)
style_dropdown = gr.Dropdown(
choices=list(STYLE_LORA_MAP.keys()),
value="Ghibli",
label="Select Style",
)
style_info = gr.Textbox(label="Style Description", value=STYLE_DESCRIPTIONS["Ghibli"], interactive=False, lines=2)
prompt_suffix = gr.Textbox(label="Additional Instructions (Optional)", placeholder="e.g. add dramatic lighting", lines=2)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(minimum=10, maximum=50, value=24, step=1, label="Inference Steps")
guidance = gr.Slider(minimum=1.0, maximum=7.5, value=2.5, step=0.1, label="Guidance Scale")
seed = gr.Number(label="Seed (0 = random)", value=42)
generate_btn = gr.Button("๐ŸŽจ Transform Image", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(label="Styled Result", type="pil", height=400)
gr.Markdown("""
### ๐Ÿ’ก Tips
* ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋Š” 1024ร—1024๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ๋ฉ๋‹ˆ๋‹ค.
* ์ตœ์ดˆ 1ํšŒย ๋ชจ๋ธย +ย LoRA ๋‹ค์šด๋กœ๋“œ ํ›„์—๋Š” **์บ์‹œ**๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ 10โ€‘20โ€ฏs ๋‚ด ์™„๋ฃŒ๋ฉ๋‹ˆ๋‹ค.
* "Additional Instructions"์— ์ƒ‰๊ฐยท์กฐ๋ช…ยทํšจ๊ณผ ๋“ฑ์„ ์˜์–ด๋กœ ๊ฐ„๋‹จํžˆ ์ ์œผ๋ฉด ๊ฒฐ๊ณผ๋ฅผ ์„ธ๋ฐ€ํ•˜๊ฒŒ ์ œ์–ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
# --- ์Šคํƒ€์ผ ์„ค๋ช… ์ž๋™ ์—…๋ฐ์ดํŠธ ---
def _update_desc(style):
return STYLE_DESCRIPTIONS.get(style, "")
style_dropdown.change(fn=_update_desc, inputs=[style_dropdown], outputs=[style_info])
# --- ์˜ˆ์ œ ---
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"],
],
inputs=[input_image, style_dropdown, prompt_suffix],
outputs=output_image,
fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 2.5, 42),
cache_examples=False,
)
# --- ๋ฒ„ํŠผ ์—ฐ๊ฒฐ ---
generate_btn.click(
fn=style_transfer,
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed],
outputs=output_image,
)
gr.Markdown("""
---
**Created with โค๏ธ by GiniGEN (2025)**
""")
if __name__ == "__main__":
demo.launch(inline=False)