SlothLoader commited on
Commit
791fec0
·
verified ·
1 Parent(s): ccf6c03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -50
app.py CHANGED
@@ -1,69 +1,36 @@
1
  import torch
2
  import numpy as np
3
- from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
4
  from PIL import Image
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
 
8
- # 加载模型和处理器
9
  model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
10
- feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
11
 
12
  def visualize_mae(image):
13
- # 预处理图像
 
14
  inputs = feature_extractor(images=image, return_tensors="pt")
15
-
16
- # 模型推理(启用掩码)
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
-
20
- # 获取掩码和重建的像素
21
- mask = outputs.mask # [1, 196]
22
- reconstructed_pixels = outputs.logits # [1, 196, 768]
23
-
24
- # 将掩码应用到原始图像(模拟掩码效果)
25
- patch_size = model.config.patch_size
26
- image_np = np.array(image)
27
- h, w = image_np.shape[0], image_np.shape[1]
28
- num_patches_h = h // patch_size
29
- num_patches_w = w // patch_size
30
-
31
- # 创建一个掩码图像(黑色表示被掩码的部分)
32
- mask_image = np.zeros_like(image_np)
33
- mask = mask[0].reshape(num_patches_h, num_patches_w)
34
- for i in range(num_patches_h):
35
- for j in range(num_patches_w):
36
- if mask[i, j] == 1: # 被掩码的patch
37
- mask_image[
38
- i * patch_size : (i + 1) * patch_size,
39
- j * patch_size : (j + 1) * patch_size,
40
- ] = 0
41
- else:
42
- mask_image[
43
- i * patch_size : (i + 1) * patch_size,
44
- j * patch_size : (j + 1) * patch_size,
45
- ] = image_np[
46
- i * patch_size : (i + 1) * patch_size,
47
- j * patch_size : (j + 1) * patch_size,
48
- ]
49
-
50
- # 可视化结果(原始图像 + 掩码图像)
51
- fig, axes = plt.subplots(1, 2, figsize=(10, 5))
52
- axes[0].imshow(image_np)
53
- axes[0].set_title("Original Image")
54
- axes[0].axis("off")
55
- axes[1].imshow(mask_image)
56
- axes[1].set_title("Masked Image (MAE Input)")
57
- axes[1].axis("off")
58
- plt.tight_layout()
59
- return fig
60
 
61
  iface = gr.Interface(
62
  fn=visualize_mae,
63
  inputs=gr.Image(type="pil"),
64
  outputs="plot",
65
- title="ViT-MAE Masked Image Visualization",
66
- description="Upload an image to see how MAE masks patches.",
67
  )
68
-
69
  iface.launch()
 
1
  import torch
2
  import numpy as np
3
+ from transformers import ViTMAEForPreTraining, ViTImageProcessor
4
  from PIL import Image
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
 
 
8
  model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
9
+ feature_extractor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
10
 
11
  def visualize_mae(image):
12
+ # 调整图像大小并预处理
13
+ image = image.resize((224, 224))
14
  inputs = feature_extractor(images=image, return_tensors="pt")
15
+
16
+ # 模型推理
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
+
20
+ # 获取掩码(14x14)
21
+ mask = outputs.mask[0].reshape(14, 14).cpu().numpy()
22
+
23
+ # 可视化
24
+ plt.imshow(mask, cmap="gray")
25
+ plt.title("MAE Mask (14x14)")
26
+ plt.axis("off")
27
+ return plt.gcf()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  iface = gr.Interface(
30
  fn=visualize_mae,
31
  inputs=gr.Image(type="pil"),
32
  outputs="plot",
33
+ title="ViT-MAE Mask Visualization",
34
+ description="Upload an image to see the MAE masking pattern.",
35
  )
 
36
  iface.launch()