Ivan Shelonik commited on
Commit
4acae3e
·
1 Parent(s): 0792f4b

add: remote_hub_from_pretrained

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. api_server.py +16 -7
  3. model.py +0 -46
Dockerfile CHANGED
@@ -16,7 +16,7 @@ EXPOSE 5000
16
 
17
  # Set the environment variable for Flask
18
  ENV FLASK_APP=api_server.py
19
- ENV TRANSFORMERS_CACHE=$(pwd)/.transformers_cache
20
 
21
  # Run the Flask application
22
  CMD ["flask", "run", "--host=0.0.0.0"]
 
16
 
17
  # Set the environment variable for Flask
18
  ENV FLASK_APP=api_server.py
19
+ ENV TRANSFORMERS_CACHE=transformers_cache
20
 
21
  # Run the Flask application
22
  CMD ["flask", "run", "--host=0.0.0.0"]
api_server.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import time
3
  import numpy as np
4
- from huggingface_hub import hf_hub_download
5
 
6
  # Disable tensorflow warnings
7
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@@ -9,15 +8,25 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9
  from tensorflow import keras
10
  from flask import Flask, jsonify, request
11
 
12
- LOCAL = False
 
 
 
 
 
 
 
13
 
14
  # Load the saved model into memory
15
- if LOCAL is True:
16
  model = keras.models.load_model('artifacts/models/mnist_model.h5')
17
- else:
18
- REPO_ID = "1vash/mnist_demo_model"
19
- FILENAME = "saved_model.pb"
20
- model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename=FILENAME))
 
 
 
21
 
22
  # Initialize the Flask application
23
  app = Flask(__name__)
 
1
  import os
2
  import time
3
  import numpy as np
 
4
 
5
  # Disable tensorflow warnings
6
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
 
8
  from tensorflow import keras
9
  from flask import Flask, jsonify, request
10
 
11
+ load_type = ''
12
+ """
13
+ local,
14
+ remote_hub_download - /cache error even using TRANSFORMERS_CACHE to root folder
15
+ remote_hub_from_pretrained
16
+ """
17
+
18
+ REPO_ID = "1vash/mnist_demo_model"
19
 
20
  # Load the saved model into memory
21
+ if load_type == 'local':
22
  model = keras.models.load_model('artifacts/models/mnist_model.h5')
23
+ elif load_type == 'remote_hub_download':
24
+ from huggingface_hub import hf_hub_download
25
+ model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename="saved_model.pb"))
26
+ elif load_type == 'remote_hub_from_pretrained':
27
+ from huggingface_hub import from_pretrained_keras
28
+ model = from_pretrained_keras(REPO_ID)
29
+
30
 
31
  # Initialize the Flask application
32
  app = Flask(__name__)
model.py DELETED
@@ -1,46 +0,0 @@
1
- import os
2
- import random
3
- import numpy as np
4
-
5
- # disable tensorflow warnings
6
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
7
-
8
- import tensorflow as tf
9
- from tensorflow import keras
10
- from keras.datasets import mnist
11
-
12
- # Set the random seed for reproducibility, remember these lines :)
13
- SEED = 42
14
- random.seed(SEED)
15
- np.random.seed(SEED)
16
- tf.random.set_seed(SEED)
17
-
18
- # Load the dataset from keras.datasets (so noone would need to download it manually from any sources)
19
- (x_train, y_train), (x_test, y_test) = mnist.load_data()
20
-
21
- # Preprocess the dataset
22
- x_train = x_train.astype('float32') / 255.0
23
- x_test = x_test.astype('float32') / 255.0
24
-
25
- # Define the model architecture
26
- model = keras.Sequential([
27
- keras.layers.Flatten(input_shape=(28, 28)),
28
- keras.layers.Dense(128, activation='relu'),
29
- keras.layers.Dense(10, activation='softmax')
30
- ])
31
-
32
- # Compile and train the model
33
- # target in one-hot categorical_crossentropy -> [0,0,1,0,0,0,0,0,0]
34
- # target can be as integer sparse_categorical_crossentropy -> 3
35
- model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
36
-
37
- # 4-epoch is overfitting, 3-rd is okay
38
- model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=4, shuffle=True, batch_size=32)
39
-
40
- # Evaluate the model
41
- print('\n')
42
- _, test_accuracy = model.evaluate(x_test, y_test)
43
- print('Test accuracy:', test_accuracy)
44
-
45
- # Save the model
46
- model.save('artifacts/models/mnist_model.h5')