jpterry commited on
Commit
4483788
·
1 Parent(s): 7c1bba2

registering

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -11,7 +11,7 @@ import sys
11
  # import timm
12
  from types import SimpleNamespace
13
  # from transformers import AutoModel, pipeline
14
- from transformers import AutoModelForImageClassification, AutoModel
15
  import torch
16
 
17
  sys.path.insert(1, "../")
@@ -254,6 +254,17 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
254
  EfficientNetConfig.model_type = hparams.model_type
255
 
256
  config.save_pretrained(save_directory=model_loading_name)
 
 
 
 
 
 
 
 
 
 
 
257
  # config = EfficientNetConfig.from_pretrained(model_loading_name)
258
 
259
  # model = EfficientNetPreTrained.from_pretrained(model_loading_name)
 
11
  # import timm
12
  from types import SimpleNamespace
13
  # from transformers import AutoModel, pipeline
14
+ from transformers import AutoModelForImageClassification, AutoModel, AutoConfig
15
  import torch
16
 
17
  sys.path.insert(1, "../")
 
254
  EfficientNetConfig.model_type = hparams.model_type
255
 
256
  config.save_pretrained(save_directory=model_loading_name)
257
+
258
+ model = EfficientNet(dropout=hparams.dropout,
259
+ num_channels=hparams.num_channels,
260
+ num_classes=hparams.num_classes,
261
+ size=hparams.size,
262
+ stochastic_depth_prob=hparams.stochastic_depth_prob,
263
+ width_mult=hparams.width_mult,
264
+ depth_mult=hparams.depth_mult,)
265
+
266
+ AutoConfig.register(model_loading_name, config)
267
+ AutoModel.register(config, model)
268
  # config = EfficientNetConfig.from_pretrained(model_loading_name)
269
 
270
  # model = EfficientNetPreTrained.from_pretrained(model_loading_name)