Ivan Shelonik commited on
Commit
e720316
·
1 Parent(s): 75241f1

upd: load types

Browse files
Files changed (1) hide show
  1. api_server.py +13 -5
api_server.py CHANGED
@@ -13,13 +13,16 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13
  from tensorflow import keras
14
  from flask import Flask, jsonify, request
15
 
16
- load_type = 'remote_hub_from_pretrained'
17
  """
18
- local,
19
- remote_hub_download - /cache error even using TRANSFORMERS_CACHE to root folder
20
- remote_hub_from_pretrained
 
 
21
  """
22
 
 
23
  REPO_ID = "1vash/mnist_demo_model"
24
 
25
  # Load the saved model into memory
@@ -29,9 +32,14 @@ elif load_type == 'remote_hub_download':
29
  from huggingface_hub import hf_hub_download
30
  model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename="saved_model.pb"))
31
  elif load_type == 'remote_hub_from_pretrained':
 
32
  from huggingface_hub import from_pretrained_keras
33
  model = from_pretrained_keras(REPO_ID, cache_dir='./artifacts/')
34
-
 
 
 
 
35
 
36
  # Initialize the Flask application
37
  app = Flask(__name__)
 
13
  from tensorflow import keras
14
  from flask import Flask, jsonify, request
15
 
16
+ load_type = ''
17
  """
18
+ local;
19
+ remote_hub_download; - /cache error even using TRANSFORMERS_CACHE & cache_dir to local folder
20
+ remote_hub_from_pretrained; - /cache error even using TRANSFORMERS_CACHE & cache_dir to local folder
21
+ remote_hub_pipeline; - needs config.json and this is not easy to grasp how to do it with custom models
22
+ https://discuss.huggingface.co/t/how-to-create-a-config-json-after-saving-a-model/10459/4
23
  """
24
 
25
+
26
  REPO_ID = "1vash/mnist_demo_model"
27
 
28
  # Load the saved model into memory
 
32
  from huggingface_hub import hf_hub_download
33
  model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename="saved_model.pb"))
34
  elif load_type == 'remote_hub_from_pretrained':
35
+ # https://huggingface.co/docs/hub/keras
36
  from huggingface_hub import from_pretrained_keras
37
  model = from_pretrained_keras(REPO_ID, cache_dir='./artifacts/')
38
+ elif load_type == 'remote_hub_pipeline':
39
+ from transformers import pipeline
40
+ classifier = pipeline("image-classification", model=REPO_ID)
41
+ else:
42
+ pass
43
 
44
  # Initialize the Flask application
45
  app = Flask(__name__)