panchadip commited on
Commit
bbdec87
·
verified ·
1 Parent(s): e8da286

Create load_model.py

Browse files
Files changed (1) hide show
  1. models/load_model.py +21 -0
models/load_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import torch
3
+
4
+ # Define a custom load function to force CPU loading
5
+ def custom_torch_load(f, *args, **kwargs):
6
+ if 'map_location' not in kwargs:
7
+ kwargs['map_location'] = torch.device("cpu")
8
+ return torch_load_backup(f, *args, **kwargs)
9
+
10
+ # Monkey patch torch.load inside joblib.load
11
+ torch_load_backup = torch.load # Backup the original function
12
+ torch.load = custom_torch_load # Override with CPU-only loading
13
+
14
+ # Load the model
15
+ topic_model = joblib.load("bertopic_model_max_compressed.joblib")
16
+
17
+ # Restore the original torch.load function
18
+ torch.load = torch_load_backup
19
+
20
+ # Verify loading success
21
+ print("Model loaded successfully on CPU")