Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor | |
from PIL import Image | |
import gradio as gr | |
# 加载模型和处理器 | |
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") | |
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base") | |
def predict(image): | |
# 预处理图像 | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# 模型推理 | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# 获取重建的图像(MAE 的输出是像素值) | |
reconstructed_pixel_values = outputs.logits # [1, 196, 768] | |
# 这里需要将输出转换为可显示的图像(示例简化,实际需调整) | |
# 注意:MAE 的输出需要后处理才能可视化,这里仅展示原始输出 | |
return f"Output shape: {reconstructed_pixel_values.shape}" | |
# 创建 Gradio 界面 | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
title="MAE (Masked Autoencoder) Demo", | |
description="Upload an image to see ViT-MAE model output.", | |
) | |
iface.launch() |