artificialguybr commited on
Commit
5e508ca
Β·
verified Β·
1 Parent(s): 71e527f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ from torch.nn import Conv2d
5
+ from torch.nn import functional as F
6
+ from torch.nn.modules.utils import _pair
7
+ from typing import Optional
8
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
9
+ import diffusers
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import spaces
13
+ import gc
14
+
15
+ def asymmetricConv2DConvForward_circular(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
16
+ self.paddingX = (
17
+ self._reversed_padding_repeated_twice[0],
18
+ self._reversed_padding_repeated_twice[1],
19
+ 0,
20
+ 0
21
+ )
22
+ self.paddingY = (
23
+ 0,
24
+ 0,
25
+ self._reversed_padding_repeated_twice[2],
26
+ self._reversed_padding_repeated_twice[3]
27
+ )
28
+ working = F.pad(input, self.paddingX, mode="circular")
29
+ working = F.pad(working, self.paddingY, mode="circular")
30
+ return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
31
+
32
+ def make_seamless(model):
33
+ for module in model.modules():
34
+ if isinstance(module, torch.nn.Conv2d):
35
+ if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None:
36
+ module.lora_layer = lambda *x: 0
37
+ module._conv_forward = asymmetricConv2DConvForward_circular.__get__(module, Conv2d)
38
+
39
+ def disable_seamless(model):
40
+ for module in model.modules():
41
+ if isinstance(module, torch.nn.Conv2d):
42
+ if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None:
43
+ module.lora_layer = lambda *x: 0
44
+ module._conv_forward = nn.Conv2d._conv_forward.__get__(module, Conv2d)
45
+
46
+ def diffusion_callback(pipe, step_index, timestep, callback_kwargs):
47
+ if step_index == int(pipe.num_timesteps * 0.8):
48
+ make_seamless(pipe.unet)
49
+ make_seamless(pipe.vae)
50
+ if step_index < int(pipe.num_timesteps * 0.8):
51
+ callback_kwargs["latents"] = torch.roll(callback_kwargs["latents"], shifts=(64, 64), dims=(2, 3))
52
+ return callback_kwargs
53
+
54
+ print("Loading Pattern Diffusion model...")
55
+ pipe = StableDiffusionPipeline.from_pretrained(
56
+ "Arrexel/pattern-diffusion",
57
+ torch_dtype=torch.float16,
58
+ safety_checker=None,
59
+ requires_safety_checker=False
60
+ )
61
+ pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
62
+
63
+ if torch.cuda.is_available():
64
+ pipe = pipe.to("cuda")
65
+ pipe.enable_attention_slicing()
66
+ pipe.enable_model_cpu_offload()
67
+ print("Model loaded successfully on GPU with optimizations!")
68
+ else:
69
+ print("GPU not available, using CPU")
70
+
71
+ @spaces.GPU(duration=40)
72
+ def generate_pattern(prompt, width=1024, height=1024, num_inference_steps=50, guidance_scale=7.5, seed=None):
73
+ try:
74
+ if torch.cuda.is_available():
75
+ pipe.to("cuda")
76
+
77
+ if seed is not None and seed != "":
78
+ generator = torch.Generator(device=pipe.device).manual_seed(int(seed))
79
+ else:
80
+ generator = None
81
+
82
+ disable_seamless(pipe.unet)
83
+ disable_seamless(pipe.vae)
84
+
85
+ with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
86
+ output = pipe(
87
+ prompt=prompt,
88
+ width=int(width),
89
+ height=int(height),
90
+ num_inference_steps=int(num_inference_steps),
91
+ guidance_scale=guidance_scale,
92
+ generator=generator,
93
+ callback_on_step_end=diffusion_callback
94
+ ).images[0]
95
+
96
+ return output
97
+
98
+ except Exception as e:
99
+ print(f"Error during generation: {str(e)}")
100
+ return None
101
+ finally:
102
+ if torch.cuda.is_available():
103
+ torch.cuda.empty_cache()
104
+ gc.collect()
105
+
106
+ def create_interface():
107
+ with gr.Blocks(title="Pattern Diffusion - Seamless Pattern Generator") as demo:
108
+ gr.Markdown("""
109
+ # 🎨 Pattern Diffusion - Seamless Pattern Generator
110
+
111
+ **Model:** [Arrexel/pattern-diffusion](https://huggingface.co/Arrexel/pattern-diffusion)
112
+
113
+ This model specializes in generating patterns that can be repeated without visible seams,
114
+ ideal for prints, wallpapers, textiles, and surfaces.
115
+
116
+ **Strengths:**
117
+ - Excellent for floral and abstract patterns
118
+ - Understands foreground and background colors well
119
+ - Fast and efficient on VRAM
120
+
121
+ **Limitations:**
122
+ - Does not generate coherent text
123
+ - Difficulty with anatomy of living creatures
124
+ - Inconsistent geometry in simple geometric patterns
125
+ """)
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ prompt = gr.Textbox(
130
+ label="Prompt",
131
+ placeholder="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background.",
132
+ lines=3,
133
+ value="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background."
134
+ )
135
+
136
+ with gr.Row():
137
+ width = gr.Slider(
138
+ label="Width",
139
+ minimum=256,
140
+ maximum=1024,
141
+ step=256,
142
+ value=1024
143
+ )
144
+ height = gr.Slider(
145
+ label="Height",
146
+ minimum=256,
147
+ maximum=1024,
148
+ step=256,
149
+ value=1024
150
+ )
151
+
152
+ with gr.Row():
153
+ steps = gr.Slider(
154
+ label="Inference Steps",
155
+ minimum=20,
156
+ maximum=100,
157
+ step=5,
158
+ value=50
159
+ )
160
+ guidance_scale = gr.Slider(
161
+ label="Guidance Scale",
162
+ minimum=1.0,
163
+ maximum=20.0,
164
+ step=0.5,
165
+ value=7.5
166
+ )
167
+
168
+ seed = gr.Number(
169
+ label="Seed (optional, leave empty for random)",
170
+ precision=0
171
+ )
172
+
173
+ generate_btn = gr.Button("🎨 Generate Pattern", variant="primary", size="lg")
174
+
175
+ with gr.Column():
176
+ output_image = gr.Image(
177
+ label="Generated Pattern",
178
+ type="pil",
179
+ height=400
180
+ )
181
+
182
+ gr.Markdown("## πŸ“‹ Example Prompts")
183
+ examples = [
184
+ ["Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background."],
185
+ ["Abstract geometric pattern with gold and navy blue triangles on cream background"],
186
+ ["Delicate cherry blossom pattern with soft pink petals on light gray background"],
187
+ ["Art deco pattern with emerald green and gold lines on black background"],
188
+ ["Tropical leaves pattern with various shades of green on white background"],
189
+ ["Vintage damask pattern in burgundy and cream colors"],
190
+ ["Modern minimalist dots pattern in pastel colors"],
191
+ ["Mandala-inspired pattern with intricate details in blue and white"]
192
+ ]
193
+
194
+ gr.Examples(
195
+ examples=examples,
196
+ inputs=[prompt],
197
+ label="Click an example to use"
198
+ )
199
+
200
+ generate_btn.click(
201
+ fn=generate_pattern,
202
+ inputs=[prompt, width, height, steps, guidance_scale, seed],
203
+ outputs=[output_image]
204
+ )
205
+
206
+ return demo
207
+
208
+ if __name__ == "__main__":
209
+ demo = create_interface()
210
+ demo.queue(max_size=20).launch()