import torch import numpy as np from transformers import ViTMAEForPreTraining, ViTImageProcessor from PIL import Image import gradio as gr import matplotlib.pyplot as plt model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") feature_extractor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base") def visualize_mae(image): # 调整图像大小并预处理 image = image.resize((224, 224)) inputs = feature_extractor(images=image, return_tensors="pt") # 模型推理 with torch.no_grad(): outputs = model(**inputs) # 获取掩码(14x14) mask = outputs.mask[0].reshape(14, 14).cpu().numpy() # 可视化 plt.imshow(mask, cmap="gray") plt.title("MAE Mask (14x14)") plt.axis("off") return plt.gcf() iface = gr.Interface( fn=visualize_mae, inputs=gr.Image(type="pil"), outputs="plot", title="ViT-MAE Mask Visualization", description="Upload an image to see the MAE masking pattern.", ) iface.launch()