noeedc commited on
Commit
266f0ac
·
1 Parent(s): 07b368a

Add initial implementation of Surgical Contaminent Classifier-Mix with model, config, and inference script

Browse files
Files changed (5) hide show
  1. README.md +69 -0
  2. classifier.py +86 -0
  3. config.json +9 -0
  4. example_inference.py +22 -0
  5. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Surgical Contaminent Classifier-Mix
2
+
3
+ This repository contains a PyTorch-based image classifier for identifying visual contaminants in surgical footage. The model distinguishes between five classes: `blur`, `smoke`, `clear`, `fluid`, and `oob` (out-of-body). It uses a MobileNetV2 backbone via [timm](https://github.com/huggingface/pytorch-image-models), and is compatible with Hugging Face Transformers' `AutoModel` and `AutoConfig` using `trust_remote_code=True`.
4
+
5
+ The name **"classifier-mix"** refers to the training data source, a mix of DaVinci and Medtronic RARP surgical frames.
6
+
7
+ > Training log:
8
+ > `gs://noee/mobileNet/Medtronic_28-04-2025/Run_13h20_Finetune_lr0.0001_ReduceLROnPlateau/training.log`
9
+ >
10
+ ## Files
11
+
12
+ - `classifier.py`: Model and config implementation.
13
+ - `config.json`: Hugging Face model configuration.
14
+ - `pytorch_model.bin`: Model weights.
15
+ - `sample_img.png`: Example image for inference.
16
+ - `example_inference.py`: Example script for running inference.
17
+
18
+ ## Usage
19
+
20
+ ### Installation
21
+
22
+ Install required dependencies:
23
+ ```sh
24
+ pip install torch torchvision timm transformers pillow
25
+ ```
26
+
27
+ ### Model Details
28
+
29
+ - **Backbone:** MobileNetV2 (`mobilenetv2_100`)
30
+ - **Classes:** blur, smoke, clear, fluid, oob
31
+ - **Input size:** 224x224 RGB images
32
+ - **Normalization:** mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991]
33
+
34
+ ### Inference Example
35
+ You can run the provided inference script to dehaze the sample image:
36
+
37
+
38
+ ```python
39
+ # example_inference.py
40
+ from transformers import AutoModel
41
+ from PIL import Image
42
+
43
+ # Load model
44
+ model = AutoModel.from_pretrained(
45
+ "vopeai/classifier-mix",
46
+ trust_remote_code=True
47
+ )
48
+ model.eval()
49
+
50
+ # Load and preprocess image
51
+ img = Image.open("sample_img.png").convert("RGB")
52
+
53
+ # Run inference
54
+ outputs = model(img)
55
+
56
+ print("Predicted class:", outputs[0]['label'])
57
+ print("Confidences:", outputs[0]['confidences'])
58
+ ```
59
+
60
+ Or use the model in your own code, by loading the model as follows :
61
+
62
+ ```python
63
+ from transformers import AutoModel
64
+
65
+ # Load model
66
+ model = AutoModel.from_pretrained("vopeai/classifier-mix", trust_remote_code=True)
67
+ ```
68
+
69
+ For more details, see the code files in this repository.
classifier.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+ from torchvision import transforms
4
+ import torch
5
+ from PIL import Image
6
+ from torch.nn.functional import softmax
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+
9
+ LABEL_MAP = ["blur", "smoke", "clear", "fluid", "oob"]
10
+
11
+
12
+ class ClassifierConfig(PretrainedConfig):
13
+ model_type = "classifier"
14
+
15
+ def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.model_name = model_name
18
+ self.num_classes = num_classes
19
+
20
+ class ClassifierModel(nn.Module):
21
+ def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), pretrained=True):
22
+ super().__init__()
23
+ self.base_model = timm.create_model(model_name, pretrained=pretrained)
24
+ num_features = self.base_model.classifier.in_features
25
+ # Use Sequential to match saved model structure
26
+ self.base_model.classifier = nn.Sequential(
27
+ nn.Linear(num_features, num_classes)
28
+ )
29
+ if "mobilenetv2" in model_name:
30
+ self.target_layer = self.base_model.conv_head
31
+ else:
32
+ raise NotImplementedError(f"Grad-CAM target layer not defined for model: {model_name}")
33
+
34
+ def forward(self, x):
35
+ return self.base_model(x)
36
+
37
+
38
+ class ClassifierWrapper(PreTrainedModel):
39
+ config_class = ClassifierConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.model = ClassifierModel(
44
+ model_name=config.model_name,
45
+ num_classes=config.num_classes,
46
+ pretrained=False # Weights are loaded by from_pretrained
47
+ )
48
+
49
+ self.transform = transforms.Compose([
50
+ transforms.Resize((224, 224)),
51
+ transforms.Normalize(mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991])
52
+ ])
53
+
54
+
55
+ def forward(self, input):
56
+ # Ensure input is a tensor
57
+ if isinstance(input, Image.Image):
58
+ x = transforms.ToTensor()(input).unsqueeze(0) # Convert PIL Image to tensor
59
+ elif isinstance(input, torch.Tensor):
60
+ if input.dim() == 3:
61
+ x = input.unsqueeze(0) # Single tensor image
62
+ elif input.dim() == 4:
63
+ x = input # Batch
64
+ else:
65
+ raise ValueError("Unsupported tensor shape.")
66
+ else:
67
+ raise TypeError(f"Unsupported input type: {type(input)}. Expected PIL.Image or torch.Tensor.")
68
+
69
+ # Apply transformations
70
+ x = self.transform(x)
71
+
72
+ # Forward pass through the model
73
+ outputs = self.model(x)
74
+
75
+ confs = softmax(outputs, dim=1)
76
+ preds = torch.argmax(confs, dim=1)
77
+
78
+ results = []
79
+ for i in range(len(preds)):
80
+ label = LABEL_MAP[preds[i]]
81
+ confidences = {}
82
+ for j in range(len(LABEL_MAP)):
83
+ confidences[LABEL_MAP[j]] = round(float(confs[0][j]), 3)
84
+
85
+ results.append({"label": label, "confidences": confidences})
86
+ return results
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "classifier",
3
+ "architectures": ["ClassifierWrapper"],
4
+ "auto_map": {
5
+ "AutoModel": "classifier.ClassifierWrapper",
6
+ "AutoConfig": "classifier.ClassifierConfig"
7
+ },
8
+ "trust_remote_code": true
9
+ }
example_inference.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoConfig
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import os
6
+
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ # Load model and config
10
+ model = AutoModel.from_pretrained(
11
+ "./classifier-mix", # or path to your model directory
12
+ trust_remote_code=True
13
+ )
14
+ model.eval()
15
+
16
+ # Load and preprocess image
17
+ img = Image.open("classifier-mix/sample_img.png").convert("RGB")
18
+
19
+ outputs = model(img)
20
+
21
+ print("Predicted class:", outputs[0]['label'])
22
+ print("Confidences:", outputs[0]['confidences'])
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c834b0f9b606b5f04d377015d3fab1976097639eed81b6f91360838409b1d24
3
+ size 9158475