File size: 4,258 Bytes
620c260
 
 
 
 
b6fa136
a89ed24
 
 
 
 
 
 
 
 
 
 
 
 
 
620c260
 
 
 
 
 
 
b6fa136
620c260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fff6c76
620c260
 
 
 
 
 
a89ed24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620c260
 
 
 
 
e42d91f
 
620c260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfccf5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50
import mysql.connector
from datetime import datetime
import os


def get_db_connection():
    return mysql.connector.connect(
        host=os.environ['DB_HOST'],
        port=os.environ['DB_PORT'],
        user=os.environ['DB_USER'],
        password=os.environ['DB_PASSWORD'],
        database=os.environ['DB_NAME']
    )


with open('class_names.pkl', 'rb') as f:
    class_names = pickle.load(f)

# 加载训练好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = resnest50(pretrained=None)
model.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(model.fc.in_features, len(class_names))
)

# 加载模型权重
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()

# 定义与训练时相同的预处理流程
preprocess = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def predict_image(img):
    img = img.convert('RGB')

    # 应用预处理
    input_tensor = preprocess(img)

    # 添加批次维度并移动到设备
    input_batch = input_tensor.unsqueeze(0).to(device)

    # 预测
    with torch.no_grad():
        output = model(input_batch)

    # 计算概率
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # 获取前3个预测结果
    top3_probs, top3_indices = torch.topk(probabilities, 3)

    results = {
        class_names[i]: p.item()
        for p, i in zip(top3_probs, top3_indices)
    }

    # 获取最佳预测结果
    best_class = class_names[top3_indices[0]]
    best_conf = top3_probs[0].item() * 100

    # 保存结果
    with open('/tmp/prediction_results.txt', 'a') as f:
        f.write(f"Image: {img}\n"
                f"Predicted: {best_class}\n"
                f"Confidence: {best_conf:.2f}%\n"
                f"Top 3: {results}\n"
                f"------------------------\n")

    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        
        # 创建表(如果尚未创建)
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS predictions (
                id INT AUTO_INCREMENT PRIMARY KEY,
                predicted_class VARCHAR(255),
                confidence FLOAT,
                top3_results TEXT,
                timestamp DATETIME
            )
        ''')
        
        # 插入数据
        insert_query = '''
            INSERT INTO predictions 
            (predicted_class, confidence, top3_results, timestamp)
            VALUES (%s, %s, %s, %s)
        '''
        cursor.execute(insert_query, (
            best_class,
            best_conf,
            str(results),
            datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        ))
        
        conn.commit()
    except Exception as e:
        print(f"Database error: {str(e)}")
    finally:
        if conn.is_connected():
            cursor.close()
            conn.close()
    
    return best_class, best_conf, results

# 创建Gradio界面
def create_interface():
    examples = [
        "r0_0_100.jpg",
        "r0_18_100.jpg"
    ]

    with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🍎 水果识别系统")

        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="上传图像")
                gr.Examples(examples=examples, inputs=image_input)
                submit_btn = gr.Button("分类", variant="primary")

            with gr.Column():
                best_pred = gr.Textbox(label="预测结果")
                confidence = gr.Textbox(label="置信度")
                full_results = gr.Label(label="Top 3", num_top_classes=3)

        # ‘分类’按钮点击事件
        submit_btn.click(
            fn=predict_image,
            inputs=image_input,
            outputs=[best_pred, confidence, full_results]
        )

    return demo


if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=False)