jpterry commited on
Commit
e5c4282
·
1 Parent(s): 520d2ff

making model first

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -17,7 +17,7 @@ import torch
17
  sys.path.insert(1, "../")
18
  # from utils import model_utils, train_utils, data_utils, run_utils
19
  # from model_utils import jason_regnet_maker, jason_efficientnet_maker
20
- from model_utils.efficientnet_config import EfficientNetConfig, EfficientNetPreTrained
21
 
22
  model_path = 'chlab/'
23
  # model_path = './models/'
@@ -223,13 +223,22 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
223
  )
224
  EfficientNetConfig.model_type = "efficientnet_%s_planet_detection" % (hparams.num_channels)
225
 
226
- config.save_pretrained(save_directory=model_loading_name)
227
  # config = EfficientNetConfig.from_pretrained(model_loading_name)
228
 
229
  # model = EfficientNetPreTrained.from_pretrained(model_loading_name)
230
  # model = AutoModel.from_pretrained(model_loading_name, trust_remote_code=True)
231
 
232
- model = cached_download(hf_hub_url(model_loading_name, filename="pytorch_model.bin"))
 
 
 
 
 
 
 
 
 
233
 
234
  print(model)
235
 
 
17
  sys.path.insert(1, "../")
18
  # from utils import model_utils, train_utils, data_utils, run_utils
19
  # from model_utils import jason_regnet_maker, jason_efficientnet_maker
20
+ from model_utils.efficientnet_config import EfficientNetConfig, EfficientNetPreTrained, EfficientNet
21
 
22
  model_path = 'chlab/'
23
  # model_path = './models/'
 
223
  )
224
  EfficientNetConfig.model_type = "efficientnet_%s_planet_detection" % (hparams.num_channels)
225
 
226
+ # config.save_pretrained(save_directory=model_loading_name)
227
  # config = EfficientNetConfig.from_pretrained(model_loading_name)
228
 
229
  # model = EfficientNetPreTrained.from_pretrained(model_loading_name)
230
  # model = AutoModel.from_pretrained(model_loading_name, trust_remote_code=True)
231
 
232
+ model = EfficientNet(dropout=hparams.dropout,
233
+ num_channels=hparams.num_channels,
234
+ num_classes=hparams.num_classes,
235
+ size=hparams.size,
236
+ stochastic_depth_prob=hparams.stochastic_depth_prob,
237
+ width_mult=hparams.width_mult,
238
+ depth_mult=hparams.depth_mult,)
239
+
240
+ model_url = cached_download(hf_hub_url(model_loading_name, filename="pytorch_model.bin"))
241
+ model.load_state_dict(torch.load(model_url, map_location='cpu'))
242
 
243
  print(model)
244