traopia commited on
Commit
115de81
·
1 Parent(s): 06ceb44
Files changed (1) hide show
  1. src/visual_qa.py +1 -1
src/visual_qa.py CHANGED
@@ -67,7 +67,7 @@ def main_create_image_collection(df_emb, collection_name="clip_image_embeddings"
67
 
68
 
69
  model_name = "patrickjohncyh/fashion-clip"
70
- device = "mps"
71
  model = CLIPModel.from_pretrained(model_name).to(device)
72
  processor = CLIPProcessor.from_pretrained(model_name)
73
 
 
67
 
68
 
69
  model_name = "patrickjohncyh/fashion-clip"
70
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
71
  model = CLIPModel.from_pretrained(model_name).to(device)
72
  processor = CLIPProcessor.from_pretrained(model_name)
73