merterbak commited on
Commit
aa01795
·
verified ·
1 Parent(s): c004962

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import AuraFlowPipeline
4
+ import spaces
5
+ import numpy as np
6
+
7
+ pipeline = AuraFlowPipeline.from_pretrained(
8
+ "fal/AuraFlow-v0.3",
9
+ torch_dtype=torch.float16,
10
+ variant="fp16",
11
+ use_safetensors=True,
12
+ ).to("cuda")
13
+
14
+ STYLE_PRESETS = {
15
+ "None": "",
16
+ "Comic": ", in comic book style, bold outlines, vibrant colors, dynamic shading",
17
+ "Watercolor": ", in watercolor style, soft edges, translucent colors, delicate brushstrokes",
18
+ "Oil Painting": ", in oil painting style, rich textures, bold brushstrokes, deep colors",
19
+ "Cyberpunk": ", in cyberpunk style, neon lights, dark atmosphere, futuristic elements",
20
+ "Photorealistic": ", in photorealistic style, highly detailed, lifelike textures, realistic lighting"
21
+ }
22
+
23
+ examples = [
24
+ {"prompt": "A rustic village nestled in a golden autumn valley, with rolling hills and a winding river bathed in warm light", "style": "Oil Painting"},
25
+ {"prompt": "A majestic dragon soaring high above a range of snow-capped mountains under a golden sunset sky", "style": "Comic"},
26
+ {"prompt": "A shiba inu on a rocky cliff overlooking a vibrant sunset ocean view", "style": "Photorealistic"},
27
+ {"prompt": "A futuristic city skyline glowing with neon lights, towering skyscrapers, and flying cars under a stormy night", "style": "Cyberpunk"},
28
+
29
+ ]
30
+
31
+ @spaces.GPU(duration=120)
32
+ def generate_images(
33
+ prompt,
34
+ negative_prompt,
35
+ style,
36
+ width=1024,
37
+ height=1024,
38
+ steps=20,
39
+ guidance=5.0,
40
+ seed=1,
41
+ num_images=1,
42
+ ):
43
+ generator = torch.Generator(device="cuda").manual_seed(seed)
44
+ styled_prompt = f"{prompt}{STYLE_PRESETS[style]}"
45
+ gallery = []
46
+
47
+ for i in range(num_images):
48
+ image = pipeline(
49
+ prompt=styled_prompt,
50
+ negative_prompt=negative_prompt,
51
+ width=width,
52
+ height=height,
53
+ num_inference_steps=steps,
54
+ guidance_scale=guidance,
55
+ generator=generator,
56
+ output_type="pil",
57
+ ).images[0]
58
+ gallery.append((image, ""))
59
+
60
+ torch.cuda.empty_cache()
61
+ return gallery
62
+
63
+ def interface_fn(
64
+ prompt,
65
+ negative_prompt,
66
+ style,
67
+ width,
68
+ height,
69
+ steps,
70
+ guidance,
71
+ seed,
72
+ num_images,
73
+ randomize_seed,
74
+ history,
75
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
76
+ ):
77
+ if not prompt:
78
+ raise gr.Error("Please enter a prompt!")
79
+ if randomize_seed:
80
+ seed = np.random.randint(0, 1000000)
81
+
82
+ gallery = generate_images(
83
+ prompt=prompt,
84
+ negative_prompt=negative_prompt,
85
+ style=style,
86
+ width=width,
87
+ height=height,
88
+ steps=steps,
89
+ guidance=guidance,
90
+ seed=seed,
91
+ num_images=num_images
92
+ )
93
+
94
+ updated_history = update_history(gallery, history)
95
+ return gallery, seed, updated_history
96
+
97
+ def update_history(new_images, history):
98
+ if history is None:
99
+ history = []
100
+ for img in reversed(new_images):
101
+ history.insert(0, img[0])
102
+ return history
103
+
104
+ def clear_result():
105
+ return gr.update(value=[]), gr.update(value=None)
106
+
107
+ custom_css = """
108
+ .gr-button {margin: 5px;}
109
+ .output-image {border-radius: 8px;}
110
+ #advanced_options {margin-top: 20px;}
111
+ .style-dropdown {width: 100%; max-width: 800px;}
112
+ .gr-textbox {width: 100%;}
113
+ .example-row {margin-top: 20px;}
114
+ .example-button {white-space: normal; height: auto; min-height: 60px;}
115
+ /* Center the Generated Images gallery */
116
+ #output-gallery {
117
+ display: block;
118
+ width: 100%;
119
+ text-align: center;
120
+ }
121
+ #output-gallery .gallery {
122
+ display: inline-flex;
123
+ justify-content: center;
124
+ align-items: center;
125
+ flex-wrap: wrap;
126
+ margin: 0 auto;
127
+ }
128
+ #output-gallery .gallery > div {
129
+ display: flex;
130
+ justify-content: center;
131
+ align-items: center;
132
+ margin: 5px;
133
+ }
134
+ #output-gallery img {
135
+ display: block;
136
+ margin: 0 auto;
137
+ }
138
+ """
139
+
140
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as interface:
141
+ gr.Markdown("# AuraFlow v0.3 Image Generator")
142
+ gr.Markdown("Enter a prompt and select a style to generate images. Use the advanced settings for more control or try an example below.")
143
+ with gr.Row():
144
+ with gr.Column(scale=1):
145
+ prompt_input = gr.Textbox(
146
+ label="Prompt",
147
+ placeholder="Enter your creative prompt here",
148
+ lines=3
149
+ )
150
+ neg_prompt_input = gr.Textbox(
151
+ label="Negative Prompt",
152
+ placeholder="What you don’t want in the image",
153
+ lines=2
154
+ )
155
+ style_input = gr.Dropdown(
156
+ choices=list(STYLE_PRESETS.keys()),
157
+ value="None",
158
+ label="Art Style",
159
+ elem_classes=["style-dropdown"]
160
+ )
161
+ with gr.Accordion("Advanced Settings", open=False, elem_id="advanced_options"):
162
+ width_input = gr.Slider(256, 1536, step=256, value=1024, label="Width")
163
+ height_input = gr.Slider(256, 1536, step=256, value=1024, label="Height")
164
+ steps_input = gr.Slider(1, 50, step=1, value=20, label="Inference Steps")
165
+ guidance_input = gr.Slider(0, 10, step=0.5, value=5.0, label="Guidance Scale")
166
+ with gr.Row():
167
+ seed_input = gr.Number(value=1, label="Seed", visible=False)
168
+ randomize_seed_input = gr.Checkbox(value=True, label="Randomize Seed")
169
+ num_images_input = gr.Slider(1, 4, step=1, value=1, label="Number of Images")
170
+
171
+ with gr.Column(scale=2):
172
+ image_output = gr.Gallery(
173
+ label="Generated Images",
174
+ show_label=True,
175
+ preview=True,
176
+ elem_id="output-gallery"
177
+ )
178
+ with gr.Row():
179
+ clear_btn = gr.Button("Clear", variant="secondary")
180
+ generate_btn = gr.Button("Generate", variant="primary")
181
+
182
+ history_gallery = gr.Gallery(
183
+ label="History",
184
+ columns=6,
185
+ object_fit="contain",
186
+ interactive=False
187
+ )
188
+
189
+ with gr.Row(equal_height=True, elem_classes=["example-row"]):
190
+ gr.Markdown("### Try these examples:")
191
+ with gr.Row(equal_height=True, elem_classes=["example-row"]):
192
+ for ex in examples:
193
+ with gr.Column(scale=1, min_width=200):
194
+ btn = gr.Button(ex["prompt"], variant="secondary", elem_classes=["example-button"])
195
+ btn.click(
196
+ fn=lambda p=ex["prompt"], s=ex["style"]: (gr.update(value=p), gr.update(value=s)),
197
+ inputs=[],
198
+ outputs=[prompt_input, style_input]
199
+ ).then(
200
+ fn=interface_fn,
201
+ inputs=[prompt_input, neg_prompt_input, style_input, width_input, height_input,
202
+ steps_input, guidance_input, seed_input, num_images_input, randomize_seed_input, history_gallery],
203
+ outputs=[image_output, seed_input, history_gallery]
204
+ )
205
+ generate_btn.click(
206
+ fn=lambda: clear_result(),
207
+ inputs=[],
208
+ outputs=[image_output, seed_input]
209
+ ).then(
210
+ fn=interface_fn,
211
+ inputs=[prompt_input, neg_prompt_input, style_input, width_input, height_input,
212
+ steps_input, guidance_input, seed_input, num_images_input, randomize_seed_input, history_gallery],
213
+ outputs=[image_output, seed_input, history_gallery]
214
+ )
215
+
216
+ clear_btn.click(
217
+ fn=clear_result,
218
+ inputs=[],
219
+ outputs=[image_output, seed_input]
220
+ ).then(
221
+ fn=lambda x: (gr.update(value=[]), gr.update(value=None), x),
222
+ inputs=[history_gallery],
223
+ outputs=[image_output, seed_input, history_gallery]
224
+ )
225
+
226
+ randomize_seed_input.change(
227
+ fn=lambda randomize: gr.update(visible=not randomize),
228
+ inputs=randomize_seed_input,
229
+ outputs=seed_input
230
+ )
231
+
232
+ interface.launch(
233
+ show_error=True,
234
+ server_name="0.0.0.0",
235
+ server_port=7860,
236
+ share=True
237
+ )