File size: 7,033 Bytes
a02323a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
import os

# Style dictionary
style_type_lora_dict = {
    "3D_Chibi": "3D_Chibi_lora_weights.safetensors",
    "American_Cartoon": "American_Cartoon_lora_weights.safetensors",
    "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
    "Clay_Toy": "Clay_Toy_lora_weights.safetensors",
    "Fabric": "Fabric_lora_weights.safetensors",
    "Ghibli": "Ghibli_lora_weights.safetensors",
    "Irasutoya": "Irasutoya_lora_weights.safetensors",
    "Jojo": "Jojo_lora_weights.safetensors",
    "Oil_Painting": "Oil_Painting_lora_weights.safetensors",
    "Pixel": "Pixel_lora_weights.safetensors",
    "Snoopy": "Snoopy_lora_weights.safetensors",
    "Poly": "Poly_lora_weights.safetensors",
    "LEGO": "LEGO_lora_weights.safetensors",
    "Origami": "Origami_lora_weights.safetensors",
    "Pop_Art": "Pop_Art_lora_weights.safetensors",
    "Van_Gogh": "Van_Gogh_lora_weights.safetensors",
    "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
    "Line": "Line_lora_weights.safetensors",
    "Vector": "Vector_lora_weights.safetensors",
    "Picasso": "Picasso_lora_weights.safetensors",
    "Macaron": "Macaron_lora_weights.safetensors",
    "Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}

# Create LoRAs directory if it doesn't exist
os.makedirs("./LoRAs", exist_ok=True)

# Download all LoRA weights at startup
print("Downloading LoRA weights...")
for style_name, lora_file in style_type_lora_dict.items():
    if not os.path.exists(f"./LoRAs/{lora_file}"):
        hf_hub_download(
            repo_id="Owen777/Kontext-Style-Loras", 
            filename=lora_file, 
            local_dir="./LoRAs"
        )
print("All LoRA weights downloaded!")

# Initialize pipeline globally (will be loaded to GPU when needed)
pipeline = None

def load_pipeline():
    global pipeline
    if pipeline is None:
        pipeline = FluxKontextPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Kontext-dev", 
            torch_dtype=torch.bfloat16
        )
    return pipeline

@spaces.GPU(duration=120)  # Request GPU for 120 seconds
def style_transfer(input_image, style_name, num_inference_steps, guidance_scale, seed):
    """
    Apply style transfer to the input image using selected style
    """
    # Load pipeline and move to GPU
    pipe = load_pipeline()
    pipe = pipe.to('cuda')
    
    # Set seed for reproducibility
    if seed is not None and seed > 0:
        generator = torch.Generator(device="cuda").manual_seed(seed)
    else:
        generator = None
    
    # Resize input image to 1024x1024
    if isinstance(input_image, str):
        image = load_image(input_image)
    else:
        image = input_image
    
    image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
    
    # Load the selected LoRA
    lora_path = f"./LoRAs/{style_type_lora_dict[style_name]}"
    pipe.load_lora_weights(lora_path, adapter_name="style_lora")
    pipe.set_adapters(["style_lora"], adapter_weights=[1.0])
    
    # Generate the styled image
    prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
    
    result = pipe(
        image=image,
        prompt=prompt,
        height=1024,
        width=1024,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    )
    
    # Clear GPU memory
    torch.cuda.empty_cache()
    
    return result.images[0]

# Create Gradio interface
with gr.Blocks(title="Flux Kontext Style Transfer") as demo:
    gr.Markdown("""
    # 🎨 Flux Kontext Style Transfer
    
    Transform your images into various artistic styles using FLUX.1-Kontext and style-specific LoRAs.
    
    Upload an image and select a style to apply the transformation!
    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Input Image",
                type="pil",
                height=400
            )
            
            style_dropdown = gr.Dropdown(
                choices=list(style_type_lora_dict.keys()),
                value="3D_Chibi",
                label="Select Style",
                info="Choose the artistic style to apply"
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                num_steps = gr.Slider(
                    minimum=10,
                    maximum=50,
                    value=24,
                    step=1,
                    label="Number of Inference Steps",
                    info="More steps = better quality but slower"
                )
                
                guidance = gr.Slider(
                    minimum=1.0,
                    maximum=10.0,
                    value=3.5,
                    step=0.5,
                    label="Guidance Scale",
                    info="Higher values = stronger style adherence"
                )
                
                seed = gr.Number(
                    label="Seed",
                    value=0,
                    precision=0,
                    info="Set to 0 for random, or use specific seed for reproducibility"
                )
            
            generate_btn = gr.Button("🎨 Apply Style Transfer", variant="primary")
        
        with gr.Column():
            output_image = gr.Image(
                label="Styled Output",
                type="pil",
                height=400
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi"],
            ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli"],
            ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh"],
        ],
        inputs=[input_image, style_dropdown],
        outputs=output_image,
        fn=lambda img, style: style_transfer(img, style, 24, 3.5, 0),
        cache_examples=True
    )
    
    # Connect the generate button
    generate_btn.click(
        fn=style_transfer,
        inputs=[input_image, style_dropdown, num_steps, guidance, seed],
        outputs=output_image
    )
    
    gr.Markdown("""
    ## πŸ“ Notes:
    - Processing takes about 30-60 seconds depending on the number of steps
    - All images are resized to 1024x1024 for optimal results
    - Different styles work better with different types of images
    - Try adjusting the advanced settings for better results
    
    ## 🎨 Available Styles:
    3D Chibi, American Cartoon, Chinese Ink, Clay Toy, Fabric, Ghibli, Irasutoya, 
    Jojo, Oil Painting, Pixel, Snoopy, Poly, LEGO, Origami, Pop Art, Van Gogh, 
    Paper Cutting, Line, Vector, Picasso, Macaron, Rick & Morty
    """)

if __name__ == "__main__":
    demo.launch()