Spaces:
Running
on
Zero
Running
on
Zero
""" | |
FLUX.1โฏKontext Style Transfer | |
============================= | |
Updatedย :ย 2025โ07โ12ย (HF_TOKENย ์ง์ย +ย ์ ์ฒด ์ฝ๋ ์์ฑ) | |
--------------------------------------------------- | |
Gradio ๋ฐ๋ชจ๋ก ์ด๋ฏธ์ง๋ฅผ 22โฏ์ข ์์ ์คํ์ผ๋ก ๋ณํํฉ๋๋ค. | |
- **HF_TOKEN**ย ํ๊ฒฝ๋ณ์๋ฅผ ์๋ ์ธ์ํด ๋ผ์ด์ ์ค ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ค๋ฅ๋ฅผ ๋ฐฉ์งํฉ๋๋ค. | |
- ์ต์ด ์คํ ์ ๋ชจ๋ธยทLoRA๋ฅผ ์บ์์ ๋ฐ์ ๋๊ณ , ์ดํ์๋ ์ฌ๋ค์ด๋ก๋๊ฐ ์์ต๋๋ค. | |
- GPUโฏVRAM์ ๊ฐ์งํ์ฌ 24โฏGBโฏ๋ฏธ๋ง์์๋ FP16โฏ+โฏCPUย offload๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
- ํ์ดํ๋ผ์ธยทLoRA ๋ก๋ฉ ๋ฉ์์ง๋ ์ต์ด 1ํ๋ง ํ์๋ฉ๋๋ค. | |
""" | |
import os | |
import gradio as gr | |
import spaces | |
import torch | |
from huggingface_hub import snapshot_download | |
from huggingface_hub.errors import LocalTokenNotFoundError | |
from diffusers import FluxKontextPipeline | |
from diffusers.utils import load_image | |
from PIL import Image | |
# ------------------------------------------------------------------ | |
# ํ๊ฒฝ ์ค์ & ๋ชจ๋ธ / LoRA ์ฌ์ ๋ค์ด๋ก๋ | |
# ------------------------------------------------------------------ | |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # ๋น ๋ฅธ ๋ค์ด๋ก๋ | |
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev" | |
LORA_REPO = "Owen777/Kontext-Style-Loras" | |
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) | |
HF_TOKEN = os.getenv("HF_TOKEN") # ๋ฐํ์์ ์ฃผ์ ํ๊ฑฐ๋ Secrets ์ฌ์ฉ | |
def _download_with_token(repo_id: str) -> str: | |
"""Download repo snapshot with optional token handling.""" | |
try: | |
return snapshot_download( | |
repo_id=repo_id, | |
cache_dir=CACHE_DIR, | |
resume_download=True, | |
token=HF_TOKEN if HF_TOKEN else True, # True โ ๋ก๊ทธ์ธ ์ธ์ ์ฌ์ฉ | |
) | |
except LocalTokenNotFoundError: | |
raise RuntimeError( | |
"Huggingย Face ํ ํฐ์ด ํ์ํฉ๋๋ค. ํ๊ฒฝ๋ณ์ HF_TOKEN์ ์ค์ ํ๊ฑฐ๋\n" | |
"`huggingface-cli login`์ผ๋ก ๋ก๊ทธ์ธํด ์ฃผ์ธ์." | |
) | |
# ์ต์ด ์คํ ์ ์บ์์๋ง ๋ค์ด๋ก๋ (์ด๋ฏธ ์์ผ๋ฉด ์ฆ์ ๋ฐํ) | |
MODEL_DIR = _download_with_token(MODEL_ID) | |
LORA_DIR = _download_with_token(LORA_REPO) | |
# ------------------------------------------------------------------ | |
# ์คํ์ผย โย LoRA ํ์ผ ๋งคํ & ์ค๋ช | |
# ------------------------------------------------------------------ | |
STYLE_LORA_MAP = { | |
"3D_Chibi": "3D_Chibi_lora_weights.safetensors", | |
"American_Cartoon": "American_Cartoon_lora_weights.safetensors", | |
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", | |
"Clay_Toy": "Clay_Toy_lora_weights.safetensors", | |
"Fabric": "Fabric_lora_weights.safetensors", | |
"Ghibli": "Ghibli_lora_weights.safetensors", | |
"Irasutoya": "Irasutoya_lora_weights.safetensors", | |
"Jojo": "Jojo_lora_weights.safetensors", | |
"Oil_Painting": "Oil_Painting_lora_weights.safetensors", | |
"Pixel": "Pixel_lora_weights.safetensors", | |
"Snoopy": "Snoopy_lora_weights.safetensors", | |
"Poly": "Poly_lora_weights.safetensors", | |
"LEGO": "LEGO_lora_weights.safetensors", | |
"Origami": "Origami_lora_weights.safetensors", | |
"Pop_Art": "Pop_Art_lora_weights.safetensors", | |
"Van_Gogh": "Van_Gogh_lora_weights.safetensors", | |
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", | |
"Line": "Line_lora_weights.safetensors", | |
"Vector": "Vector_lora_weights.safetensors", | |
"Picasso": "Picasso_lora_weights.safetensors", | |
"Macaron": "Macaron_lora_weights.safetensors", | |
"Rick_Morty": "Rick_Morty_lora_weights.safetensors", | |
} | |
STYLE_DESCRIPTIONS = { | |
"3D_Chibi": "Cute, miniature 3D character style with big heads", | |
"American_Cartoon": "Classic American animation style", | |
"Chinese_Ink": "Traditional Chinese ink painting aesthetic", | |
"Clay_Toy": "Playful clay/plasticine toy appearance", | |
"Fabric": "Soft, textile-like rendering", | |
"Ghibli": "Studio Ghibli's distinctive anime style", | |
"Irasutoya": "Simple, flat Japanese illustration style", | |
"Jojo": "JoJo's Bizarre Adventure manga style", | |
"Oil_Painting": "Classic oil painting texture and strokes", | |
"Pixel": "Retro pixel art style", | |
"Snoopy": "Peanuts comic strip style", | |
"Poly": "Low-poly 3D geometric style", | |
"LEGO": "LEGO brick construction style", | |
"Origami": "Paper folding art style", | |
"Pop_Art": "Bold, colorful pop art style", | |
"Van_Gogh": "Van Gogh's expressive brushstroke style", | |
"Paper_Cutting": "Paper cut-out art style", | |
"Line": "Clean line art/sketch style", | |
"Vector": "Clean vector graphics style", | |
"Picasso": "Cubist art style inspired by Picasso", | |
"Macaron": "Soft, pastel macaron-like style", | |
"Rick_Morty": "Rick and Morty cartoon style", | |
} | |
# ------------------------------------------------------------------ | |
# ํ์ดํ๋ผ์ธ ๋ก๋ (์ฑ๊ธํด) | |
# ------------------------------------------------------------------ | |
_pipeline = None | |
def load_pipeline(): | |
"""Load or return cached FluxKontextPipeline.""" | |
global _pipeline | |
if _pipeline is not None: | |
return _pipeline | |
# VRAM ํ๋ณ โ dtype & offload ์ค์ | |
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
dtype = torch.bfloat16 if vram_gb >= 24 else torch.float16 | |
gr.Info("FLUX.1โKontext ํ์ดํ๋ผ์ธ ๋ก๋ฉ ์คโฆย (์ต์ด 1ํ)") | |
pipe = FluxKontextPipeline.from_pretrained( | |
MODEL_DIR, | |
torch_dtype=dtype, | |
local_files_only=True, | |
).to("cuda") | |
if vram_gb < 24: | |
pipe.enable_sequential_cpu_offload() | |
else: | |
pipe.enable_model_cpu_offload() | |
_pipeline = pipe | |
return _pipeline | |
# ------------------------------------------------------------------ | |
# ์คํ์ผ ๋ณํ ํจ์ | |
# ------------------------------------------------------------------ | |
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed): | |
"""Apply selected style to the uploaded image.""" | |
if input_image is None: | |
gr.Warning("Please upload an image first!") | |
return None | |
try: | |
pipe = load_pipeline() | |
# Torch Generator (seed ๊ณ ์ ์ ์ฌํ ๊ฐ๋ฅ) | |
generator = None | |
if seed and int(seed) > 0: | |
generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
# ์ ๋ ฅ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image) | |
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS) | |
# LoRA ๋ก๋ | |
lora_file = STYLE_LORA_MAP[style_name] | |
adapter_name = "style" | |
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name) | |
pipe.set_adapters([adapter_name], [1.0]) | |
# ํ๋กฌํํธ ๊ตฌ์ฑ | |
readable_style = style_name.replace("_", " ") | |
prompt = f"Turn this image into the {readable_style} style." | |
if prompt_suffix and prompt_suffix.strip(): | |
prompt += f" {prompt_suffix.strip()}" | |
gr.Info("Generating styled imageโฆย (20โ60โฏs)") | |
result = pipe( | |
image=img, | |
prompt=prompt, | |
guidance_scale=float(guidance_scale), | |
num_inference_steps=int(num_inference_steps), | |
generator=generator, | |
height=1024, | |
width=1024, | |
) | |
# LoRA ์ธ๋ก๋ & ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
pipe.unload_lora_weights(adapter_name=adapter_name) | |
torch.cuda.empty_cache() | |
return result.images[0] | |
except Exception as e: | |
torch.cuda.empty_cache() | |
gr.Error(f"Error during style transfer: {e}") | |
return None | |
# ------------------------------------------------------------------ | |
# Gradio UI | |
# ------------------------------------------------------------------ | |
def update_description(style): | |
return STYLE_DESCRIPTIONS.get(style, "") | |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# ๐จ FLUX.1 Kontext Style Transfer | |
์ ๋ก๋ํ ์ด๋ฏธ์ง๋ฅผ 22โฏ์ข ์์ ์คํ์ผ๋ก ๋ณํํ์ธ์! | |
(๋ชจ๋ธโฏ/โฏLoRA๋ ์ต์ด ์คํ ์์๋ง ๋ค์ด๋ก๋๋๋ฉฐ, ์ดํ ์คํ์ ๋น ๋ฆ ๋๋ค.) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Upload Image", type="pil", height=400) | |
style_dropdown = gr.Dropdown( | |
choices=list(STYLE_LORA_MAP.keys()), | |
value="Ghibli", | |
label="Select Style", | |
) | |
style_info = gr.Textbox( | |
label="Style Description", | |
value=STYLE_DESCRIPTIONS["Ghibli"], | |
interactive=False, | |
lines=2, | |
) | |
prompt_suffix = gr.Textbox( | |
label="Additional Instructions (Optional)", | |
placeholder="e.g. add dramatic lighting", | |
lines=2, | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
num_steps = gr.Slider(10, 50, value=24, step=1, label="Inference Steps") | |
guidance = gr.Slider(1.0, 7.5, value=2.5, step=0.1, label="Guidance Scale") | |
seed = gr.Number(label="Seed (0 = random)", value=42) | |
generate_btn = gr.Button("๐จ Transform Image", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Styled Result", type="pil", height=400) | |
gr.Markdown( | |
""" | |
### ๐ก Tips | |
- ๋ชจ๋ธ(7โฏGB)ยทLoRA๋ ์ต์ด ์คํ ์์๋ง ๋ค์ด๋ก๋๋ฉ๋๋ค. | |
- ์ด๋ฏธ์ง๋ 1024ร1024๋ก ๋ฆฌ์ฌ์ด์ฆ ํ ์ฒ๋ฆฌ๋ฉ๋๋ค. | |
- VRAMย <ย 24โฏGB์ธ ๊ฒฝ์ฐ ์๋์ผ๋ก FP16โฏ+โฏCPU offload๊ฐ ์ ์ฉ๋ฉ๋๋ค. | |
- seedย ๊ฐ์ ๋ณ๊ฒฝํด ๋ค์ํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ด ๋ณด์ธ์! | |
""" | |
) | |
# ์คํ์ผ ์ค๋ช ์๋ ์ ๋ฐ์ดํธ | |
style_dropdown.change(update_description, inputs=[style_dropdown], outputs=[style_info]) | |
# ์์ ์ํ | |
gr.Examples( | |
examples=[ | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Pixel", "8-bit retro game style"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Chinese_Ink", "mountain landscape"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "LEGO", "colorful blocks"], | |
], | |
inputs=[input_image, style_dropdown, prompt_suffix], | |
outputs=output_image, | |
fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 2.5, 42), | |
cache_examples=False, | |
) | |
# ๋ฒํผ ํด๋ฆญ ์ฐ๊ฒฐ | |
generate_btn.click( | |
fn=style_transfer, | |
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed], | |
outputs=output_image, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |