rohithk-03 commited on
Commit
c81bbd5
·
1 Parent(s): 0eb824f

update model code

Browse files
Files changed (1) hide show
  1. model.py +6 -1
model.py CHANGED
@@ -137,6 +137,11 @@ def check_file(image_path):
137
  outputs = model(images)
138
  _, predicted = torch.max(outputs.data, 1)
139
  output = predicted
 
 
 
 
 
140
  return predicted
141
 
142
  def remove_module_from_checkpoint(checkpoint):
@@ -157,4 +162,4 @@ def check_file(image_path):
157
  model = nn.DataParallel(model)
158
  output = test_model(model, test_loader, device)
159
  print(output)
160
- return "Control" if output.item() == 0 else "Axial"
 
137
  outputs = model(images)
138
  _, predicted = torch.max(outputs.data, 1)
139
  output = predicted
140
+ # Convert logits to probabilities
141
+ probabilities = F.softmax(outputs, dim=1)
142
+ # Get confidence score and prediction
143
+ confidence, d = torch.max(probabilities, 1)
144
+ print(confidence)
145
  return predicted
146
 
147
  def remove_module_from_checkpoint(checkpoint):
 
162
  model = nn.DataParallel(model)
163
  output = test_model(model, test_loader, device)
164
  print(output)
165
+ return "control" if output.item() == 0 else "ms"