seawolf2357 commited on
Commit
a02323a
Β·
verified Β·
1 Parent(s): e53c0b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+ from diffusers import FluxKontextPipeline
6
+ from diffusers.utils import load_image
7
+ from PIL import Image
8
+ import os
9
+
10
+ # Style dictionary
11
+ style_type_lora_dict = {
12
+ "3D_Chibi": "3D_Chibi_lora_weights.safetensors",
13
+ "American_Cartoon": "American_Cartoon_lora_weights.safetensors",
14
+ "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
15
+ "Clay_Toy": "Clay_Toy_lora_weights.safetensors",
16
+ "Fabric": "Fabric_lora_weights.safetensors",
17
+ "Ghibli": "Ghibli_lora_weights.safetensors",
18
+ "Irasutoya": "Irasutoya_lora_weights.safetensors",
19
+ "Jojo": "Jojo_lora_weights.safetensors",
20
+ "Oil_Painting": "Oil_Painting_lora_weights.safetensors",
21
+ "Pixel": "Pixel_lora_weights.safetensors",
22
+ "Snoopy": "Snoopy_lora_weights.safetensors",
23
+ "Poly": "Poly_lora_weights.safetensors",
24
+ "LEGO": "LEGO_lora_weights.safetensors",
25
+ "Origami": "Origami_lora_weights.safetensors",
26
+ "Pop_Art": "Pop_Art_lora_weights.safetensors",
27
+ "Van_Gogh": "Van_Gogh_lora_weights.safetensors",
28
+ "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
29
+ "Line": "Line_lora_weights.safetensors",
30
+ "Vector": "Vector_lora_weights.safetensors",
31
+ "Picasso": "Picasso_lora_weights.safetensors",
32
+ "Macaron": "Macaron_lora_weights.safetensors",
33
+ "Rick_Morty": "Rick_Morty_lora_weights.safetensors"
34
+ }
35
+
36
+ # Create LoRAs directory if it doesn't exist
37
+ os.makedirs("./LoRAs", exist_ok=True)
38
+
39
+ # Download all LoRA weights at startup
40
+ print("Downloading LoRA weights...")
41
+ for style_name, lora_file in style_type_lora_dict.items():
42
+ if not os.path.exists(f"./LoRAs/{lora_file}"):
43
+ hf_hub_download(
44
+ repo_id="Owen777/Kontext-Style-Loras",
45
+ filename=lora_file,
46
+ local_dir="./LoRAs"
47
+ )
48
+ print("All LoRA weights downloaded!")
49
+
50
+ # Initialize pipeline globally (will be loaded to GPU when needed)
51
+ pipeline = None
52
+
53
+ def load_pipeline():
54
+ global pipeline
55
+ if pipeline is None:
56
+ pipeline = FluxKontextPipeline.from_pretrained(
57
+ "black-forest-labs/FLUX.1-Kontext-dev",
58
+ torch_dtype=torch.bfloat16
59
+ )
60
+ return pipeline
61
+
62
+ @spaces.GPU(duration=120) # Request GPU for 120 seconds
63
+ def style_transfer(input_image, style_name, num_inference_steps, guidance_scale, seed):
64
+ """
65
+ Apply style transfer to the input image using selected style
66
+ """
67
+ # Load pipeline and move to GPU
68
+ pipe = load_pipeline()
69
+ pipe = pipe.to('cuda')
70
+
71
+ # Set seed for reproducibility
72
+ if seed is not None and seed > 0:
73
+ generator = torch.Generator(device="cuda").manual_seed(seed)
74
+ else:
75
+ generator = None
76
+
77
+ # Resize input image to 1024x1024
78
+ if isinstance(input_image, str):
79
+ image = load_image(input_image)
80
+ else:
81
+ image = input_image
82
+
83
+ image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
84
+
85
+ # Load the selected LoRA
86
+ lora_path = f"./LoRAs/{style_type_lora_dict[style_name]}"
87
+ pipe.load_lora_weights(lora_path, adapter_name="style_lora")
88
+ pipe.set_adapters(["style_lora"], adapter_weights=[1.0])
89
+
90
+ # Generate the styled image
91
+ prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
92
+
93
+ result = pipe(
94
+ image=image,
95
+ prompt=prompt,
96
+ height=1024,
97
+ width=1024,
98
+ num_inference_steps=num_inference_steps,
99
+ guidance_scale=guidance_scale,
100
+ generator=generator
101
+ )
102
+
103
+ # Clear GPU memory
104
+ torch.cuda.empty_cache()
105
+
106
+ return result.images[0]
107
+
108
+ # Create Gradio interface
109
+ with gr.Blocks(title="Flux Kontext Style Transfer") as demo:
110
+ gr.Markdown("""
111
+ # 🎨 Flux Kontext Style Transfer
112
+
113
+ Transform your images into various artistic styles using FLUX.1-Kontext and style-specific LoRAs.
114
+
115
+ Upload an image and select a style to apply the transformation!
116
+ """)
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ input_image = gr.Image(
121
+ label="Input Image",
122
+ type="pil",
123
+ height=400
124
+ )
125
+
126
+ style_dropdown = gr.Dropdown(
127
+ choices=list(style_type_lora_dict.keys()),
128
+ value="3D_Chibi",
129
+ label="Select Style",
130
+ info="Choose the artistic style to apply"
131
+ )
132
+
133
+ with gr.Accordion("Advanced Settings", open=False):
134
+ num_steps = gr.Slider(
135
+ minimum=10,
136
+ maximum=50,
137
+ value=24,
138
+ step=1,
139
+ label="Number of Inference Steps",
140
+ info="More steps = better quality but slower"
141
+ )
142
+
143
+ guidance = gr.Slider(
144
+ minimum=1.0,
145
+ maximum=10.0,
146
+ value=3.5,
147
+ step=0.5,
148
+ label="Guidance Scale",
149
+ info="Higher values = stronger style adherence"
150
+ )
151
+
152
+ seed = gr.Number(
153
+ label="Seed",
154
+ value=0,
155
+ precision=0,
156
+ info="Set to 0 for random, or use specific seed for reproducibility"
157
+ )
158
+
159
+ generate_btn = gr.Button("🎨 Apply Style Transfer", variant="primary")
160
+
161
+ with gr.Column():
162
+ output_image = gr.Image(
163
+ label="Styled Output",
164
+ type="pil",
165
+ height=400
166
+ )
167
+
168
+ # Examples
169
+ gr.Examples(
170
+ examples=[
171
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi"],
172
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli"],
173
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh"],
174
+ ],
175
+ inputs=[input_image, style_dropdown],
176
+ outputs=output_image,
177
+ fn=lambda img, style: style_transfer(img, style, 24, 3.5, 0),
178
+ cache_examples=True
179
+ )
180
+
181
+ # Connect the generate button
182
+ generate_btn.click(
183
+ fn=style_transfer,
184
+ inputs=[input_image, style_dropdown, num_steps, guidance, seed],
185
+ outputs=output_image
186
+ )
187
+
188
+ gr.Markdown("""
189
+ ## πŸ“ Notes:
190
+ - Processing takes about 30-60 seconds depending on the number of steps
191
+ - All images are resized to 1024x1024 for optimal results
192
+ - Different styles work better with different types of images
193
+ - Try adjusting the advanced settings for better results
194
+
195
+ ## 🎨 Available Styles:
196
+ 3D Chibi, American Cartoon, Chinese Ink, Clay Toy, Fabric, Ghibli, Irasutoya,
197
+ Jojo, Oil Painting, Pixel, Snoopy, Poly, LEGO, Origami, Pop Art, Van Gogh,
198
+ Paper Cutting, Line, Vector, Picasso, Macaron, Rick & Morty
199
+ """)
200
+
201
+ if __name__ == "__main__":
202
+ demo.launch()