seawolf2357's picture
Update app.py
5f18011 verified
raw
history blame
11.5 kB
"""
FLUX.1โ€ฏKontext Style Transfer
=============================
Updatedย :ย 2025โ€‘07โ€‘12ย (HF_TOKENย ์ง€์›ย +ย ์ „์ฒด ์ฝ”๋“œ ์™„์„ฑ)
---------------------------------------------------
Gradio ๋ฐ๋ชจ๋กœ ์ด๋ฏธ์ง€๋ฅผ 22โ€ฏ์ข… ์˜ˆ์ˆ  ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
- **HF_TOKEN**ย ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ์ž๋™ ์ธ์‹ํ•ด ๋ผ์ด์„ ์Šค ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์˜ค๋ฅ˜๋ฅผ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค.
- ์ตœ์ดˆ ์‹คํ–‰ ์‹œ ๋ชจ๋ธยทLoRA๋ฅผ ์บ์‹œ์— ๋ฐ›์•„ ๋‘๊ณ , ์ดํ›„์—๋Š” ์žฌ๋‹ค์šด๋กœ๋“œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.
- GPUโ€ฏVRAM์„ ๊ฐ์ง€ํ•˜์—ฌ 24โ€ฏGBโ€ฏ๋ฏธ๋งŒ์—์„œ๋Š” FP16โ€ฏ+โ€ฏCPUย offload๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
- ํŒŒ์ดํ”„๋ผ์ธยทLoRA ๋กœ๋”ฉ ๋ฉ”์‹œ์ง€๋Š” ์ตœ์ดˆ 1ํšŒ๋งŒ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.
"""
import os
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalTokenNotFoundError
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
# ------------------------------------------------------------------
# ํ™˜๊ฒฝ ์„ค์ • & ๋ชจ๋ธ / LoRA ์‚ฌ์ „ ๋‹ค์šด๋กœ๋“œ
# ------------------------------------------------------------------
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # ๋น ๋ฅธ ๋‹ค์šด๋กœ๋“œ
MODEL_ID = "black-forest-labs/FLUX.1-Kontext-dev"
LORA_REPO = "Owen777/Kontext-Style-Loras"
CACHE_DIR = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
HF_TOKEN = os.getenv("HF_TOKEN") # ๋Ÿฐํƒ€์ž„์— ์ฃผ์ž…ํ•˜๊ฑฐ๋‚˜ Secrets ์‚ฌ์šฉ
def _download_with_token(repo_id: str) -> str:
"""Download repo snapshot with optional token handling."""
try:
return snapshot_download(
repo_id=repo_id,
cache_dir=CACHE_DIR,
resume_download=True,
token=HF_TOKEN if HF_TOKEN else True, # True โ†’ ๋กœ๊ทธ์ธ ์„ธ์…˜ ์‚ฌ์šฉ
)
except LocalTokenNotFoundError:
raise RuntimeError(
"Huggingย Face ํ† ํฐ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ํ™˜๊ฒฝ๋ณ€์ˆ˜ HF_TOKEN์„ ์„ค์ •ํ•˜๊ฑฐ๋‚˜\n"
"`huggingface-cli login`์œผ๋กœ ๋กœ๊ทธ์ธํ•ด ์ฃผ์„ธ์š”."
)
# ์ตœ์ดˆ ์‹คํ–‰ ์‹œ ์บ์‹œ์—๋งŒ ๋‹ค์šด๋กœ๋“œ (์ด๋ฏธ ์žˆ์œผ๋ฉด ์ฆ‰์‹œ ๋ฐ˜ํ™˜)
MODEL_DIR = _download_with_token(MODEL_ID)
LORA_DIR = _download_with_token(LORA_REPO)
# ------------------------------------------------------------------
# ์Šคํƒ€์ผย โ†’ย LoRA ํŒŒ์ผ ๋งคํ•‘ & ์„ค๋ช…
# ------------------------------------------------------------------
STYLE_LORA_MAP = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors",
}
STYLE_DESCRIPTIONS = {
"3D_Chibi": "Cute, miniature 3D character style with big heads",
"American_Cartoon": "Classic American animation style",
"Chinese_Ink": "Traditional Chinese ink painting aesthetic",
"Clay_Toy": "Playful clay/plasticine toy appearance",
"Fabric": "Soft, textile-like rendering",
"Ghibli": "Studio Ghibli's distinctive anime style",
"Irasutoya": "Simple, flat Japanese illustration style",
"Jojo": "JoJo's Bizarre Adventure manga style",
"Oil_Painting": "Classic oil painting texture and strokes",
"Pixel": "Retro pixel art style",
"Snoopy": "Peanuts comic strip style",
"Poly": "Low-poly 3D geometric style",
"LEGO": "LEGO brick construction style",
"Origami": "Paper folding art style",
"Pop_Art": "Bold, colorful pop art style",
"Van_Gogh": "Van Gogh's expressive brushstroke style",
"Paper_Cutting": "Paper cut-out art style",
"Line": "Clean line art/sketch style",
"Vector": "Clean vector graphics style",
"Picasso": "Cubist art style inspired by Picasso",
"Macaron": "Soft, pastel macaron-like style",
"Rick_Morty": "Rick and Morty cartoon style",
}
# ------------------------------------------------------------------
# ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋” (์‹ฑ๊ธ€ํ„ด)
# ------------------------------------------------------------------
_pipeline = None
def load_pipeline():
"""Load or return cached FluxKontextPipeline."""
global _pipeline
if _pipeline is not None:
return _pipeline
# VRAM ํŒ๋ณ„ โ†’ dtype & offload ์„ค์ •
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
dtype = torch.bfloat16 if vram_gb >= 24 else torch.float16
gr.Info("FLUX.1โ€‘Kontext ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋”ฉ ์ค‘โ€ฆย (์ตœ์ดˆ 1ํšŒ)")
pipe = FluxKontextPipeline.from_pretrained(
MODEL_DIR,
torch_dtype=dtype,
local_files_only=True,
).to("cuda")
if vram_gb < 24:
pipe.enable_sequential_cpu_offload()
else:
pipe.enable_model_cpu_offload()
_pipeline = pipe
return _pipeline
# ------------------------------------------------------------------
# ์Šคํƒ€์ผ ๋ณ€ํ™˜ ํ•จ์ˆ˜
# ------------------------------------------------------------------
@spaces.GPU(duration=600)
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
"""Apply selected style to the uploaded image."""
if input_image is None:
gr.Warning("Please upload an image first!")
return None
try:
pipe = load_pipeline()
# Torch Generator (seed ๊ณ ์ • ์‹œ ์žฌํ˜„ ๊ฐ€๋Šฅ)
generator = None
if seed and int(seed) > 0:
generator = torch.Generator(device="cuda").manual_seed(int(seed))
# ์ž…๋ ฅ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
img = input_image if isinstance(input_image, Image.Image) else load_image(input_image)
img = img.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
# LoRA ๋กœ๋“œ
lora_file = STYLE_LORA_MAP[style_name]
adapter_name = "style"
pipe.load_lora_weights(LORA_DIR, weight_name=lora_file, adapter_name=adapter_name)
pipe.set_adapters([adapter_name], [1.0])
# ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
readable_style = style_name.replace("_", " ")
prompt = f"Turn this image into the {readable_style} style."
if prompt_suffix and prompt_suffix.strip():
prompt += f" {prompt_suffix.strip()}"
gr.Info("Generating styled imageโ€ฆย (20โ€‘60โ€ฏs)")
result = pipe(
image=img,
prompt=prompt,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
height=1024,
width=1024,
)
# LoRA ์–ธ๋กœ๋“œ & ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
pipe.unload_lora_weights(adapter_name=adapter_name)
torch.cuda.empty_cache()
return result.images[0]
except Exception as e:
torch.cuda.empty_cache()
gr.Error(f"Error during style transfer: {e}")
return None
# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
def update_description(style):
return STYLE_DESCRIPTIONS.get(style, "")
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ๐ŸŽจ FLUX.1 Kontext Style Transfer
์—…๋กœ๋“œํ•œ ์ด๋ฏธ์ง€๋ฅผ 22โ€ฏ์ข… ์˜ˆ์ˆ  ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜ํ•˜์„ธ์š”!
(๋ชจ๋ธโ€ฏ/โ€ฏLoRA๋Š” ์ตœ์ดˆ ์‹คํ–‰ ์‹œ์—๋งŒ ๋‹ค์šด๋กœ๋“œ๋˜๋ฉฐ, ์ดํ›„ ์‹คํ–‰์€ ๋น ๋ฆ…๋‹ˆ๋‹ค.)
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Upload Image", type="pil", height=400)
style_dropdown = gr.Dropdown(
choices=list(STYLE_LORA_MAP.keys()),
value="Ghibli",
label="Select Style",
)
style_info = gr.Textbox(
label="Style Description",
value=STYLE_DESCRIPTIONS["Ghibli"],
interactive=False,
lines=2,
)
prompt_suffix = gr.Textbox(
label="Additional Instructions (Optional)",
placeholder="e.g. add dramatic lighting",
lines=2,
)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(10, 50, value=24, step=1, label="Inference Steps")
guidance = gr.Slider(1.0, 7.5, value=2.5, step=0.1, label="Guidance Scale")
seed = gr.Number(label="Seed (0 = random)", value=42)
generate_btn = gr.Button("๐ŸŽจ Transform Image", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(label="Styled Result", type="pil", height=400)
gr.Markdown(
"""
### ๐Ÿ’ก Tips
- ๋ชจ๋ธ(7โ€ฏGB)ยทLoRA๋Š” ์ตœ์ดˆ ์‹คํ–‰ ์‹œ์—๋งŒ ๋‹ค์šด๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
- ์ด๋ฏธ์ง€๋Š” 1024ร—1024๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ ํ›„ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค.
- VRAMย <ย 24โ€ฏGB์ธ ๊ฒฝ์šฐ ์ž๋™์œผ๋กœ FP16โ€ฏ+โ€ฏCPU offload๊ฐ€ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.
- seedย ๊ฐ’์„ ๋ณ€๊ฒฝํ•ด ๋‹ค์–‘ํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์–ด ๋ณด์„ธ์š”!
"""
)
# ์Šคํƒ€์ผ ์„ค๋ช… ์ž๋™ ์—…๋ฐ์ดํŠธ
style_dropdown.change(update_description, inputs=[style_dropdown], outputs=[style_info])
# ์˜ˆ์ œ ์ƒ˜ํ”Œ
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Pixel", "8-bit retro game style"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Chinese_Ink", "mountain landscape"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "LEGO", "colorful blocks"],
],
inputs=[input_image, style_dropdown, prompt_suffix],
outputs=output_image,
fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 2.5, 42),
cache_examples=False,
)
# ๋ฒ„ํŠผ ํด๋ฆญ ์—ฐ๊ฒฐ
generate_btn.click(
fn=style_transfer,
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed],
outputs=output_image,
)
if __name__ == "__main__":
demo.launch()