Spaces:
Runtime error
Runtime error
rohithk-03
commited on
Commit
·
c81bbd5
1
Parent(s):
0eb824f
update model code
Browse files
model.py
CHANGED
@@ -137,6 +137,11 @@ def check_file(image_path):
|
|
137 |
outputs = model(images)
|
138 |
_, predicted = torch.max(outputs.data, 1)
|
139 |
output = predicted
|
|
|
|
|
|
|
|
|
|
|
140 |
return predicted
|
141 |
|
142 |
def remove_module_from_checkpoint(checkpoint):
|
@@ -157,4 +162,4 @@ def check_file(image_path):
|
|
157 |
model = nn.DataParallel(model)
|
158 |
output = test_model(model, test_loader, device)
|
159 |
print(output)
|
160 |
-
return "
|
|
|
137 |
outputs = model(images)
|
138 |
_, predicted = torch.max(outputs.data, 1)
|
139 |
output = predicted
|
140 |
+
# Convert logits to probabilities
|
141 |
+
probabilities = F.softmax(outputs, dim=1)
|
142 |
+
# Get confidence score and prediction
|
143 |
+
confidence, d = torch.max(probabilities, 1)
|
144 |
+
print(confidence)
|
145 |
return predicted
|
146 |
|
147 |
def remove_module_from_checkpoint(checkpoint):
|
|
|
162 |
model = nn.DataParallel(model)
|
163 |
output = test_model(model, test_loader, device)
|
164 |
print(output)
|
165 |
+
return "control" if output.item() == 0 else "ms"
|