Shriti09 commited on
Commit
e27d04f
·
verified ·
1 Parent(s): 990e8d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -74
app.py CHANGED
@@ -1,97 +1,121 @@
1
  import os
2
  import torch
 
3
  import gradio as gr
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from PIL import Image
7
  import torch.nn.functional as F
8
- from torchvision import transforms
 
9
  from diffusers import StableDiffusionPipeline
10
-
11
- # Define Loss Functions (same as in your code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def edge_loss(image_tensor):
13
  grayscale = image_tensor.mean(dim=0, keepdim=True)
14
- grayscale = grayscale.unsqueeze(0)
15
- sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], device=image_tensor.device).float().unsqueeze(0).unsqueeze(0)
16
  sobel_y = sobel_x.transpose(2, 3)
17
  gx = F.conv2d(grayscale, sobel_x, padding=1)
18
  gy = F.conv2d(grayscale, sobel_y, padding=1)
19
- return -torch.mean(torch.sqrt(gx ** 2 + gy ** 2))
20
 
21
  def texture_loss(image_tensor):
22
- return F.mse_loss(image_tensor, torch.rand_like(image_tensor, device=image_tensor.device))
23
 
24
  def entropy_loss(image_tensor):
25
- hist = torch.histc(image_tensor, bins=256, min=0, max=255)
26
  hist = hist / hist.sum()
27
  return -torch.sum(hist * torch.log(hist + 1e-7))
28
 
29
  def symmetry_loss(image_tensor):
30
  width = image_tensor.shape[-1]
31
- left_half = image_tensor[:, :, :width // 2]
32
- right_half = torch.flip(image_tensor[:, :, width // 2:], dims=[-1])
33
- return F.mse_loss(left_half, right_half)
34
 
35
  def contrast_loss(image_tensor):
36
- min_val = image_tensor.min()
37
- max_val = image_tensor.max()
38
- return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7))
39
-
40
- # Setup Stable Diffusion Pipeline
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
- pipe = StableDiffusionPipeline.from_pretrained(
43
- "CompVis/stable-diffusion-v1-4",
44
- torch_dtype=torch.float16
45
- ).to(device)
46
 
47
- # Image transform to tensor
48
- transform = transforms.ToTensor()
49
-
50
- # Loss functions dictionary
51
- losses = {
52
- "edge": edge_loss,
53
- "texture": texture_loss,
54
- "entropy": entropy_loss,
55
- "symmetry": symmetry_loss,
56
- "contrast": contrast_loss
57
- }
58
-
59
- # Define function to generate images for a given seed
60
  def generate_images(seed):
61
- generator = torch.Generator(device).manual_seed(seed)
62
- output_image = pipe("A futuristic city skyline at sunset", generator=generator).images[0]
63
-
64
- # Convert to tensor
65
- image_tensor = transform(output_image).to(device)
66
-
67
- loss_images = []
68
- loss_values = []
69
-
70
- # Compute losses and generate modified images
71
- for loss_name, loss_fn in losses.items():
72
- loss_value = loss_fn(image_tensor)
73
-
74
- # Resize to thumbnail size
75
- thumbnail_image = output_image.copy()
76
- thumbnail_image.thumbnail((128, 128))
77
-
78
- # Save loss image with thumbnail
79
- loss_images.append(thumbnail_image)
80
- loss_values.append(f"{loss_name}: {loss_value.item():.4f}")
81
-
82
- return loss_images, loss_values
83
-
84
- # Gradio Interface
85
- def gradio_interface(seed):
86
- loss_images, loss_values = generate_images(int(seed))
87
- return loss_images, loss_values
88
-
89
- # Set up Gradio UI
90
- interface = gr.Interface(
91
- fn=gradio_interface,
92
- inputs=gr.inputs.Textbox(label="Enter Seed"),
93
- outputs=[gr.outputs.Gallery(label="Loss Images"), gr.outputs.Textbox(label="Loss Values")]
94
- )
95
-
96
- # Launch the interface
97
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
+ import gc
4
  import gradio as gr
 
 
 
5
  import torch.nn.functional as F
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
  from diffusers import StableDiffusionPipeline
9
+ from huggingface_hub import login
10
+
11
+ # Initialize model once and reuse
12
+ def init_pipeline():
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
15
+
16
+ # Authenticate with HF token
17
+ login(token=os.environ.get("HF_AUTH_TOKEN"))
18
+
19
+ pipeline = StableDiffusionPipeline.from_pretrained(
20
+ "CompVis/stable-diffusion-v1-4",
21
+ torch_dtype=torch_dtype,
22
+ use_auth_token=True
23
+ ).to(device)
24
+
25
+ # Optimize for performance
26
+ if device == "cuda":
27
+ pipeline.enable_attention_slicing()
28
+ pipeline.enable_xformers_memory_efficient_attention()
29
+
30
+ return pipeline, device
31
+
32
+ # Initialize pipeline at startup
33
+ pipe, device = init_pipeline()
34
+
35
+ # Define your original loss functions
36
  def edge_loss(image_tensor):
37
  grayscale = image_tensor.mean(dim=0, keepdim=True)
38
+ sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]],
39
+ device=device).float().unsqueeze(0).unsqueeze(0)
40
  sobel_y = sobel_x.transpose(2, 3)
