Vedansh-7 commited on
Commit
a4cf41d
·
1 Parent(s): d5e5728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -2
app.py CHANGED
@@ -14,8 +14,183 @@ NUM_CLASSES = 2
14
  # Define the device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # SinusoidalPositionEmbeddings and UNet classes remain the same as your original code
18
- # DiffusionModel class remains the same as your original code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Load the trained model with improved error handling
21
  def load_model(model_path, device):
 
14
  # Define the device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Define the SinusoidalPositionEmbeddings class
18
+ class SinusoidalPositionEmbeddings(nn.Module):
19
+ def __init__(self, dim):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.register_buffer('embeddings', self._precompute_embeddings(dim))
23
+
24
+ def _precompute_embeddings(self, dim):
25
+ half_dim = dim // 2
26
+ emb = math.log(10000) / (half_dim - 1)
27
+ emb = torch.exp(torch.arange(half_dim) * -emb)
28
+ return emb
29
+
30
+ def forward(self, time):
31
+ device = time.device
32
+ embeddings = self.embeddings.to(device)
33
+ embeddings = time[:, None] * embeddings[None, :]
34
+ output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
35
+ return output
36
+
37
+ # Define the UNet class
38
+ class UNet(nn.Module):
39
+ def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
40
+ super().__init__()
41
+
42
+ self.num_classes = num_classes
43
+ self.label_embedding = nn.Embedding(num_classes, time_dim)
44
+
45
+ self.time_mlp = nn.Sequential(
46
+ SinusoidalPositionEmbeddings(time_dim),
47
+ nn.Linear(time_dim, time_dim),
48
+ nn.ReLU(),
49
+ nn.Linear(time_dim, time_dim)
50
+ )
51
+
52
+ self.inc = self.double_conv(in_channels, 64)
53
+ self.down1 = self.down(64 + time_dim * 2, 128)
54
+ self.down2 = self.down(128 + time_dim * 2, 256)
55
+ self.down3 = self.down(256 + time_dim * 2, 512)
56
+
57
+ self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
58
+
59
+ self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
60
+ self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
61
+
62
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
63
+ self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
64
+
65
+ self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
66
+ self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
67
+
68
+ self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
69
+
70
+ def double_conv(self, in_channels, out_channels):
71
+ return nn.Sequential(
72
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
73
+ nn.ReLU(inplace=True),
74
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
75
+ nn.ReLU(inplace=True)
76
+ )
77
+
78
+ def down(self, in_channels, out_channels):
79
+ return nn.Sequential(
80
+ nn.MaxPool2d(2),
81
+ self.double_conv(in_channels, out_channels)
82
+ )
83
+
84
+ def forward(self, x, labels, time):
85
+ label_indices = torch.argmax(labels, dim=1)
86
+ label_emb = self.label_embedding(label_indices)
87
+ t_emb = self.time_mlp(time)
88
+
89
+ combined_emb = torch.cat([t_emb, label_emb], dim=1)
90
+ combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
91
+
92
+ x1 = self.inc(x)
93
+ x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
94
+
95
+ x2 = self.down1(x1_cat)
96
+ x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
97
+
98
+ x3 = self.down2(x2_cat)
99
+ x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
100
+
101
+ x4 = self.down3(x3_cat)
102
+ x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
103
+
104
+ x5 = self.bottleneck(x4_cat)
105
+
106
+ x = self.up1(x5)
107
+ x = torch.cat([x, x3], dim=1)
108
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
109
+ x = self.upconv1(x)
110
+
111
+ x = self.up2(x)
112
+ x = torch.cat([x, x2], dim=1)
113
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
114
+ x = self.upconv2(x)
115
+
116
+ x = self.up3(x)
117
+ x = torch.cat([x, x1], dim=1)
118
+ x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
119
+ x = self.upconv3(x)
120
+
121
+ output = self.outc(x)
122
+ return output
123
+
124
+ # Define the DiffusionModel class
125
+ class DiffusionModel(nn.Module):
126
+ def __init__(self, model, timesteps=500, time_dim=256):
127
+ super().__init__()
128
+ self.model = model
129
+ self.timesteps = timesteps
130
+ self.time_dim = time_dim
131
+
132
+ self.betas = self.linear_schedule(timesteps)
133
+ self.alphas = 1. - self.betas
134
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
135
+
136
+ def linear_schedule(self, timesteps):
137
+ scale = 1000 / timesteps
138
+ beta_start = scale * 0.0001
139
+ beta_end = scale * 0.02
140
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
141
+
142
+ def forward_diffusion(self, x_0, t, noise):
143
+ x_0 = x_0.float()
144
+ noise = noise.float()
145
+ alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
146
+ x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
147
+ return x_t
148
+
149
+ def forward(self, x_0, labels):
150
+ t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
151
+ noise = torch.randn_like(x_0)
152
+ x_t = self.forward_diffusion(x_0, t, noise)
153
+ predicted_noise = self.model(x_t, labels, t.float())
154
+ return predicted_noise, noise, t
155
+
156
+ @torch.no_grad()
157
+ def sample(model, num_images, timesteps, img_size, num_classes, labels, device):
158
+ x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
159
+
160
+ if labels.ndim == 1:
161
+ labels_one_hot = torch.zeros(num_images, num_classes).to(device)
162
+ labels_one_hot[torch.arange(num_images), labels] = 1
163
+ labels = labels_one_hot
164
+ else:
165
+ labels = labels.to(device)
166
+
167
+ for t in reversed(range(timesteps)):
168
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
169
+
170
+ predicted_noise = model.model(x_t, labels, t_tensor)
171
+
172
+ beta_t = model.betas[t].to(device)
173
+ alpha_t = model.alphas[t].to(device)
174
+ alpha_bar_t = model.alpha_bars[t].to(device)
175
+
176
+ mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
177
+ variance = beta_t
178
+
179
+ if t > 0:
180
+ noise = torch.randn_like(x_t)
181
+ else:
182
+ noise = torch.zeros_like(x_t)
183
+
184
+ x_t = mean + torch.sqrt(variance) * noise
185
+
186
+ x_0 = torch.clamp(x_t, -1., 1.)
187
+
188
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
189
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
190
+ x_0 = std * x_0 + mean
191
+ x_0 = torch.clamp(x_0, 0., 1.)
192
+
193
+ return x_0
194
 
195
  # Load the trained model with improved error handling
196
  def load_model(model_path, device):