artificialguybr's picture
Create app.py
5e508ca verified
raw
history blame
7.87 kB
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from typing import Optional
from diffusers import StableDiffusionPipeline, DDPMScheduler
import diffusers
from PIL import Image
import gradio as gr
import spaces
import gc
def asymmetricConv2DConvForward_circular(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
self.paddingX = (
self._reversed_padding_repeated_twice[0],
self._reversed_padding_repeated_twice[1],
0,
0
)
self.paddingY = (
0,
0,
self._reversed_padding_repeated_twice[2],
self._reversed_padding_repeated_twice[3]
)
working = F.pad(input, self.paddingX, mode="circular")
working = F.pad(working, self.paddingY, mode="circular")
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
def make_seamless(model):
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None:
module.lora_layer = lambda *x: 0
module._conv_forward = asymmetricConv2DConvForward_circular.__get__(module, Conv2d)
def disable_seamless(model):
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
if isinstance(module, diffusers.models.lora.LoRACompatibleConv) and module.lora_layer is None:
module.lora_layer = lambda *x: 0
module._conv_forward = nn.Conv2d._conv_forward.__get__(module, Conv2d)
def diffusion_callback(pipe, step_index, timestep, callback_kwargs):
if step_index == int(pipe.num_timesteps * 0.8):
make_seamless(pipe.unet)
make_seamless(pipe.vae)
if step_index < int(pipe.num_timesteps * 0.8):
callback_kwargs["latents"] = torch.roll(callback_kwargs["latents"], shifts=(64, 64), dims=(2, 3))
return callback_kwargs
print("Loading Pattern Diffusion model...")
pipe = StableDiffusionPipeline.from_pretrained(
"Arrexel/pattern-diffusion",
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload()
print("Model loaded successfully on GPU with optimizations!")
else:
print("GPU not available, using CPU")
@spaces.GPU(duration=40)
def generate_pattern(prompt, width=1024, height=1024, num_inference_steps=50, guidance_scale=7.5, seed=None):
try:
if torch.cuda.is_available():
pipe.to("cuda")
if seed is not None and seed != "":
generator = torch.Generator(device=pipe.device).manual_seed(int(seed))
else:
generator = None
disable_seamless(pipe.unet)
disable_seamless(pipe.vae)
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
output = pipe(
prompt=prompt,
width=int(width),
height=int(height),
num_inference_steps=int(num_inference_steps),
guidance_scale=guidance_scale,
generator=generator,
callback_on_step_end=diffusion_callback
).images[0]
return output
except Exception as e:
print(f"Error during generation: {str(e)}")
return None
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def create_interface():
with gr.Blocks(title="Pattern Diffusion - Seamless Pattern Generator") as demo:
gr.Markdown("""
# 🎨 Pattern Diffusion - Seamless Pattern Generator
**Model:** [Arrexel/pattern-diffusion](https://huggingface.co/Arrexel/pattern-diffusion)
This model specializes in generating patterns that can be repeated without visible seams,
ideal for prints, wallpapers, textiles, and surfaces.
**Strengths:**
- Excellent for floral and abstract patterns
- Understands foreground and background colors well
- Fast and efficient on VRAM
**Limitations:**
- Does not generate coherent text
- Difficulty with anatomy of living creatures
- Inconsistent geometry in simple geometric patterns
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background.",
lines=3,
value="Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background."
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
step=256,
value=1024
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
step=256,
value=1024
)
with gr.Row():
steps = gr.Slider(
label="Inference Steps",
minimum=20,
maximum=100,
step=5,
value=50
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=20.0,
step=0.5,
value=7.5
)
seed = gr.Number(
label="Seed (optional, leave empty for random)",
precision=0
)
generate_btn = gr.Button("🎨 Generate Pattern", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(
label="Generated Pattern",
type="pil",
height=400
)
gr.Markdown("## πŸ“‹ Example Prompts")
examples = [
["Vibrant watercolor floral pattern with pink, purple, and blue flowers against a white background."],
["Abstract geometric pattern with gold and navy blue triangles on cream background"],
["Delicate cherry blossom pattern with soft pink petals on light gray background"],
["Art deco pattern with emerald green and gold lines on black background"],
["Tropical leaves pattern with various shades of green on white background"],
["Vintage damask pattern in burgundy and cream colors"],
["Modern minimalist dots pattern in pastel colors"],
["Mandala-inspired pattern with intricate details in blue and white"]
]
gr.Examples(
examples=examples,
inputs=[prompt],
label="Click an example to use"
)
generate_btn.click(
fn=generate_pattern,
inputs=[prompt, width, height, steps, guidance_scale, seed],
outputs=[output_image]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue(max_size=20).launch()