yqcyqc's picture
Update app.py
a89ed24 verified
raw
history blame
4.26 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50
import mysql.connector
from datetime import datetime
import os
def get_db_connection():
return mysql.connector.connect(
host=os.environ['DB_HOST'],
port=os.environ['DB_PORT'],
user=os.environ['DB_USER'],
password=os.environ['DB_PASSWORD'],
database=os.environ['DB_NAME']
)
with open('class_names.pkl', 'rb') as f:
class_names = pickle.load(f)
# 加载训练好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnest50(pretrained=None)
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.fc.in_features, len(class_names))
)
# 加载模型权重
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()
# 定义与训练时相同的预处理流程
preprocess = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img):
img = img.convert('RGB')
# 应用预处理
input_tensor = preprocess(img)
# 添加批次维度并移动到设备
input_batch = input_tensor.unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(input_batch)
# 计算概率
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# 获取前3个预测结果
top3_probs, top3_indices = torch.topk(probabilities, 3)
results = {
class_names[i]: p.item()
for p, i in zip(top3_probs, top3_indices)
}
# 获取最佳预测结果
best_class = class_names[top3_indices[0]]
best_conf = top3_probs[0].item() * 100
# 保存结果
with open('/tmp/prediction_results.txt', 'a') as f:
f.write(f"Image: {img}\n"
f"Predicted: {best_class}\n"
f"Confidence: {best_conf:.2f}%\n"
f"Top 3: {results}\n"
f"------------------------\n")
try:
conn = get_db_connection()
cursor = conn.cursor()
# 创建表(如果尚未创建)
cursor.execute('''
CREATE TABLE IF NOT EXISTS predictions (
id INT AUTO_INCREMENT PRIMARY KEY,
predicted_class VARCHAR(255),
confidence FLOAT,
top3_results TEXT,
timestamp DATETIME
)
''')
# 插入数据
insert_query = '''
INSERT INTO predictions
(predicted_class, confidence, top3_results, timestamp)
VALUES (%s, %s, %s, %s)
'''
cursor.execute(insert_query, (
best_class,
best_conf,
str(results),
datetime.now().strftime('%Y-%m-%d %H:%M:%S')
))
conn.commit()
except Exception as e:
print(f"Database error: {str(e)}")
finally:
if conn.is_connected():
cursor.close()
conn.close()
return best_class, best_conf, results
# 创建Gradio界面
def create_interface():
examples = [
"r0_0_100.jpg",
"r0_18_100.jpg"
]
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🍎 水果识别系统")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="上传图像")
gr.Examples(examples=examples, inputs=image_input)
submit_btn = gr.Button("分类", variant="primary")
with gr.Column():
best_pred = gr.Textbox(label="预测结果")
confidence = gr.Textbox(label="置信度")
full_results = gr.Label(label="Top 3", num_top_classes=3)
# ‘分类’按钮点击事件
submit_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[best_pred, confidence, full_results]
)
return demo
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=False)