Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,97 +1,121 @@
|
|
1 |
import os
|
2 |
import torch
|
|
|
3 |
import gradio as gr
|
4 |
-
import numpy as np
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
from PIL import Image
|
7 |
import torch.nn.functional as F
|
8 |
-
|
|
|
9 |
from diffusers import StableDiffusionPipeline
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def edge_loss(image_tensor):
|
13 |
grayscale = image_tensor.mean(dim=0, keepdim=True)
|
14 |
-
|
15 |
-
|
16 |
sobel_y = sobel_x.transpose(2, 3)
|
17 |
gx = F.conv2d(grayscale, sobel_x, padding=1)
|
18 |
gy = F.conv2d(grayscale, sobel_y, padding=1)
|
19 |
-
return -torch.mean(torch.sqrt(gx
|
20 |
|
21 |
def texture_loss(image_tensor):
|
22 |
-
return F.mse_loss(image_tensor, torch.rand_like(image_tensor
|
23 |
|
24 |
def entropy_loss(image_tensor):
|
25 |
-
hist = torch.histc(image_tensor, bins=256, min=0, max=
|
26 |
hist = hist / hist.sum()
|
27 |
return -torch.sum(hist * torch.log(hist + 1e-7))
|
28 |
|
29 |
def symmetry_loss(image_tensor):
|
30 |
width = image_tensor.shape[-1]
|
31 |
-
|
32 |
-
|
33 |
-
return F.mse_loss(
|
34 |
|
35 |
def contrast_loss(image_tensor):
|
36 |
-
|
37 |
-
max_val = image_tensor.max()
|
38 |
-
return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7))
|
39 |
-
|
40 |
-
# Setup Stable Diffusion Pipeline
|
41 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
43 |
-
"CompVis/stable-diffusion-v1-4",
|
44 |
-
torch_dtype=torch.float16
|
45 |
-
).to(device)
|
46 |
|
47 |
-
#
|
48 |
-
transform = transforms.ToTensor()
|
49 |
-
|
50 |
-
# Loss functions dictionary
|
51 |
-
losses = {
|
52 |
-
"edge": edge_loss,
|
53 |
-
"texture": texture_loss,
|
54 |
-
"entropy": entropy_loss,
|
55 |
-
"symmetry": symmetry_loss,
|
56 |
-
"contrast": contrast_loss
|
57 |
-
}
|
58 |
-
|
59 |
-
# Define function to generate images for a given seed
|
60 |
def generate_images(seed):
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
)
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
import gc
|
4 |
import gradio as gr
|
|
|
|
|
|
|
5 |
import torch.nn.functional as F
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from PIL import Image
|
8 |
from diffusers import StableDiffusionPipeline
|
9 |
+
from huggingface_hub import login
|
10 |
+
|
11 |
+
# Initialize model once and reuse
|
12 |
+
def init_pipeline():
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
15 |
+
|
16 |
+
# Authenticate with HF token
|
17 |
+
login(token=os.environ.get("HF_AUTH_TOKEN"))
|
18 |
+
|
19 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
20 |
+
"CompVis/stable-diffusion-v1-4",
|
21 |
+
torch_dtype=torch_dtype,
|
22 |
+
use_auth_token=True
|
23 |
+
).to(device)
|
24 |
+
|
25 |
+
# Optimize for performance
|
26 |
+
if device == "cuda":
|
27 |
+
pipeline.enable_attention_slicing()
|
28 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
29 |
+
|
30 |
+
return pipeline, device
|
31 |
+
|
32 |
+
# Initialize pipeline at startup
|
33 |
+
pipe, device = init_pipeline()
|
34 |
+
|
35 |
+
# Define your original loss functions
|
36 |
def edge_loss(image_tensor):
|
37 |
grayscale = image_tensor.mean(dim=0, keepdim=True)
|
38 |
+
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]],
|
39 |
+
device=device).float().unsqueeze(0).unsqueeze(0)
|
40 |
sobel_y = sobel_x.transpose(2, 3)
|
41 |
gx = F.conv2d(grayscale, sobel_x, padding=1)
|
42 |
gy = F.conv2d(grayscale, sobel_y, padding=1)
|
43 |
+
return -torch.mean(torch.sqrt(gx**2 + gy**2 + 1e-6))
|
44 |
|
45 |
def texture_loss(image_tensor):
|
46 |
+
return F.mse_loss(image_tensor, torch.rand_like(image_tensor))
|
47 |
|
48 |
def entropy_loss(image_tensor):
|
49 |
+
hist = torch.histc(image_tensor, bins=256, min=0, max=1)
|
50 |
hist = hist / hist.sum()
|
51 |
return -torch.sum(hist * torch.log(hist + 1e-7))
|
52 |
|
53 |
def symmetry_loss(image_tensor):
|
54 |
width = image_tensor.shape[-1]
|
55 |
+
left = image_tensor[..., :width//2]
|
56 |
+
right = torch.flip(image_tensor[..., width//2:], dims=[-1])
|
57 |
+
return F.mse_loss(left, right)
|
58 |
|
59 |
def contrast_loss(image_tensor):
|
60 |
+
return -torch.std(image_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# Modified generation function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def generate_images(seed):
|
64 |
+
# Clear memory before generation
|
65 |
+
if torch.cuda.is_available():
|
66 |
+
torch.cuda.empty_cache()
|
67 |
+
gc.collect()
|
68 |
+
|
69 |
+
# Generate base image
|
70 |
+
generator = torch.Generator(device).manual_seed(int(seed))
|
71 |
+
image = pipe(
|
72 |
+
"A futuristic city skyline at sunset",
|
73 |
+
generator=generator,
|
74 |
+
num_inference_steps=50,
|
75 |
+
guidance_scale=7.5
|
76 |
+
).images[0]
|
77 |
+
|
78 |
+
# Process image
|
79 |
+
transform = T.Compose([
|
80 |
+
T.ToTensor(),
|
81 |
+
T.Normalize([0.5], [0.5])
|
82 |
+
])
|
83 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
84 |
+
|
85 |
+
# Calculate losses
|
86 |
+
losses = {
|
87 |
+
"Edge": edge_loss(image_tensor),
|
88 |
+
"Texture": texture_loss(image_tensor),
|
89 |
+
"Entropy": entropy_loss(image_tensor),
|
90 |
+
"Symmetry": symmetry_loss(image_tensor),
|
91 |
+
"Contrast": contrast_loss(image_tensor)
|
92 |
+
}
|
93 |
+
|
94 |
+
# Create thumbnail grid
|
95 |
+
thumb = image.copy()
|
96 |
+
thumb.thumbnail((256, 256))
|
97 |
+
grid = Image.new('RGB', (256 * 2, 256))
|
98 |
+
grid.paste(thumb, (0, 0))
|
99 |
+
|
100 |
+
# Add visualization placeholder
|
101 |
+
vis = Image.new('RGB', (256, 256), color='white')
|
102 |
+
grid.paste(vis, (256, 0))
|
103 |
+
|
104 |
+
return grid, "\n".join([f"{k}: {v.item():.4f}" for k,v in losses.items()])
|
105 |
+
|
106 |
+
# Gradio interface
|
107 |
+
def create_interface():
|
108 |
+
return gr.Interface(
|
109 |
+
fn=generate_images,
|
110 |
+
inputs=gr.Number(label="Seed", value=42),
|
111 |
+
outputs=[
|
112 |
+
gr.Image(label="Generated Image & Visualizations"),
|
113 |
+
gr.Textbox(label="Loss Values")
|
114 |
+
],
|
115 |
+
title="Stable Diffusion Loss Analysis",
|
116 |
+
description="Generate images and visualize different loss metrics"
|
117 |
+
)
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
interface = create_interface()
|
121 |
+
interface.launch(server_port=7860, share=False)
|