SlothLoader commited on
Commit
ff18bca
·
verified ·
1 Parent(s): cc86bec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -13
app.py CHANGED
@@ -1,34 +1,69 @@
1
  import torch
 
2
  from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
  # 加载模型和处理器
7
  model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
8
  feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
9
 
10
- def predict(image):
11
  # 预处理图像
12
  inputs = feature_extractor(images=image, return_tensors="pt")
13
 
14
- # 模型推理
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
 
18
- # 获取重建的图像(MAE 的输出是像素值)
19
- reconstructed_pixel_values = outputs.logits # [1, 196, 768]
20
-
21
- # 这里需要将输出转换为可显示的图像(示例简化,实际需调整)
22
- # 注意:MAE 的输出需要后处理才能可视化,这里仅展示原始输出
23
- return f"Output shape: {reconstructed_pixel_values.shape}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # 创建 Gradio 界面
26
  iface = gr.Interface(
27
- fn=predict,
28
  inputs=gr.Image(type="pil"),
29
- outputs="text",
30
- title="MAE (Masked Autoencoder) Demo",
31
- description="Upload an image to see ViT-MAE model output.",
32
  )
33
 
34
  iface.launch()
 
1
  import torch
2
+ import numpy as np
3
  from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
4
  from PIL import Image
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
 
8
  # 加载模型和处理器
9
  model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
10
  feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
11
 
12
+ def visualize_mae(image):
13
  # 预处理图像
14
  inputs = feature_extractor(images=image, return_tensors="pt")
15
 
16
+ # 模型推理(启用掩码)
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
 
20
+ # 获取掩码和重建的像素
21
+ mask = outputs.mask # [1, 196]
22
+ reconstructed_pixels = outputs.logits # [1, 196, 768]
23
+
24
+ # 将掩码应用到原始图像(模拟掩码效果)
25
+ patch_size = model.config.patch_size
26
+ image_np = np.array(image)
27
+ h, w = image_np.shape[0], image_np.shape[1]
28
+ num_patches_h = h // patch_size
29
+ num_patches_w = w // patch_size
30
+
31
+ # 创建一个掩码图像(黑色表示被掩码的部分)
32
+ mask_image = np.zeros_like(image_np)
33
+ mask = mask[0].reshape(num_patches_h, num_patches_w)
34
+ for i in range(num_patches_h):
35
+ for j in range(num_patches_w):
36
+ if mask[i, j] == 1: # 被掩码的patch
37
+ mask_image[
38
+ i * patch_size : (i + 1) * patch_size,
39
+ j * patch_size : (j + 1) * patch_size,
40
+ ] = 0
41
+ else:
42
+ mask_image[
43
+ i * patch_size : (i + 1) * patch_size,
44
+ j * patch_size : (j + 1) * patch_size,
45
+ ] = image_np[
46
+ i * patch_size : (i + 1) * patch_size,
47
+ j * patch_size : (j + 1) * patch_size,
48
+ ]
49
+
50
+ # 可视化结果(原始图像 + 掩码图像)
51
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5))
52
+ axes[0].imshow(image_np)
53
+ axes[0].set_title("Original Image")
54
+ axes[0].axis("off")
55
+ axes[1].imshow(mask_image)
56
+ axes[1].set_title("Masked Image (MAE Input)")
57
+ axes[1].axis("off")
58
+ plt.tight_layout()
59
+ return fig
60
 
 
61
  iface = gr.Interface(
62
+ fn=visualize_mae,
63
  inputs=gr.Image(type="pil"),
64
+ outputs="plot",
65
+ title="ViT-MAE Masked Image Visualization",
66
+ description="Upload an image to see how MAE masks patches.",
67
  )
68
 
69
  iface.launch()