|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
from torch.nn import Conv2d |
|
from torch.nn import functional as F |
|
from torch.nn.modules.utils import _pair |
|
from typing import Optional |
|
from diffusers import StableDiffusionPipeline, DDPMScheduler |
|
import diffusers |
|
from PIL import Image |
|
import gradio as gr |
|
import spaces |
|
import gc |
|
|
|
def asymmetricConv2DConvForward_circular(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): |
|
self.paddingX = ( |
|
self._reversed_padding_repeated_twice[0], |
|
self._reversed_padding_repeated_twice[1], |
|
0, |
|
0 |
|
) |
|
self.paddingY = ( |
|
0, |
|
0, |
|
self._reversed_padding_repeated_twice[2], |
|
self._reversed_padding_repeated_twice[3] |
|
) |
|
working = F.pad(input, self.paddingX, mode="circular") |
|
working = F.pad(working, self.paddingY, mode="circular") |
|
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups) |
|
|
|
def make_seamless(model): |
|
for module in model.modules(): |
|
if isinstance(module, torch.nn.Conv2d): |
|
if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None: |
|
module.lora_layer = lambda *x: 0 |
|
module._conv_forward = asymmetricConv2DConvForward_circular.__get__(module, Conv2d) |
|
|
|
def disable_seamless(model): |
|
for module in model.modules(): |
|
if isinstance(module, torch.nn.Conv2d): |
|
if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None: |
|
module.lora_layer = lambda *x: 0 |
|
module._conv_forward = nn.Conv2d._conv_forward.__get__(module, Conv2d) |
|
|
|
def diffusion_callback(pipe, step_index, timestep, callback_kwargs): |
|
if step_index == int(pipe.num_timesteps * 0.8): |
|
make_seamless(pipe.unet) |
|
make_seamless(pipe.vae) |
|
if step_index < int(pipe.num_timesteps * 0.8): |
|
callback_kwargs["latents"] = torch.roll(callback_kwargs["latents"], shifts=(64, 64), dims=(2, 3)) |
|
return callback_kwargs |
|
|
|
print("Loading Pattern Diffusion model...") |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"Arrexel/pattern-diffusion", |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
) |
|
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) |
|
|
|
if torch.cuda.is_available(): |
|
pipe = pipe.to("cuda") |
|
pipe.enable_attention_slicing() |
|
pipe.enable_model_cpu_offload() |
|
print("Model loaded successfully on GPU with optimizations!") |
|
else: |
|
print("GPU not available, using CPU") |
|
|
|
@spaces.GPU(duration=40) |
|
def generate_pattern(prompt, width=1024, height=1024, num_inference_steps=50, guidance_scale=7.5, seed=None): |
|
try: |
|
if torch.cuda.is_available(): |
|
pipe.to("cuda") |
|
|
|
if seed is not None and seed != "": |
|
generator = torch.Generator(device=pipe.device).manual_seed(int(seed)) |
|
else: |
|
generator = None |
|
|
|
disable_seamless(pipe.unet) |
|
disable_seamless(pipe.vae) |
|
|
|
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"): |
|
output = pipe( |
|
prompt=prompt, |
|
width=int(width), |
|
height=int(height), |
|
num_inference_steps=int(num_inference_steps), |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
callback_on_step_end=diffusion_callback |
|
).images[0] |
|
|
|
return output |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {str(e)}") |
|
return None |
|
finally: |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Pattern Diffusion - Seamless Pattern Generator") as demo: |
|
gr.Markdown(""" |
|
# π¨ Pattern Diffusion - Seamless Pattern Generator |
|
|
|
**Model:** [Arrexel/pattern-diffusion](https://huggingface.co/Arrexel/pattern-diffusion) |
|
|
|
This model specializes in generating patterns that can be repeated without visible seams, |
|
ideal for prints, wallpapers, textiles, and surfaces. |
|
|
|
**Strengths:** |
|
- Excellent for floral and abstract patterns |
|
- Understands foreground and background colors well |
|
- Fast and efficient on VRAM |
|
|
|
**Limitations:** |
|
- Does not generate coherent text |
|
- Difficulty with anatomy of living creatures |
|
- Inconsistent geometry in simple geometric patterns |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background.", |
|
lines=3, |
|
value="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background." |
|
) |
|
|
|
with gr.Row(): |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=1024, |
|
step=256, |
|
value=1024 |
|
) |
|
height = gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=1024, |
|
step=256, |
|
value=1024 |
|
) |
|
|
|
with gr.Row(): |
|
steps = gr.Slider( |
|
label="Inference Steps", |
|
minimum=20, |
|
maximum=100, |
|
step=5, |
|
value=50 |
|
) |
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale", |
|
minimum=1.0, |
|
maximum=20.0, |
|
step=0.5, |
|
value=7.5 |
|
) |
|
|
|
seed = gr.Number( |
|
label="Seed (optional, leave empty for random)", |
|
precision=0 |
|
) |
|
|
|
generate_btn = gr.Button("π¨ Generate Pattern", variant="primary", size="lg") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image( |
|
label="Generated Pattern", |
|
type="pil", |
|
height=400 |
|
) |
|
|
|
gr.Markdown("## π Example Prompts") |
|
examples = [ |
|
["Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background."], |
|
["Abstract geometric pattern with gold and navy blue triangles on cream background"], |
|
["Delicate cherry blossom pattern with soft pink petals on light gray background"], |
|
["Art deco pattern with emerald green and gold lines on black background"], |
|
["Tropical leaves pattern with various shades of green on white background"], |
|
["Vintage damask pattern in burgundy and cream colors"], |
|
["Modern minimalist dots pattern in pastel colors"], |
|
["Mandala-inspired pattern with intricate details in blue and white"] |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[prompt], |
|
label="Click an example to use" |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_pattern, |
|
inputs=[prompt, width, height, steps, guidance_scale, seed], |
|
outputs=[output_image] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.queue(max_size=20).launch() |