Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor | |
from PIL import Image | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
# 加载模型和处理器 | |
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") | |
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base") | |
def visualize_mae(image): | |
# 预处理图像 | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# 模型推理(启用掩码) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# 获取掩码和重建的像素 | |
mask = outputs.mask # [1, 196] | |
reconstructed_pixels = outputs.logits # [1, 196, 768] | |
# 将掩码应用到原始图像(模拟掩码效果) | |
patch_size = model.config.patch_size | |
image_np = np.array(image) | |
h, w = image_np.shape[0], image_np.shape[1] | |
num_patches_h = h // patch_size | |
num_patches_w = w // patch_size | |
# 创建一个掩码图像(黑色表示被掩码的部分) | |
mask_image = np.zeros_like(image_np) | |
mask = mask[0].reshape(num_patches_h, num_patches_w) | |
for i in range(num_patches_h): | |
for j in range(num_patches_w): | |
if mask[i, j] == 1: # 被掩码的patch | |
mask_image[ | |
i * patch_size : (i + 1) * patch_size, | |
j * patch_size : (j + 1) * patch_size, | |
] = 0 | |
else: | |
mask_image[ | |
i * patch_size : (i + 1) * patch_size, | |
j * patch_size : (j + 1) * patch_size, | |
] = image_np[ | |
i * patch_size : (i + 1) * patch_size, | |
j * patch_size : (j + 1) * patch_size, | |
] | |
# 可视化结果(原始图像 + 掩码图像) | |
fig, axes = plt.subplots(1, 2, figsize=(10, 5)) | |
axes[0].imshow(image_np) | |
axes[0].set_title("Original Image") | |
axes[0].axis("off") | |
axes[1].imshow(mask_image) | |
axes[1].set_title("Masked Image (MAE Input)") | |
axes[1].axis("off") | |
plt.tight_layout() | |
return fig | |
iface = gr.Interface( | |
fn=visualize_mae, | |
inputs=gr.Image(type="pil"), | |
outputs="plot", | |
title="ViT-MAE Masked Image Visualization", | |
description="Upload an image to see how MAE masks patches.", | |
) | |
iface.launch() |