Eric P. Nusbaum commited on
Commit
c62ea31
·
1 Parent(s): 2d23a8b

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -13,11 +13,16 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
13
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
14
 
15
  # Load labels
16
- with open('tensorflow/labels.txt', 'r') as f:
 
 
 
17
  labels = f.read().splitlines()
18
 
19
  # Function to load the frozen TensorFlow graph
20
  def load_frozen_graph(pb_file_path):
 
 
21
  with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
22
  graph_def = tf.compat.v1.GraphDef()
23
  graph_def.ParseFromString(f.read())
@@ -27,16 +32,19 @@ def load_frozen_graph(pb_file_path):
27
  return graph
28
 
29
  # Load the TensorFlow model
30
- MODEL_DIR = 'tensorflow/'
31
  MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
32
  graph = load_frozen_graph(MODEL_PATH)
33
  sess = tf.compat.v1.Session(graph=graph)
34
 
35
  # Define tensor names based on your model's outputs
36
- input_tensor = graph.get_tensor_by_name('image_tensor:0')
37
- detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
38
- detected_classes = graph.get_tensor_by_name('detected_classes:0')
39
- detected_scores = graph.get_tensor_by_name('detected_scores:0')
 
 
 
40
 
41
  # Define the target size based on your model's expected input
42
  TARGET_WIDTH = 320
@@ -53,7 +61,8 @@ def preprocess_image(image):
53
  image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
54
  image_np = np.array(image).astype(np.float32)
55
  image_np = image_np / 255.0 # Normalize to [0,1]
56
- image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
 
57
  image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
58
  return image_np
59
 
@@ -128,6 +137,9 @@ def predict(image):
128
  return annotated_image
129
 
130
  except Exception as e:
 
 
 
131
  # Return an error image with the error message
132
  error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
133
  draw = ImageDraw.Draw(error_image)
 
13
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
14
 
15
  # Load labels
16
+ labels_path = os.path.join('tensorflow', 'labels.txt')
17
+ if not os.path.exists(labels_path):
18
+ raise FileNotFoundError(f"Labels file not found at {labels_path}")
19
+ with open(labels_path, 'r') as f:
20
  labels = f.read().splitlines()
21
 
22
  # Function to load the frozen TensorFlow graph
23
  def load_frozen_graph(pb_file_path):
24
+ if not os.path.exists(pb_file_path):
25
+ raise FileNotFoundError(f"Model file not found at {pb_file_path}")
26
  with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
27
  graph_def = tf.compat.v1.GraphDef()
28
  graph_def.ParseFromString(f.read())
 
32
  return graph
33
 
34
  # Load the TensorFlow model
35
+ MODEL_DIR = 'tensorflow'
36
  MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
37
  graph = load_frozen_graph(MODEL_PATH)
38
  sess = tf.compat.v1.Session(graph=graph)
39
 
40
  # Define tensor names based on your model's outputs
41
+ try:
42
+ input_tensor = graph.get_tensor_by_name('image_tensor:0')
43
+ detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
44
+ detected_classes = graph.get_tensor_by_name('detected_classes:0')
45
+ detected_scores = graph.get_tensor_by_name('detected_scores:0')
46
+ except KeyError as e:
47
+ raise KeyError(f"Tensor not found in the graph: {e}")
48
 
49
  # Define the target size based on your model's expected input
50
  TARGET_WIDTH = 320
 
61
  image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
62
  image_np = np.array(image).astype(np.float32)
63
  image_np = image_np / 255.0 # Normalize to [0,1]
64
+ if image_np.shape[2] == 3:
65
+ image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
66
  image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
67
  return image_np
68
 
 
137
  return annotated_image
138
 
139
  except Exception as e:
140
+ # Log the exception for debugging
141
+ print(f"Exception during prediction: {e}")
142
+
143
  # Return an error image with the error message
144
  error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
145
  draw = ImageDraw.Draw(error_image)