mae / app.py
SlothLoader's picture
Create app.py
69bd721 verified
raw
history blame
1.1 kB
import torch
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import gradio as gr
# 加载模型和处理器
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
def predict(image):
# 预处理图像
inputs = feature_extractor(images=image, return_tensors="pt")
# 模型推理
with torch.no_grad():
outputs = model(**inputs)
# 获取重建的图像(MAE 的输出是像素值)
reconstructed_pixel_values = outputs.logits # [1, 196, 768]
# 这里需要将输出转换为可显示的图像(示例简化,实际需调整)
# 注意:MAE 的输出需要后处理才能可视化,这里仅展示原始输出
return f"Output shape: {reconstructed_pixel_values.shape}"
# 创建 Gradio 界面
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="MAE (Masked Autoencoder) Demo",
description="Upload an image to see ViT-MAE model output.",
)
iface.launch()