File size: 1,102 Bytes
69bd721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import gradio as gr

# 加载模型和处理器
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")

def predict(image):
    # 预处理图像
    inputs = feature_extractor(images=image, return_tensors="pt")

    # 模型推理
    with torch.no_grad():
        outputs = model(**inputs)

    # 获取重建的图像(MAE 的输出是像素值)
    reconstructed_pixel_values = outputs.logits  # [1, 196, 768]
    
    # 这里需要将输出转换为可显示的图像(示例简化,实际需调整)
    # 注意:MAE 的输出需要后处理才能可视化,这里仅展示原始输出
    return f"Output shape: {reconstructed_pixel_values.shape}"

# 创建 Gradio 界面
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="MAE (Masked Autoencoder) Demo",
    description="Upload an image to see ViT-MAE model output.",
)

iface.launch()