yqcyqc commited on
Commit
ba701dd
·
verified ·
1 Parent(s): 626d8a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -32
app.py CHANGED
@@ -1,28 +1,29 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from torchvision import transforms, models
5
  import pickle
6
  from resnest.torch import resnest50
 
 
 
7
 
 
8
  with open('class_names.pkl', 'rb') as f:
9
  class_names = pickle.load(f)
10
 
11
- # 加载训练好的模型
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
-
14
- model = resnest50(pretrained=None)
15
  model.fc = nn.Sequential(
16
  nn.Dropout(0.2),
17
  nn.Linear(model.fc.in_features, len(class_names))
18
  )
19
-
20
- # 加载模型权重
21
- model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
22
  model = model.to(device)
23
  model.eval()
24
 
25
- # 定义与训练时相同的预处理流程
26
  preprocess = transforms.Compose([
27
  transforms.Resize((100, 100)),
28
  transforms.ToTensor(),
@@ -30,23 +31,42 @@ preprocess = transforms.Compose([
30
  ])
31
 
32
 
33
- def predict_image(img):
34
- img = img.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # 应用预处理
37
- input_tensor = preprocess(img)
38
 
39
- # 添加批次维度并移动到设备
 
 
 
 
 
 
 
 
 
40
  input_batch = input_tensor.unsqueeze(0).to(device)
41
 
42
  # 预测
43
  with torch.no_grad():
44
  output = model(input_batch)
45
 
46
- # 计算概率
47
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
48
-
49
- # 获取前3个预测结果
50
  top3_probs, top3_indices = torch.topk(probabilities, 3)
51
 
52
  results = {
@@ -54,38 +74,66 @@ def predict_image(img):
54
  for p, i in zip(top3_probs, top3_indices)
55
  }
56
 
57
- # 获取最佳预测结果
58
  best_class = class_names[top3_indices[0]]
59
  best_conf = top3_probs[0].item() * 100
60
-
61
- return best_class, best_conf, results
62
 
63
- # 创建Gradio界面
 
 
64
  def create_interface():
65
  examples = [
66
  "r0_0_100.jpg",
67
- "r0_18_100.jpg"
 
 
 
68
  ]
69
 
70
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
71
- gr.Markdown("# 🍎 水果识别系统")
 
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with gr.Row():
74
  with gr.Column():
75
- image_input = gr.Image(type="pil", label="上传图像")
76
- gr.Examples(examples=examples, inputs=image_input)
77
- submit_btn = gr.Button("分类", variant="primary")
 
 
 
 
 
 
 
 
 
 
78
 
79
  with gr.Column():
80
- best_pred = gr.Textbox(label="预测结果")
81
- confidence = gr.Textbox(label="置信度")
82
- full_results = gr.Label(label="Top 3", num_top_classes=3)
 
83
 
84
- # ‘分类’按钮点击事件
85
  submit_btn.click(
86
  fn=predict_image,
87
- inputs=image_input,
88
- outputs=[best_pred, confidence, full_results]
89
  )
90
 
91
  return demo
@@ -93,4 +141,4 @@ def create_interface():
93
 
94
  if __name__ == "__main__":
95
  interface = create_interface()
96
- interface.launch(share=False)
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ from torchvision import transforms
5
  import pickle
6
  from resnest.torch import resnest50
7
+ from rembg import remove
8
+ from PIL import Image
9
+ import io
10
 
11
+ # 加载类别名称
12
  with open('class_names.pkl', 'rb') as f:
13
  class_names = pickle.load(f)
14
 
15
+ # 初始化模型
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = resnest50(pretrained=False)
 
18
  model.fc = nn.Sequential(
19
  nn.Dropout(0.2),
20
  nn.Linear(model.fc.in_features, len(class_names))
21
  )
22
+ model.load_state_dict(torch.load('best_model.pth', map_location=device))
 
 
23
  model = model.to(device)
24
  model.eval()
25
 
26
+ # 预处理流程
27
  preprocess = transforms.Compose([
28
  transforms.Resize((100, 100)),
29
  transforms.ToTensor(),
 
31
  ])
32
 
33
 
34
+ def remove_background(img):
35
+ """使用rembg去除背景并添加白色背景"""
36
+ # 转换图像为字节流
37
+ img_byte_arr = io.BytesIO()
38
+ img.save(img_byte_arr, format='PNG')
39
+ img_bytes = img_byte_arr.getvalue()
40
+
41
+ # 去除背景
42
+ removed_bg_bytes = remove(img_bytes)
43
+
44
+ # 转换为PIL图像并处理透明度
45
+ removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
46
+
47
+ # 创建白色背景
48
+ white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
49
+ combined = Image.alpha_composite(white_bg, removed_bg_img)
50
+ return combined.convert('RGB')
51
 
 
 
52
 
53
+ def predict_image(img, remove_bg=False):
54
+ """分类预测主函数"""
55
+ # 根据选择处理图像
56
+ if remove_bg:
57
+ processed_img = remove_background(img)
58
+ else:
59
+ processed_img = img.convert('RGB') # 确保为RGB格式
60
+
61
+ # 预处理
62
+ input_tensor = preprocess(processed_img)
63
  input_batch = input_tensor.unsqueeze(0).to(device)
64
 
65
  # 预测
66
  with torch.no_grad():
67
  output = model(input_batch)
68
 
 
69
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
 
 
70
  top3_probs, top3_indices = torch.topk(probabilities, 3)
71
 
72
  results = {
 
74
  for p, i in zip(top3_probs, top3_indices)
75
  }
76
 
77
+ # 记录结果
78
  best_class = class_names[top3_indices[0]]
79
  best_conf = top3_probs[0].item() * 100
 
 
80
 
81
+ return processed_img, best_class, f"{best_conf:.2f}%", results
82
+
83
+
84
  def create_interface():
85
  examples = [
86
  "r0_0_100.jpg",
87
+ "r0_18_100.jpg",
88
+ "9_100.jpg",
89
+ "127_100.jpg",
90
+ "r0_1_100.jpg",
91
  ]
92
 
93
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown("""
95
+ # 🍎 智能水果识别系统
96
+ """)
97
 
98
+ # 新增:模式选择卡片(视觉强化)
99
+ with gr.Row():
100
+ with gr.Column(scale=3):
101
+ with gr.Group():
102
+ gr.Markdown("### ⚙️ 处理模式选择")
103
+ with gr.Row():
104
+ bg_removal = gr.Checkbox(
105
+ label="背景去除",
106
+ value=False,
107
+ interactive=True
108
+ )
109
+
110
+ # 主操作区域
111
  with gr.Row():
112
  with gr.Column():
113
+ original_image = gr.Image(label="📤 上传图片", type="pil")
114
+ gr.Examples(examples=examples, inputs=original_image)
115
+ submit_btn = gr.Button("🚀 开始识别", variant="primary")
116
+
117
+ # 添加模式说明提示
118
+ gr.Markdown("""
119
+ <div style="background: #f3f4f6; padding: 15px; border-radius: 8px; margin-top: 10px">
120
+ <b>💡 使用建议:</b><br>
121
+ • 上传图片:选择一张图片,点击'开始识别'按钮<br>
122
+ • 勾选背景去除:适合杂乱背景的图片(识别更准确)<br>
123
+ • 不勾选:适合纯色背景的图片(速度更快)
124
+ </div>
125
+ """)
126
 
127
  with gr.Column():
128
+ processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
129
+ best_pred = gr.Textbox(label="🔍 识别结果")
130
+ confidence = gr.Textbox(label="📊 置信度")
131
+ full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3)
132
 
 
133
  submit_btn.click(
134
  fn=predict_image,
135
+ inputs=[original_image, bg_removal],
136
+ outputs=[processed_image, best_pred, confidence, full_results]
137
  )
138
 
139
  return demo
 
141
 
142
  if __name__ == "__main__":
143
  interface = create_interface()
144
+ interface.launch(share=True)