jpterry commited on
Commit
513f826
·
1 Parent(s): a5000fc

going back to automodel

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. model_utils/efficientnet_config.py +1 -1
app.py CHANGED
@@ -10,7 +10,7 @@ import sys
10
  # import timm
11
  from types import SimpleNamespace
12
  # from transformers import AutoModel, pipeline
13
- from transformers import AutoModelForImageClassification
14
  import torch
15
 
16
  sys.path.insert(1, "../")
@@ -220,12 +220,13 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
220
  width_mult=hparams.width_mult,
221
  depth_mult=hparams.depth_mult,
222
  )
 
223
 
224
  config.save_pretrained(save_directory=model_loading_name)
225
  # config = EfficientNetConfig.from_pretrained(model_loading_name)
226
 
227
- model = EfficientNetPreTrained.from_pretrained(model_loading_name)
228
- # model = AutoModelForImageClassification.from_pretrained(model_loading_name, trust_remote_code=True)
229
 
230
  # model = EfficientNetPreTrained(config)
231
  # config.register_for_auto_class()
 
10
  # import timm
11
  from types import SimpleNamespace
12
  # from transformers import AutoModel, pipeline
13
+ from transformers import AutoModelForImageClassification, AutoModel
14
  import torch
15
 
16
  sys.path.insert(1, "../")
 
220
  width_mult=hparams.width_mult,
221
  depth_mult=hparams.depth_mult,
222
  )
223
+ EfficientNetConfig.model_type = "efficientnet_%s_planet_detection" % (hparams.num_channels)
224
 
225
  config.save_pretrained(save_directory=model_loading_name)
226
  # config = EfficientNetConfig.from_pretrained(model_loading_name)
227
 
228
+ # model = EfficientNetPreTrained.from_pretrained(model_loading_name)
229
+ model = AutoModel.from_pretrained(model_loading_name, trust_remote_code=True)
230
 
231
  # model = EfficientNetPreTrained(config)
232
  # config.register_for_auto_class()
model_utils/efficientnet_config.py CHANGED
@@ -242,7 +242,7 @@ class FusedMBConv(nn.Module):
242
 
243
  class EfficientNetConfig(PretrainedConfig):
244
 
245
- model_type = "efficientnet_61_planet_detection"
246
 
247
  def __init__(
248
  self,
 
242
 
243
  class EfficientNetConfig(PretrainedConfig):
244
 
245
+ model_type = "efficientnet"
246
 
247
  def __init__(
248
  self,