Spaces:
Sleeping
Sleeping
File size: 2,315 Bytes
69bd721 ff18bca 69bd721 ff18bca 69bd721 ff18bca 69bd721 ff18bca 69bd721 ff18bca 69bd721 ff18bca 69bd721 ff18bca 69bd721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import numpy as np
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
# 加载模型和处理器
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
def visualize_mae(image):
# 预处理图像
inputs = feature_extractor(images=image, return_tensors="pt")
# 模型推理(启用掩码)
with torch.no_grad():
outputs = model(**inputs)
# 获取掩码和重建的像素
mask = outputs.mask # [1, 196]
reconstructed_pixels = outputs.logits # [1, 196, 768]
# 将掩码应用到原始图像(模拟掩码效果)
patch_size = model.config.patch_size
image_np = np.array(image)
h, w = image_np.shape[0], image_np.shape[1]
num_patches_h = h // patch_size
num_patches_w = w // patch_size
# 创建一个掩码图像(黑色表示被掩码的部分)
mask_image = np.zeros_like(image_np)
mask = mask[0].reshape(num_patches_h, num_patches_w)
for i in range(num_patches_h):
for j in range(num_patches_w):
if mask[i, j] == 1: # 被掩码的patch
mask_image[
i * patch_size : (i + 1) * patch_size,
j * patch_size : (j + 1) * patch_size,
] = 0
else:
mask_image[
i * patch_size : (i + 1) * patch_size,
j * patch_size : (j + 1) * patch_size,
] = image_np[
i * patch_size : (i + 1) * patch_size,
j * patch_size : (j + 1) * patch_size,
]
# 可视化结果(原始图像 + 掩码图像)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(image_np)
axes[0].set_title("Original Image")
axes[0].axis("off")
axes[1].imshow(mask_image)
axes[1].set_title("Masked Image (MAE Input)")
axes[1].axis("off")
plt.tight_layout()
return fig
iface = gr.Interface(
fn=visualize_mae,
inputs=gr.Image(type="pil"),
outputs="plot",
title="ViT-MAE Masked Image Visualization",
description="Upload an image to see how MAE masks patches.",
)
iface.launch() |