Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,183 @@ NUM_CLASSES = 2
|
|
14 |
# Define the device
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
-
#
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# Load the trained model with improved error handling
|
21 |
def load_model(model_path, device):
|
|
|
14 |
# Define the device
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
+
# Define the SinusoidalPositionEmbeddings class
|
18 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
19 |
+
def __init__(self, dim):
|
20 |
+
super().__init__()
|
21 |
+
self.dim = dim
|
22 |
+
self.register_buffer('embeddings', self._precompute_embeddings(dim))
|
23 |
+
|
24 |
+
def _precompute_embeddings(self, dim):
|
25 |
+
half_dim = dim // 2
|
26 |
+
emb = math.log(10000) / (half_dim - 1)
|
27 |
+
emb = torch.exp(torch.arange(half_dim) * -emb)
|
28 |
+
return emb
|
29 |
+
|
30 |
+
def forward(self, time):
|
31 |
+
device = time.device
|
32 |
+
embeddings = self.embeddings.to(device)
|
33 |
+
embeddings = time[:, None] * embeddings[None, :]
|
34 |
+
output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
35 |
+
return output
|
36 |
+
|
37 |
+
# Define the UNet class
|
38 |
+
class UNet(nn.Module):
|
39 |
+
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.num_classes = num_classes
|
43 |
+
self.label_embedding = nn.Embedding(num_classes, time_dim)
|
44 |
+
|
45 |
+
self.time_mlp = nn.Sequential(
|
46 |
+
SinusoidalPositionEmbeddings(time_dim),
|
47 |
+
nn.Linear(time_dim, time_dim),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(time_dim, time_dim)
|
50 |
+
)
|
51 |
+
|
52 |
+
self.inc = self.double_conv(in_channels, 64)
|
53 |
+
self.down1 = self.down(64 + time_dim * 2, 128)
|
54 |
+
self.down2 = self.down(128 + time_dim * 2, 256)
|
55 |
+
self.down3 = self.down(256 + time_dim * 2, 512)
|
56 |
+
|
57 |
+
self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
|
58 |
+
|
59 |
+
self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
|
60 |
+
self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
|
61 |
+
|
62 |
+
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
63 |
+
self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
|
64 |
+
|
65 |
+
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
66 |
+
self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
|
67 |
+
|
68 |
+
self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
|
69 |
+
|
70 |
+
def double_conv(self, in_channels, out_channels):
|
71 |
+
return nn.Sequential(
|
72 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
73 |
+
nn.ReLU(inplace=True),
|
74 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
75 |
+
nn.ReLU(inplace=True)
|
76 |
+
)
|
77 |
+
|
78 |
+
def down(self, in_channels, out_channels):
|
79 |
+
return nn.Sequential(
|
80 |
+
nn.MaxPool2d(2),
|
81 |
+
self.double_conv(in_channels, out_channels)
|
82 |
+
)
|
83 |
+
|
84 |
+
def forward(self, x, labels, time):
|
85 |
+
label_indices = torch.argmax(labels, dim=1)
|
86 |
+
label_emb = self.label_embedding(label_indices)
|
87 |
+
t_emb = self.time_mlp(time)
|
88 |
+
|
89 |
+
combined_emb = torch.cat([t_emb, label_emb], dim=1)
|
90 |
+
combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
|
91 |
+
|
92 |
+
x1 = self.inc(x)
|
93 |
+
x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
|
94 |
+
|
95 |
+
x2 = self.down1(x1_cat)
|
96 |
+
x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
|
97 |
+
|
98 |
+
x3 = self.down2(x2_cat)
|
99 |
+
x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
|
100 |
+
|
101 |
+
x4 = self.down3(x3_cat)
|
102 |
+
x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
|
103 |
+
|
104 |
+
x5 = self.bottleneck(x4_cat)
|
105 |
+
|
106 |
+
x = self.up1(x5)
|
107 |
+
x = torch.cat([x, x3], dim=1)
|
108 |
+
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
|
109 |
+
x = self.upconv1(x)
|
110 |
+
|
111 |
+
x = self.up2(x)
|
112 |
+
x = torch.cat([x, x2], dim=1)
|
113 |
+
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
|
114 |
+
x = self.upconv2(x)
|
115 |
+
|
116 |
+
x = self.up3(x)
|
117 |
+
x = torch.cat([x, x1], dim=1)
|
118 |
+
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
|
119 |
+
x = self.upconv3(x)
|
120 |
+
|
121 |
+
output = self.outc(x)
|
122 |
+
return output
|
123 |
+
|
124 |
+
# Define the DiffusionModel class
|
125 |
+
class DiffusionModel(nn.Module):
|
126 |
+
def __init__(self, model, timesteps=500, time_dim=256):
|
127 |
+
super().__init__()
|
128 |
+
self.model = model
|
129 |
+
self.timesteps = timesteps
|
130 |
+
self.time_dim = time_dim
|
131 |
+
|
132 |
+
self.betas = self.linear_schedule(timesteps)
|
133 |
+
self.alphas = 1. - self.betas
|
134 |
+
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
|
135 |
+
|
136 |
+
def linear_schedule(self, timesteps):
|
137 |
+
scale = 1000 / timesteps
|
138 |
+
beta_start = scale * 0.0001
|
139 |
+
beta_end = scale * 0.02
|
140 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
141 |
+
|
142 |
+
def forward_diffusion(self, x_0, t, noise):
|
143 |
+
x_0 = x_0.float()
|
144 |
+
noise = noise.float()
|
145 |
+
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
|
146 |
+
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
|
147 |
+
return x_t
|
148 |
+
|
149 |
+
def forward(self, x_0, labels):
|
150 |
+
t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
|
151 |
+
noise = torch.randn_like(x_0)
|
152 |
+
x_t = self.forward_diffusion(x_0, t, noise)
|
153 |
+
predicted_noise = self.model(x_t, labels, t.float())
|
154 |
+
return predicted_noise, noise, t
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def sample(model, num_images, timesteps, img_size, num_classes, labels, device):
|
158 |
+
x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
|
159 |
+
|
160 |
+
if labels.ndim == 1:
|
161 |
+
labels_one_hot = torch.zeros(num_images, num_classes).to(device)
|
162 |
+
labels_one_hot[torch.arange(num_images), labels] = 1
|
163 |
+
labels = labels_one_hot
|
164 |
+
else:
|
165 |
+
labels = labels.to(device)
|
166 |
+
|
167 |
+
for t in reversed(range(timesteps)):
|
168 |
+
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
|
169 |
+
|
170 |
+
predicted_noise = model.model(x_t, labels, t_tensor)
|
171 |
+
|
172 |
+
beta_t = model.betas[t].to(device)
|
173 |
+
alpha_t = model.alphas[t].to(device)
|
174 |
+
alpha_bar_t = model.alpha_bars[t].to(device)
|
175 |
+
|
176 |
+
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
|
177 |
+
variance = beta_t
|
178 |
+
|
179 |
+
if t > 0:
|
180 |
+
noise = torch.randn_like(x_t)
|
181 |
+
else:
|
182 |
+
noise = torch.zeros_like(x_t)
|
183 |
+
|
184 |
+
x_t = mean + torch.sqrt(variance) * noise
|
185 |
+
|
186 |
+
x_0 = torch.clamp(x_t, -1., 1.)
|
187 |
+
|
188 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
189 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
190 |
+
x_0 = std * x_0 + mean
|
191 |
+
x_0 = torch.clamp(x_0, 0., 1.)
|
192 |
+
|
193 |
+
return x_0
|
194 |
|
195 |
# Load the trained model with improved error handling
|
196 |
def load_model(model_path, device):
|