Spaces:
Running
on
Zero
Running
on
Zero
""" | |
FLUX.1 Kontext Style Transfer | |
============================== | |
Updated: 2025โ07โ12 | |
--------------------------------- | |
์ด ์คํฌ๋ฆฝํธ๋ HuggingโฏFace **FLUX.1โKontextโdev** ๋ชจ๋ธ๊ณผ | |
22โฏ์ข ์ ์คํ์ผ LoRA ๊ฐ์ค์น๋ฅผ ์ด์ฉํด ์ด๋ฏธ์ง๋ฅผ ๋ค์ํ ์์ | |
์คํ์ผ๋ก ๋ณํํ๋ Gradio ๋ฐ๋ชจ์ ๋๋ค. | |
์ฃผ์ ๊ฐ์ ์ฌํญ | |
-------------- | |
1. **๋ชจ๋ธ ์บ์ฑ**โโย `snapshot_download()`๋ก ์คํ ์์ ์ ํ ๋ฒ๋ง | |
๋ชจ๋ธ๊ณผ LoRA๋ฅผ ์บ์ฑํด ์ดํ GPU ์ก์์๋ ์ฌ๋ค์ด๋ก๋๊ฐ ์๋๋ก | |
ํจ. | |
2. **GPUโฏVRAM ์๋ ํ๋ณ**โโย GPUโฏVRAM์ด 24โฏGBโฏ๋ฏธ๋ง์ด๋ฉด | |
`torch.float16`ย / `enable_sequential_cpu_offload()`๋ฅผ ์๋ ์ ์ฉ. | |
3. **๋จ์ผ ๋ก๋ฉ ๋ฉ์์ง**โโย Gradioย `gr.Info()` ๋ฉ์์ง๊ฐ ์ต์ด 1ํ๋ง | |
ํ์๋๋๋ก ์์ . | |
4. **๋ฒ๊ทธ ํฝ์ค**โโย seed ์ฒ๋ฆฌ, LoRA ์ธ๋ก๋, ์ด๋ฏธ์ง ๋ฆฌ์ฌ์ด์ฆ ๋ก์ง | |
๋ฑ ์ธ๋ถ ์ค๋ฅ ์์ . | |
------------------------------------------------------------ | |
""" | |
import os | |
import gradio as gr | |
import spaces | |
import torch | |
from huggingface_hub import snapshot_download | |
from diffusers import FluxKontextPipeline | |
from diffusers.utils import load_image | |
from PIL import Image | |
# ------------------------------------------------------------------ | |
# ํ๊ฒฝ ์ค์ & ๋ชจ๋ธย /ย LoRA ์ฌ์ ๋ค์ด๋ก๋ | |
# ------------------------------------------------------------------ | |
# ํฐ ํ์ผ์ ๋น ๋ฅด๊ฒ ๋ฐ๋๋ก ๊ฐ์ ํ๋๊ทธ ํ์ฑํ | |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev" | |
LORA_REPO = "Owen777/Kontext-Style-Loras" | |
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) | |
# --- ์ต์ด ์คํ ์์๋ง ๋ค์ด๋ก๋(์ด๋ฏธ ์บ์์ ์์ผ๋ฉด ๊ฑด๋๋) --- | |
MODEL_DIR = snapshot_download( | |
repo_id=MODEL_ID, | |
cache_dir=CACHE_DIR, | |
resume_download=True, | |
token=True # HF ํ ํฐ(ํ์ ์ ํ๊ฒฝ๋ณ์ HF_TOKEN ์ง์ ) | |
) | |
LORA_DIR = snapshot_download( | |
repo_id=LORA_REPO, | |
cache_dir=CACHE_DIR, | |
resume_download=True, | |
token=True | |
) | |
# ------------------------------------------------------------------ | |
# ์คํ์ผย โย LoRA ํ์ผ ๋งคํ & ์ค๋ช | |
# ------------------------------------------------------------------ | |
STYLE_LORA_MAP = { | |
"3D_Chibi": "3D_Chibi_lora_weights.safetensors", | |
"American_Cartoon": "American_Cartoon_lora_weights.safetensors", | |
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", | |
"Clay_Toy": "Clay_Toy_lora_weights.safetensors", | |
"Fabric": "Fabric_lora_weights.safetensors", | |
"Ghibli": "Ghibli_lora_weights.safetensors", | |
"Irasutoya": "Irasutoya_lora_weights.safetensors", | |
"Jojo": "Jojo_lora_weights.safetensors", | |
"Oil_Painting": "Oil_Painting_lora_weights.safetensors", | |
"Pixel": "Pixel_lora_weights.safetensors", | |
"Snoopy": "Snoopy_lora_weights.safetensors", | |
"Poly": "Poly_lora_weights.safetensors", | |
"LEGO": "LEGO_lora_weights.safetensors", | |
"Origami": "Origami_lora_weights.safetensors", | |
"Pop_Art": "Pop_Art_lora_weights.safetensors", | |
"Van_Gogh": "Van_Gogh_lora_weights.safetensors", | |
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", | |
"Line": "Line_lora_weights.safetensors", | |
"Vector": "Vector_lora_weights.safetensors", | |
"Picasso": "Picasso_lora_weights.safetensors", | |
"Macaron": "Macaron_lora_weights.safetensors", | |
"Rick_Morty": "Rick_Morty_lora_weights.safetensors", | |
} | |
STYLE_DESCRIPTIONS = { | |
"3D_Chibi": "Cute, miniature 3D character style with big heads", | |
"American_Cartoon": "Classic American animation style", | |
"Chinese_Ink": "Traditional Chinese ink painting aesthetic", | |
"Clay_Toy": "Playful clay/plasticine toy appearance", | |
"Fabric": "Soft, textile-like rendering", | |
"Ghibli": "Studio Ghibli's distinctive anime style", | |
"Irasutoya": "Simple, flat Japanese illustration style", | |
"Jojo": "JoJo's Bizarre Adventure manga style", | |
"Oil_Painting": "Classic oil painting texture and strokes", | |
"Pixel": "Retro pixel art style", | |
"Snoopy": "Peanuts comic strip style", | |
"Poly": "Low-poly 3D geometric style", | |
"LEGO": "LEGO brick construction style", | |
"Origami": "Paper folding art style", | |
"Pop_Art": "Bold, colorful pop art style", | |
"Van_Gogh": "Van Gogh's expressive brushstroke style", | |
"Paper_Cutting": "Paper cut-out art style", | |
"Line": "Clean line art/sketch style", | |
"Vector": "Clean vector graphics style", | |
"Picasso": "Cubist art style inspired by Picasso", | |
"Macaron": "Soft, pastel macaron-like style", | |
"Rick_Morty": "Rick and Morty cartoon style", | |
} | |
# ------------------------------------------------------------------ | |
# ํ์ดํ๋ผ์ธ ๋ก๋ (๋จ์ผ ์ธ์คํด์ค) | |
# ------------------------------------------------------------------ | |
_pipeline = None # ๋ด๋ถ ๊ธ๋ก๋ฒ ์บ์ | |
def load_pipeline(): | |
"""Load (or return cached) FluxKontextPipeline.""" | |
global _pipeline | |
if _pipeline is not None: | |
return _pipeline | |
# VRAM์ด 24โฏGBย ๋ฏธ๋ง์ด๋ฉด FP16 ์ฌ์ฉ + CPU ์คํ๋ก๋ฉ | |
dtype = torch.bfloat16 | |
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
if vram_gb < 24: | |
dtype = torch.float16 | |
gr.Info("FLUX.1โKontext ํ์ดํ๋ผ์ธ ๋ก๋ฉ ์คโฆย (์ต์ด 1ํ)") | |
pipe = FluxKontextPipeline.from_pretrained( | |
MODEL_DIR, | |
torch_dtype=dtype, | |
local_files_only=True, | |
) | |
pipe.to("cuda") | |
if vram_gb < 24: | |
pipe.enable_sequential_cpu_offload() | |
else: | |
pipe.enable_model_cpu_offload() | |
_pipeline = pipe | |
return _pipeline | |
# ------------------------------------------------------------------ | |
# ์คํ์ผ ๋ณํ ํจ์ (Spaces GPU ์ก) | |
# ------------------------------------------------------------------ | |
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed): | |
"""Apply selected style to the uploaded image.""" | |
if input_image is None: | |
gr.Warning("Please upload an image first!") | |
return None | |
try: | |
pipe = load_pipeline() | |
# --- Torchย Generator ์ค์ --- | |
if seed > 0: | |
generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
else: | |
generator = None # random | |
# --- ์ ๋ ฅ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ --- | |
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image) | |
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS) | |
# --- LoRA ๋ก๋ --- | |
lora_file = STYLE_LORA_MAP[style_name] | |
adapter_name = "style" | |
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name) | |
pipe.set_adapters([adapter_name], [1.0]) | |
# --- ํ๋กฌํํธ ๋น๋ --- | |
human_readable_style = style_name.replace("_", " ") | |
prompt = f"Turn this image into the {human_readable_style} style." | |
if prompt_suffix and prompt_suffix.strip(): | |
prompt += f" {prompt_suffix.strip()}" | |
gr.Info("Generating styled imageโฆ (24โ60โฏs)") | |
result = pipe( | |
image=img, | |
prompt=prompt, | |
guidance_scale=float(guidance_scale), | |
num_inference_steps=int(num_inference_steps), | |
generator=generator, | |
height=1024, | |
width=1024, | |
) | |
# --- LoRA ์ธ๋ก๋ & GPU ๋ฉ๋ชจ๋ฆฌ ํด์ --- | |
pipe.unload_lora_weights(adapter_name=adapter_name) | |
torch.cuda.empty_cache() | |
return result.images[0] | |
except Exception as e: | |
torch.cuda.empty_cache() | |
gr.Error(f"Error during style transfer: {e}") | |
return None | |
# ------------------------------------------------------------------ | |
# Gradio UI ์ ์ | |
# ------------------------------------------------------------------ | |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# ๐จ FLUX.1 Kontext Style Transfer | |
์ ๋ก๋ํ ์ด๋ฏธ์ง๋ฅผ 22โฏ์ข ์ ์์ ์คํ์ผ๋ก ๋ณํํ์ธ์! | |
(๋ชจ๋ธโฏ/โฏLoRA๋ ์ต์ด ์คํ ์์๋ง ๋ค์ด๋ก๋๋๋ฉฐ, ์ดํ ์คํ์ ๋น ๋ฆ ๋๋ค.) | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Upload Image", type="pil", height=400) | |
style_dropdown = gr.Dropdown( | |
choices=list(STYLE_LORA_MAP.keys()), | |
value="Ghibli", | |
label="Select Style", | |
) | |
style_info = gr.Textbox(label="Style Description", value=STYLE_DESCRIPTIONS["Ghibli"], interactive=False, lines=2) | |
prompt_suffix = gr.Textbox(label="Additional Instructions (Optional)", placeholder="e.g. add dramatic lighting", lines=2) | |
with gr.Accordion("Advanced Settings", open=False): | |
num_steps = gr.Slider(minimum=10, maximum=50, value=24, step=1, label="Inference Steps") | |
guidance = gr.Slider(minimum=1.0, maximum=7.5, value=2.5, step=0.1, label="Guidance Scale") | |
seed = gr.Number(label="Seed (0 = random)", value=42) | |
generate_btn = gr.Button("๐จ Transform Image", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Styled Result", type="pil", height=400) | |
gr.Markdown(""" | |
### ๐ก Tips | |
* ์ด๋ฏธ์ง ํฌ๊ธฐ๋ 1024ร1024๋ก ๋ฆฌ์ฌ์ด์ฆ๋ฉ๋๋ค. | |
* ์ต์ด 1ํย ๋ชจ๋ธย +ย LoRA ๋ค์ด๋ก๋ ํ์๋ **์บ์**๋ฅผ ์ฌ์ฉํ๋ฏ๋ก 10โ20โฏs ๋ด ์๋ฃ๋ฉ๋๋ค. | |
* "Additional Instructions"์ ์๊ฐยท์กฐ๋ช ยทํจ๊ณผ ๋ฑ์ ์์ด๋ก ๊ฐ๋จํ ์ ์ผ๋ฉด ๊ฒฐ๊ณผ๋ฅผ ์ธ๋ฐํ๊ฒ ์ ์ดํ ์ ์์ต๋๋ค. | |
""") | |
# --- ์คํ์ผ ์ค๋ช ์๋ ์ ๋ฐ์ดํธ --- | |
def _update_desc(style): | |
return STYLE_DESCRIPTIONS.get(style, "") | |
style_dropdown.change(fn=_update_desc, inputs=[style_dropdown], outputs=[style_info]) | |
# --- ์์ --- | |
gr.Examples( | |
examples=[ | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"], | |
], | |
inputs=[input_image, style_dropdown, prompt_suffix], | |
outputs=output_image, | |
fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 2.5, 42), | |
cache_examples=False, | |
) | |
# --- ๋ฒํผ ์ฐ๊ฒฐ --- | |
generate_btn.click( | |
fn=style_transfer, | |
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed], | |
outputs=output_image, | |
) | |
gr.Markdown(""" | |
--- | |
**Created with โค๏ธ by GiniGEN (2025)** | |
""") | |
if __name__ == "__main__": | |
demo.launch(inline=False) | |