mae / app.py
SlothLoader's picture
Update app.py
ff18bca verified
raw
history blame
2.32 kB
import torch
import numpy as np
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
# 加载模型和处理器
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
def visualize_mae(image):
# 预处理图像
inputs = feature_extractor(images=image, return_tensors="pt")
# 模型推理(启用掩码)
with torch.no_grad():
outputs = model(**inputs)
# 获取掩码和重建的像素
mask = outputs.mask # [1, 196]
reconstructed_pixels = outputs.logits # [1, 196, 768]
# 将掩码应用到原始图像(模拟掩码效果)
patch_size = model.config.patch_size
image_np = np.array(image)
h, w = image_np.shape[0], image_np.shape[1]
num_patches_h = h // patch_size
num_patches_w = w // patch_size
# 创建一个掩码图像(黑色表示被掩码的部分)
mask_image = np.zeros_like(image_np)
mask = mask[0].reshape(num_patches_h, num_patches_w)
for i in range(num_patches_h):
for j in range(num_patches_w):
if mask[i, j] == 1: # 被掩码的patch
mask_image[
i * patch_size : (i + 1) * patch_size,
j * patch_size : (j + 1) * patch_size,
] = 0
else:
mask_image[
i * patch_size : (i + 1) * patch_size,
j * patch_size : (j + 1) * patch_size,
] = image_np[
i * patch_size : (i + 1) * patch_size,
j * patch_size : (j + 1) * patch_size,
]
# 可视化结果(原始图像 + 掩码图像)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(image_np)
axes[0].set_title("Original Image")
axes[0].axis("off")
axes[1].imshow(mask_image)
axes[1].set_title("Masked Image (MAE Input)")
axes[1].axis("off")
plt.tight_layout()
return fig
iface = gr.Interface(
fn=visualize_mae,
inputs=gr.Image(type="pil"),
outputs="plot",
title="ViT-MAE Masked Image Visualization",
description="Upload an image to see how MAE masks patches.",
)
iface.launch()