Spaces:
Sleeping
Sleeping
File size: 1,102 Bytes
69bd721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import gradio as gr
# 加载模型和处理器
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
def predict(image):
# 预处理图像
inputs = feature_extractor(images=image, return_tensors="pt")
# 模型推理
with torch.no_grad():
outputs = model(**inputs)
# 获取重建的图像(MAE 的输出是像素值)
reconstructed_pixel_values = outputs.logits # [1, 196, 768]
# 这里需要将输出转换为可显示的图像(示例简化,实际需调整)
# 注意:MAE 的输出需要后处理才能可视化,这里仅展示原始输出
return f"Output shape: {reconstructed_pixel_values.shape}"
# 创建 Gradio 界面
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="MAE (Masked Autoencoder) Demo",
description="Upload an image to see ViT-MAE model output.",
)
iface.launch() |