rohithk-03 commited on
Commit
1e646da
·
1 Parent(s): b500916

update return msg

Browse files
Files changed (1) hide show
  1. model.py +2 -1
model.py CHANGED
@@ -78,6 +78,7 @@ class HybridCNNViT(nn.Module):
78
 
79
 
80
  def load_and_pad_single_image(image_path, img_size=(224, 224)):
 
81
  img = cv2.imread(image_path)
82
  if img is None:
83
  raise ValueError(f"Could not read image: {image_path}")
@@ -145,7 +146,7 @@ def check_file(image_path):
145
  return checkpoint
146
 
147
  model = HybridCNNViT(3, 2)
148
- checkpoint = torch.load('checkpoint32.pth')
149
  checkpoint = remove_module_from_checkpoint(checkpoint)
150
  model.load_state_dict(checkpoint['model_state_dict'])
151
  # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 
78
 
79
 
80
  def load_and_pad_single_image(image_path, img_size=(224, 224)):
81
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  img = cv2.imread(image_path)
83
  if img is None:
84
  raise ValueError(f"Could not read image: {image_path}")
 
146
  return checkpoint
147
 
148
  model = HybridCNNViT(3, 2)
149
+ checkpoint = torch.load("/home/user/app/checkpoint32.pth")
150
  checkpoint = remove_module_from_checkpoint(checkpoint)
151
  model.load_state_dict(checkpoint['model_state_dict'])
152
  # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])