Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from huggingface_hub import hf_hub_download | |
from diffusers import FluxKontextPipeline | |
from diffusers.utils import load_image | |
from PIL import Image | |
import os | |
# Style dictionary | |
style_type_lora_dict = { | |
"3D_Chibi": "3D_Chibi_lora_weights.safetensors", | |
"American_Cartoon": "American_Cartoon_lora_weights.safetensors", | |
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", | |
"Clay_Toy": "Clay_Toy_lora_weights.safetensors", | |
"Fabric": "Fabric_lora_weights.safetensors", | |
"Ghibli": "Ghibli_lora_weights.safetensors", | |
"Irasutoya": "Irasutoya_lora_weights.safetensors", | |
"Jojo": "Jojo_lora_weights.safetensors", | |
"Oil_Painting": "Oil_Painting_lora_weights.safetensors", | |
"Pixel": "Pixel_lora_weights.safetensors", | |
"Snoopy": "Snoopy_lora_weights.safetensors", | |
"Poly": "Poly_lora_weights.safetensors", | |
"LEGO": "LEGO_lora_weights.safetensors", | |
"Origami": "Origami_lora_weights.safetensors", | |
"Pop_Art": "Pop_Art_lora_weights.safetensors", | |
"Van_Gogh": "Van_Gogh_lora_weights.safetensors", | |
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", | |
"Line": "Line_lora_weights.safetensors", | |
"Vector": "Vector_lora_weights.safetensors", | |
"Picasso": "Picasso_lora_weights.safetensors", | |
"Macaron": "Macaron_lora_weights.safetensors", | |
"Rick_Morty": "Rick_Morty_lora_weights.safetensors" | |
} | |
# Create LoRAs directory if it doesn't exist | |
os.makedirs("./LoRAs", exist_ok=True) | |
# Download all LoRA weights at startup | |
print("Downloading LoRA weights...") | |
for style_name, lora_file in style_type_lora_dict.items(): | |
if not os.path.exists(f"./LoRAs/{lora_file}"): | |
hf_hub_download( | |
repo_id="Owen777/Kontext-Style-Loras", | |
filename=lora_file, | |
local_dir="./LoRAs" | |
) | |
print("All LoRA weights downloaded!") | |
# Initialize pipeline globally (will be loaded to GPU when needed) | |
pipeline = None | |
def load_pipeline(): | |
global pipeline | |
if pipeline is None: | |
pipeline = FluxKontextPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-Kontext-dev", | |
torch_dtype=torch.bfloat16 | |
) | |
return pipeline | |
# Request GPU for 120 seconds | |
def style_transfer(input_image, style_name, num_inference_steps, guidance_scale, seed): | |
""" | |
Apply style transfer to the input image using selected style | |
""" | |
# Load pipeline and move to GPU | |
pipe = load_pipeline() | |
pipe = pipe.to('cuda') | |
# Set seed for reproducibility | |
if seed is not None and seed > 0: | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
else: | |
generator = None | |
# Resize input image to 1024x1024 | |
if isinstance(input_image, str): | |
image = load_image(input_image) | |
else: | |
image = input_image | |
image = image.resize((1024, 1024), Image.Resampling.LANCZOS) | |
# Load the selected LoRA | |
lora_path = f"./LoRAs/{style_type_lora_dict[style_name]}" | |
pipe.load_lora_weights(lora_path, adapter_name="style_lora") | |
pipe.set_adapters(["style_lora"], adapter_weights=[1.0]) | |
# Generate the styled image | |
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style." | |
result = pipe( | |
image=image, | |
prompt=prompt, | |
height=1024, | |
width=1024, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator | |
) | |
# Clear GPU memory | |
torch.cuda.empty_cache() | |
return result.images[0] | |
# Create Gradio interface | |
with gr.Blocks(title="Flux Kontext Style Transfer") as demo: | |
gr.Markdown(""" | |
# π¨ Flux Kontext Style Transfer | |
Transform your images into various artistic styles using FLUX.1-Kontext and style-specific LoRAs. | |
Upload an image and select a style to apply the transformation! | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image", | |
type="pil", | |
height=400 | |
) | |
style_dropdown = gr.Dropdown( | |
choices=list(style_type_lora_dict.keys()), | |
value="3D_Chibi", | |
label="Select Style", | |
info="Choose the artistic style to apply" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
num_steps = gr.Slider( | |
minimum=10, | |
maximum=50, | |
value=24, | |
step=1, | |
label="Number of Inference Steps", | |
info="More steps = better quality but slower" | |
) | |
guidance = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
value=3.5, | |
step=0.5, | |
label="Guidance Scale", | |
info="Higher values = stronger style adherence" | |
) | |
seed = gr.Number( | |
label="Seed", | |
value=0, | |
precision=0, | |
info="Set to 0 for random, or use specific seed for reproducibility" | |
) | |
generate_btn = gr.Button("π¨ Apply Style Transfer", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Styled Output", | |
type="pil", | |
height=400 | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli"], | |
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh"], | |
], | |
inputs=[input_image, style_dropdown], | |
outputs=output_image, | |
fn=lambda img, style: style_transfer(img, style, 24, 3.5, 0), | |
cache_examples=True | |
) | |
# Connect the generate button | |
generate_btn.click( | |
fn=style_transfer, | |
inputs=[input_image, style_dropdown, num_steps, guidance, seed], | |
outputs=output_image | |
) | |
gr.Markdown(""" | |
## π Notes: | |
- Processing takes about 30-60 seconds depending on the number of steps | |
- All images are resized to 1024x1024 for optimal results | |
- Different styles work better with different types of images | |
- Try adjusting the advanced settings for better results | |
## π¨ Available Styles: | |
3D Chibi, American Cartoon, Chinese Ink, Clay Toy, Fabric, Ghibli, Irasutoya, | |
Jojo, Oil Painting, Pixel, Snoopy, Poly, LEGO, Origami, Pop Art, Van Gogh, | |
Paper Cutting, Line, Vector, Picasso, Macaron, Rick & Morty | |
""") | |
if __name__ == "__main__": | |
demo.launch() |