mae / app.py
SlothLoader's picture
Update app.py
791fec0 verified
import torch
import numpy as np
from transformers import ViTMAEForPreTraining, ViTImageProcessor
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
def visualize_mae(image):
# 调整图像大小并预处理
image = image.resize((224, 224))
inputs = feature_extractor(images=image, return_tensors="pt")
# 模型推理
with torch.no_grad():
outputs = model(**inputs)
# 获取掩码(14x14)
mask = outputs.mask[0].reshape(14, 14).cpu().numpy()
# 可视化
plt.imshow(mask, cmap="gray")
plt.title("MAE Mask (14x14)")
plt.axis("off")
return plt.gcf()
iface = gr.Interface(
fn=visualize_mae,
inputs=gr.Image(type="pil"),
outputs="plot",
title="ViT-MAE Mask Visualization",
description="Upload an image to see the MAE masking pattern.",
)
iface.launch()