41
  gx = F.conv2d(grayscale, sobel_x, padding=1)
42
  gy = F.conv2d(grayscale, sobel_y, padding=1)
43
+ return -torch.mean(torch.sqrt(gx**2 + gy**2 + 1e-6))
44
 
45
  def texture_loss(image_tensor):
46
+ return F.mse_loss(image_tensor, torch.rand_like(image_tensor))
47
 
48
  def entropy_loss(image_tensor):
49
+ hist = torch.histc(image_tensor, bins=256, min=0, max=1)
50
  hist = hist / hist.sum()
51
  return -torch.sum(hist * torch.log(hist + 1e-7))
52
 
53
  def symmetry_loss(image_tensor):
54
  width = image_tensor.shape[-1]
55
+ left = image_tensor[..., :width//2]
56
+ right = torch.flip(image_tensor[..., width//2:], dims=[-1])
57
+ return F.mse_loss(left, right)
58
 
59
  def contrast_loss(image_tensor):
60
+ return -torch.std(image_tensor)
 
 
 
 
 
 
 
 
 
61
 
62
+ # Modified generation function
 
 
 
 
 
 
 
 
 
 
 
 
63
  def generate_images(seed):
64
+ # Clear memory before generation
65
+ if torch.cuda.is_available():
66
+ torch.cuda.empty_cache()
67
+ gc.collect()
68
+
69
+ # Generate base image
70
+ generator = torch.Generator(device).manual_seed(int(seed))
71
+ image = pipe(
72
+ "A futuristic city skyline at sunset",
73
+ generator=generator,
74
+ num_inference_steps=50,
75
+ guidance_scale=7.5
76
+ ).images[0]
77
+
78
+ # Process image
79
+ transform = T.Compose([
80
+ T.ToTensor(),
81
+ T.Normalize([0.5], [0.5])
82
+ ])
83
+ image_tensor = transform(image).unsqueeze(0).to(device)
84
+
85
+ # Calculate losses
86
+ losses = {
87
+ "Edge": edge_loss(image_tensor),
88
+ "Texture": texture_loss(image_tensor),
89
+ "Entropy": entropy_loss(image_tensor),
90
+ "Symmetry": symmetry_loss(image_tensor),
91
+ "Contrast": contrast_loss(image_tensor)
92
+ }
93
+
94
+ # Create thumbnail grid
95
+ thumb = image.copy()
96
+ thumb.thumbnail((256, 256))
97
+ grid = Image.new('RGB', (256 * 2, 256))
98
+ grid.paste(thumb, (0, 0))
99
+
100
+ # Add visualization placeholder
101
+ vis = Image.new('RGB', (256, 256), color='white')
102
+ grid.paste(vis, (256, 0))
103
+
104
+ return grid, "\n".join([f"{k}: {v.item():.4f}" for k,v in losses.items()])
105
+
106
+ # Gradio interface
107
+ def create_interface():
108
+ return gr.Interface(
109
+ fn=generate_images,
110
+ inputs=gr.Number(label="Seed", value=42),
111
+ outputs=[
112
+ gr.Image(label="Generated Image & Visualizations"),
113
+ gr.Textbox(label="Loss Values")
114
+ ],
115
+ title="Stable Diffusion Loss Analysis",
116
+ description="Generate images and visualize different loss metrics"
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ interface = create_interface()
121
+ interface.launch(server_port=7860, share=False)