MoinulwithAI commited on
Commit
953eab4
·
verified ·
1 Parent(s): 59ea35e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.models.detection import FasterRCNN
4
+ from torchvision.models.detection.fasterrcnn_resnet50_fpn import FastRCNNPredictor
5
+ from torchvision.transforms import functional as F
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import gradio as gr
8
+
9
+ # Force CPU
10
+ device = torch.device('cpu')
11
+
12
+ # COCO-style class map
13
+ COCO_CLASSES = {
14
+ 0: "Background",
15
+ 1: "Stand",
16
+ 2: "Sit",
17
+ 3: "Ruku",
18
+ 4: "Sijdah"
19
+ }
20
+
21
+ # Load model
22
+ def get_model(num_classes):
23
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
24
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
25
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
26
+ return model
27
+
28
+ model = get_model(num_classes=5)
29
+ model.load_state_dict(torch.load("Salatfasterrcnn_resnet50_epoch_3.pth", map_location=device))
30
+ model.to(device)
31
+ model.eval()
32
+
33
+ # Prediction function
34
+ def predict(image):
35
+ image = image.convert("RGB")
36
+ image_tensor = F.to_tensor(image).unsqueeze(0).to(device)
37
+
38
+ with torch.no_grad():
39
+ prediction = model(image_tensor)
40
+
41
+ draw = ImageDraw.Draw(image)
42
+ boxes = prediction[0]["boxes"].cpu().numpy()
43
+ labels = prediction[0]["labels"].cpu().numpy()
44
+ scores = prediction[0]["scores"].cpu().numpy()
45
+
46
+ for box, label, score in zip(boxes, labels, scores):
47
+ if score > 0.5:
48
+ x_min, y_min, x_max, y_max = box
49
+ class_name = COCO_CLASSES.get(label, "Unknown")
50
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
51
+ draw.text((x_min, y_min), f"{class_name} ({score:.2f})", fill="red")
52
+
53
+ return image
54
+
55
+ # Gradio interface
56
+ gr.Interface(
57
+ fn=predict,
58
+ inputs=gr.Image(type="pil"),
59
+ outputs=gr.Image(type="pil"),
60
+ title="Salat Posture Detection",
61
+ description="Upload an image to detect salat postures (stand, sit, ruku, sijdah)."
62
+ ).launch()