noeedc
commited on
Commit
·
266f0ac
1
Parent(s):
07b368a
Add initial implementation of Surgical Contaminent Classifier-Mix with model, config, and inference script
Browse files- README.md +69 -0
- classifier.py +86 -0
- config.json +9 -0
- example_inference.py +22 -0
- pytorch_model.bin +3 -0
README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Surgical Contaminent Classifier-Mix
|
2 |
+
|
3 |
+
This repository contains a PyTorch-based image classifier for identifying visual contaminants in surgical footage. The model distinguishes between five classes: `blur`, `smoke`, `clear`, `fluid`, and `oob` (out-of-body). It uses a MobileNetV2 backbone via [timm](https://github.com/huggingface/pytorch-image-models), and is compatible with Hugging Face Transformers' `AutoModel` and `AutoConfig` using `trust_remote_code=True`.
|
4 |
+
|
5 |
+
The name **"classifier-mix"** refers to the training data source, a mix of DaVinci and Medtronic RARP surgical frames.
|
6 |
+
|
7 |
+
> Training log:
|
8 |
+
> `gs://noee/mobileNet/Medtronic_28-04-2025/Run_13h20_Finetune_lr0.0001_ReduceLROnPlateau/training.log`
|
9 |
+
>
|
10 |
+
## Files
|
11 |
+
|
12 |
+
- `classifier.py`: Model and config implementation.
|
13 |
+
- `config.json`: Hugging Face model configuration.
|
14 |
+
- `pytorch_model.bin`: Model weights.
|
15 |
+
- `sample_img.png`: Example image for inference.
|
16 |
+
- `example_inference.py`: Example script for running inference.
|
17 |
+
|
18 |
+
## Usage
|
19 |
+
|
20 |
+
### Installation
|
21 |
+
|
22 |
+
Install required dependencies:
|
23 |
+
```sh
|
24 |
+
pip install torch torchvision timm transformers pillow
|
25 |
+
```
|
26 |
+
|
27 |
+
### Model Details
|
28 |
+
|
29 |
+
- **Backbone:** MobileNetV2 (`mobilenetv2_100`)
|
30 |
+
- **Classes:** blur, smoke, clear, fluid, oob
|
31 |
+
- **Input size:** 224x224 RGB images
|
32 |
+
- **Normalization:** mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991]
|
33 |
+
|
34 |
+
### Inference Example
|
35 |
+
You can run the provided inference script to dehaze the sample image:
|
36 |
+
|
37 |
+
|
38 |
+
```python
|
39 |
+
# example_inference.py
|
40 |
+
from transformers import AutoModel
|
41 |
+
from PIL import Image
|
42 |
+
|
43 |
+
# Load model
|
44 |
+
model = AutoModel.from_pretrained(
|
45 |
+
"vopeai/classifier-mix",
|
46 |
+
trust_remote_code=True
|
47 |
+
)
|
48 |
+
model.eval()
|
49 |
+
|
50 |
+
# Load and preprocess image
|
51 |
+
img = Image.open("sample_img.png").convert("RGB")
|
52 |
+
|
53 |
+
# Run inference
|
54 |
+
outputs = model(img)
|
55 |
+
|
56 |
+
print("Predicted class:", outputs[0]['label'])
|
57 |
+
print("Confidences:", outputs[0]['confidences'])
|
58 |
+
```
|
59 |
+
|
60 |
+
Or use the model in your own code, by loading the model as follows :
|
61 |
+
|
62 |
+
```python
|
63 |
+
from transformers import AutoModel
|
64 |
+
|
65 |
+
# Load model
|
66 |
+
model = AutoModel.from_pretrained("vopeai/classifier-mix", trust_remote_code=True)
|
67 |
+
```
|
68 |
+
|
69 |
+
For more details, see the code files in this repository.
|
classifier.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import timm
|
3 |
+
from torchvision import transforms
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torch.nn.functional import softmax
|
7 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
8 |
+
|
9 |
+
LABEL_MAP = ["blur", "smoke", "clear", "fluid", "oob"]
|
10 |
+
|
11 |
+
|
12 |
+
class ClassifierConfig(PretrainedConfig):
|
13 |
+
model_type = "classifier"
|
14 |
+
|
15 |
+
def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), **kwargs):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.model_name = model_name
|
18 |
+
self.num_classes = num_classes
|
19 |
+
|
20 |
+
class ClassifierModel(nn.Module):
|
21 |
+
def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), pretrained=True):
|
22 |
+
super().__init__()
|
23 |
+
self.base_model = timm.create_model(model_name, pretrained=pretrained)
|
24 |
+
num_features = self.base_model.classifier.in_features
|
25 |
+
# Use Sequential to match saved model structure
|
26 |
+
self.base_model.classifier = nn.Sequential(
|
27 |
+
nn.Linear(num_features, num_classes)
|
28 |
+
)
|
29 |
+
if "mobilenetv2" in model_name:
|
30 |
+
self.target_layer = self.base_model.conv_head
|
31 |
+
else:
|
32 |
+
raise NotImplementedError(f"Grad-CAM target layer not defined for model: {model_name}")
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.base_model(x)
|
36 |
+
|
37 |
+
|
38 |
+
class ClassifierWrapper(PreTrainedModel):
|
39 |
+
config_class = ClassifierConfig
|
40 |
+
|
41 |
+
def __init__(self, config):
|
42 |
+
super().__init__(config)
|
43 |
+
self.model = ClassifierModel(
|
44 |
+
model_name=config.model_name,
|
45 |
+
num_classes=config.num_classes,
|
46 |
+
pretrained=False # Weights are loaded by from_pretrained
|
47 |
+
)
|
48 |
+
|
49 |
+
self.transform = transforms.Compose([
|
50 |
+
transforms.Resize((224, 224)),
|
51 |
+
transforms.Normalize(mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991])
|
52 |
+
])
|
53 |
+
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
# Ensure input is a tensor
|
57 |
+
if isinstance(input, Image.Image):
|
58 |
+
x = transforms.ToTensor()(input).unsqueeze(0) # Convert PIL Image to tensor
|
59 |
+
elif isinstance(input, torch.Tensor):
|
60 |
+
if input.dim() == 3:
|
61 |
+
x = input.unsqueeze(0) # Single tensor image
|
62 |
+
elif input.dim() == 4:
|
63 |
+
x = input # Batch
|
64 |
+
else:
|
65 |
+
raise ValueError("Unsupported tensor shape.")
|
66 |
+
else:
|
67 |
+
raise TypeError(f"Unsupported input type: {type(input)}. Expected PIL.Image or torch.Tensor.")
|
68 |
+
|
69 |
+
# Apply transformations
|
70 |
+
x = self.transform(x)
|
71 |
+
|
72 |
+
# Forward pass through the model
|
73 |
+
outputs = self.model(x)
|
74 |
+
|
75 |
+
confs = softmax(outputs, dim=1)
|
76 |
+
preds = torch.argmax(confs, dim=1)
|
77 |
+
|
78 |
+
results = []
|
79 |
+
for i in range(len(preds)):
|
80 |
+
label = LABEL_MAP[preds[i]]
|
81 |
+
confidences = {}
|
82 |
+
for j in range(len(LABEL_MAP)):
|
83 |
+
confidences[LABEL_MAP[j]] = round(float(confs[0][j]), 3)
|
84 |
+
|
85 |
+
results.append({"label": label, "confidences": confidences})
|
86 |
+
return results
|
config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "classifier",
|
3 |
+
"architectures": ["ClassifierWrapper"],
|
4 |
+
"auto_map": {
|
5 |
+
"AutoModel": "classifier.ClassifierWrapper",
|
6 |
+
"AutoConfig": "classifier.ClassifierConfig"
|
7 |
+
},
|
8 |
+
"trust_remote_code": true
|
9 |
+
}
|
example_inference.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModel, AutoConfig
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
|
7 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
8 |
+
|
9 |
+
# Load model and config
|
10 |
+
model = AutoModel.from_pretrained(
|
11 |
+
"./classifier-mix", # or path to your model directory
|
12 |
+
trust_remote_code=True
|
13 |
+
)
|
14 |
+
model.eval()
|
15 |
+
|
16 |
+
# Load and preprocess image
|
17 |
+
img = Image.open("classifier-mix/sample_img.png").convert("RGB")
|
18 |
+
|
19 |
+
outputs = model(img)
|
20 |
+
|
21 |
+
print("Predicted class:", outputs[0]['label'])
|
22 |
+
print("Confidences:", outputs[0]['confidences'])
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c834b0f9b606b5f04d377015d3fab1976097639eed81b6f91360838409b1d24
|
3 |
+
size 9158475
|