Spaces:
Sleeping
Sleeping
File size: 4,258 Bytes
620c260 b6fa136 a89ed24 620c260 b6fa136 620c260 fff6c76 620c260 a89ed24 620c260 e42d91f 620c260 dfccf5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50
import mysql.connector
from datetime import datetime
import os
def get_db_connection():
return mysql.connector.connect(
host=os.environ['DB_HOST'],
port=os.environ['DB_PORT'],
user=os.environ['DB_USER'],
password=os.environ['DB_PASSWORD'],
database=os.environ['DB_NAME']
)
with open('class_names.pkl', 'rb') as f:
class_names = pickle.load(f)
# 加载训练好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnest50(pretrained=None)
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.fc.in_features, len(class_names))
)
# 加载模型权重
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()
# 定义与训练时相同的预处理流程
preprocess = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img):
img = img.convert('RGB')
# 应用预处理
input_tensor = preprocess(img)
# 添加批次维度并移动到设备
input_batch = input_tensor.unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(input_batch)
# 计算概率
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# 获取前3个预测结果
top3_probs, top3_indices = torch.topk(probabilities, 3)
results = {
class_names[i]: p.item()
for p, i in zip(top3_probs, top3_indices)
}
# 获取最佳预测结果
best_class = class_names[top3_indices[0]]
best_conf = top3_probs[0].item() * 100
# 保存结果
with open('/tmp/prediction_results.txt', 'a') as f:
f.write(f"Image: {img}\n"
f"Predicted: {best_class}\n"
f"Confidence: {best_conf:.2f}%\n"
f"Top 3: {results}\n"
f"------------------------\n")
try:
conn = get_db_connection()
cursor = conn.cursor()
# 创建表(如果尚未创建)
cursor.execute('''
CREATE TABLE IF NOT EXISTS predictions (
id INT AUTO_INCREMENT PRIMARY KEY,
predicted_class VARCHAR(255),
confidence FLOAT,
top3_results TEXT,
timestamp DATETIME
)
''')
# 插入数据
insert_query = '''
INSERT INTO predictions
(predicted_class, confidence, top3_results, timestamp)
VALUES (%s, %s, %s, %s)
'''
cursor.execute(insert_query, (
best_class,
best_conf,
str(results),
datetime.now().strftime('%Y-%m-%d %H:%M:%S')
))
conn.commit()
except Exception as e:
print(f"Database error: {str(e)}")
finally:
if conn.is_connected():
cursor.close()
conn.close()
return best_class, best_conf, results
# 创建Gradio界面
def create_interface():
examples = [
"r0_0_100.jpg",
"r0_18_100.jpg"
]
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🍎 水果识别系统")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="上传图像")
gr.Examples(examples=examples, inputs=image_input)
submit_btn = gr.Button("分类", variant="primary")
with gr.Column():
best_pred = gr.Textbox(label="预测结果")
confidence = gr.Textbox(label="置信度")
full_results = gr.Label(label="Top 3", num_top_classes=3)
# ‘分类’按钮点击事件
submit_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[best_pred, confidence, full_results]
)
return demo
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=False)
|