yqcyqc commited on
Commit
620c260
·
verified ·
1 Parent(s): cbaa73e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -12
app.py CHANGED
@@ -1,12 +1,104 @@
1
- import gradio as gr
2
- def greet(name):
3
- return "Hello " + name + "!"
4
-
5
-
6
- demo = gr.Interface(
7
- fn=greet,
8
- inputs="textbox",
9
- outputs="textbox"
10
- )
11
-
12
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ import pickle
6
+
7
+
8
+ with open('class_names.pkl', 'rb') as f:
9
+ class_names = pickle.load(f)
10
+
11
+ # 加载训练好的模型
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ model = models.resnet50(weights=None)
15
+ model.fc = nn.Sequential(
16
+ nn.Dropout(0.2),
17
+ nn.Linear(model.fc.in_features, len(class_names))
18
+ )
19
+
20
+ # 加载模型权重
21
+ model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
22
+ model = model.to(device)
23
+ model.eval()
24
+
25
+ # 定义与训练时相同的预处理流程
26
+ preprocess = transforms.Compose([
27
+ transforms.Resize((100, 100)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
+ ])
31
+
32
+
33
+ def predict_image(img):
34
+ img = img.convert('RGB')
35
+
36
+ # 应用预处理
37
+ input_tensor = preprocess(img)
38
+
39
+ # 添加批次维度并移动到设备
40
+ input_batch = input_tensor.unsqueeze(0).to(device)
41
+
42
+ # 预测
43
+ with torch.no_grad():
44
+ output = model(input_batch)
45
+
46
+ # 计算概率
47
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
48
+
49
+ # 获取前3个预测结果
50
+ top3_probs, top3_indices = torch.topk(probabilities, 3)
51
+
52
+ results = {
53
+ class_names[i]: p.item()
54
+ for p, i in zip(top3_probs, top3_indices)
55
+ }
56
+
57
+ # 获取最佳预测结果
58
+ best_class = class_names[top3_indices[0]]
59
+ best_conf = top3_probs[0].item() * 100
60
+
61
+ # 保存结果
62
+ with open('prediction_results.txt', 'a') as f:
63
+ f.write(f"Image: {img}\n"
64
+ f"Predicted: {best_class}\n"
65
+ f"Confidence: {best_conf:.2f}%\n"
66
+ f"Top 3: {results}\n"
67
+ f"------------------------\n")
68
+
69
+ return best_class, best_conf, results
70
+
71
+ # 创建Gradio界面
72
+ def create_interface():
73
+ examples = [
74
+ "data/r0_0_100.jpg",
75
+ "data/r0_18_100.jpg"
76
+ ]
77
+
78
+ with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
79
+ gr.Markdown("# 🍎 水果识别系统")
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ image_input = gr.Image(type="pil", label="上传图像")
84
+ gr.Examples(examples=examples, inputs=image_input)
85
+ submit_btn = gr.Button("分类", variant="primary")
86
+
87
+ with gr.Column():
88
+ best_pred = gr.Textbox(label="预测结果")
89
+ confidence = gr.Textbox(label="置信度")
90
+ full_results = gr.Label(label="Top 3", num_top_classes=3)
91
+
92
+ # ‘分类’按钮点击事件
93
+ submit_btn.click(
94
+ fn=predict_image,
95
+ inputs=image_input,
96
+ outputs=[best_pred, confidence, full_results]
97
+ )
98
+
99
+ return demo
100
+
101
+
102
+ if __name__ == "__main__":
103
+ interface = create_interface()
104
+ interface.launch(share=True)