SlothLoader commited on
Commit
69bd721
·
verified ·
1 Parent(s): d7e575a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ # 加载模型和处理器
7
+ model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
9
+
10
+ def predict(image):
11
+ # 预处理图像
12
+ inputs = feature_extractor(images=image, return_tensors="pt")
13
+
14
+ # 模型推理
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+
18
+ # 获取重建的图像(MAE 的输出是像素值)
19
+ reconstructed_pixel_values = outputs.logits # [1, 196, 768]
20
+
21
+ # 这里需要将输出转换为可显示的图像(示例简化,实际需调整)
22
+ # 注意:MAE 的输出需要后处理才能可视化,这里仅展示原始输出
23
+ return f"Output shape: {reconstructed_pixel_values.shape}"
24
+
25
+ # 创建 Gradio 界面
26
+ iface = gr.Interface(
27
+ fn=predict,
28
+ inputs=gr.Image(type="pil"),
29
+ outputs="text",
30
+ title="MAE (Masked Autoencoder) Demo",
31
+ description="Upload an image to see ViT-MAE model output.",
32
+ )
33
+
34
+ iface.launch()