tonyassi commited on
Commit
d86aec5
·
verified ·
1 Parent(s): 8c881da

Upload 2 files

Browse files
Files changed (2) hide show
  1. FootDetection.py +71 -0
  2. requirements.txt +4 -0
FootDetection.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
5
+ from PIL import Image, ImageDraw
6
+ from torchvision.transforms import functional as F
7
+ from huggingface_hub import hf_hub_download
8
+
9
+
10
+ class FootDetection:
11
+ def __init__(self, device="cpu"):
12
+ self.device = torch.device(device)
13
+ self.checkpoint_dir = "checkpoints"
14
+ self.checkpoint_file = "fasterrcnn_foot.pth"
15
+ self.model = self._load_model()
16
+ self.last_detection = None
17
+
18
+ def _load_model(self):
19
+ local_path = os.path.join(self.checkpoint_dir, self.checkpoint_file)
20
+
21
+ # Download if not exists
22
+ if not os.path.exists(local_path):
23
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
24
+ print("Downloading model from Hugging Face...")
25
+ local_path = hf_hub_download(
26
+ repo_id="tonyassi/foot-detection",
27
+ filename=self.checkpoint_file,
28
+ local_dir=self.checkpoint_dir
29
+ )
30
+
31
+ # Load model
32
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
33
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
34
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
35
+ model.load_state_dict(torch.load(local_path, map_location=self.device))
36
+ model.to(self.device)
37
+ model.eval()
38
+ return model
39
+
40
+ def detect(self, image, threshold=0.1):
41
+ """Run foot detection on a PIL image"""
42
+ image_tensor = F.to_tensor(image).unsqueeze(0).to(self.device)
43
+ with torch.no_grad():
44
+ outputs = self.model(image_tensor)[0]
45
+
46
+ boxes = []
47
+ scores = []
48
+ for box, score in zip(outputs["boxes"], outputs["scores"]):
49
+ if score >= threshold:
50
+ boxes.append(box.tolist())
51
+ scores.append(score.item())
52
+
53
+ self.last_detection = {
54
+ "boxes": boxes,
55
+ "scores": scores
56
+ }
57
+ return self.last_detection
58
+
59
+ def draw_boxes(self, image):
60
+ """Draw the most recent detection boxes on a copy of the image"""
61
+ if self.last_detection is None:
62
+ raise ValueError("No detection results found. Run .detect(image) first.")
63
+ image_copy = image.copy()
64
+ draw = ImageDraw.Draw(image_copy)
65
+
66
+ for box, score in zip(self.last_detection["boxes"], self.last_detection["scores"]):
67
+ x0, y0, x1, y1 = box
68
+ draw.rectangle([x0, y0, x1, y1], outline="red", width=3)
69
+ draw.text((x0, y0), f"{score:.2f}", fill="red")
70
+
71
+ return image_copy
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ huggingface_hub