ning8429 commited on
Commit
c2fcab3
·
verified ·
1 Parent(s): 82e2346

Update api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +31 -32
api_server.py CHANGED
@@ -13,13 +13,14 @@ from tensorflow import keras
13
  from flask import Flask, jsonify, request, render_template, send_file
14
  import torch
15
  from collections import Counter
 
16
 
17
  # Disable tensorflow warnings
18
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19
 
20
  load_type = 'local'
21
 
22
- MODEL_NAME = "yolo11_detect_best_241024_1.pt"
23
  MODEL_DIR = "./artifacts/models"
24
  YOLO_DIR = "./artifacts/yolo"
25
  #REPO_ID = "1vash/mnist_demo_model"
@@ -27,7 +28,7 @@ YOLO_DIR = "./artifacts/yolo"
27
  # Load the saved YOLO model into memory
28
  if load_type == 'local':
29
  # 本地模型路徑
30
- model_path = f'{MODEL_DIR}/{MODEL_NAME}'
31
  if not os.path.exists(model_path):
32
  raise FileNotFoundError(f"Model file not found at {model_path}")
33
 
@@ -37,7 +38,7 @@ elif load_type == 'remote_hub_download':
37
  from huggingface_hub import hf_hub_download
38
 
39
  # 從 Hugging Face Hub 下載模型
40
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_NAME)
41
  model = torch.load(model_path)
42
  #model.eval()
43
  elif load_type == 'remote_hub_from_pretrained':
@@ -45,7 +46,7 @@ elif load_type == 'remote_hub_from_pretrained':
45
  os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute())
46
  from huggingface_hub import from_pretrained
47
 
48
- model = from_pretrained(REPO_ID, filename=MODEL_NAME, cache_dir=MODEL_DIR)
49
  #model.eval()
50
  else:
51
  raise AssertionError('No load type is specified!')
@@ -118,56 +119,54 @@ def predict():
118
 
119
  encoded_images=[]
120
  element_list =[]
 
121
 
122
  for element, count in element_counts.items():
123
 
124
- output_path = f"{YOLO_DIR}/{message_id}/{element}"
125
- output_file = get_jpg_files(output_path)
126
 
127
  element_list.append(element)
128
 
129
- for output_img in output_file: # 取得每張圖的路徑
130
- encoded_images.append(image_to_base64(output_img))
 
131
 
132
  # if element_counts[element] > 1: #某隻角色的數量>1
133
- # output_path = f"{YOLO_DIR}/{message_id}/{element}"
134
- # output_file = get_jpg_files(output_path)
135
 
136
- # for output_img in output_file: # 取得每張圖的路徑
137
- # encoded_images.append(image_to_base64(output_img))
138
 
139
  # else : #某隻角色的數量=1
140
- # output_path = f"{YOLO_DIR}/{message_id}/{element}/im.jpg.jpg"
141
- # encoded_images.append(image_to_base64(output_path))
142
 
143
  # 建立回應資料
144
  response_data = {
145
  'message_id': message_id,
146
- 'images': encoded_images,
147
  'description': element_list
 
 
 
 
 
 
 
148
  }
 
 
 
 
 
149
 
150
  return jsonify(response_data)
151
 
152
  # for label_name in label_names:
153
- # output_file=f"{YOLO_DIR}/{message_id}/{label_name}/im.jpg.jpg"
154
  # # 將圖片轉換為 base64 編碼
155
- # encoded_images.append(image_to_base64(output_file))
156
-
157
-
158
- # # 渲染推理結果到圖像
159
- # img_with_boxes = results[0].plot() # 使用 results[0],假設只有一張圖像作推理
160
-
161
- # # 將 numpy array 轉換為 PIL Image
162
- # img = Image.fromarray(img_with_boxes)
163
-
164
- # # 儲存圖片到內存緩衝區
165
- # img_io = io.BytesIO()
166
- # img.save(img_io, 'PNG')
167
- # img_io.seek(0)
168
-
169
- # # 返回處理後的圖像
170
- # return send_file(img_io, mimetype='image/png')
171
 
172
 
173
  # # dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary
 
13
  from flask import Flask, jsonify, request, render_template, send_file
14
  import torch
15
  from collections import Counter
16
+ from clip import clip_result
17
 
18
  # Disable tensorflow warnings
19
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
20
 
21
  load_type = 'local'
22
 
23
+ MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
24
  MODEL_DIR = "./artifacts/models"
25
  YOLO_DIR = "./artifacts/yolo"
26
  #REPO_ID = "1vash/mnist_demo_model"
 
28
  # Load the saved YOLO model into memory
29
  if load_type == 'local':
30
  # 本地模型路徑
31
+ model_path = f'{MODEL_DIR}/{MODEL_YOLO}'
32
  if not os.path.exists(model_path):
33
  raise FileNotFoundError(f"Model file not found at {model_path}")
34
 
 
38
  from huggingface_hub import hf_hub_download
39
 
40
  # 從 Hugging Face Hub 下載模型
41
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_YOLO)
42
  model = torch.load(model_path)
43
  #model.eval()
44
  elif load_type == 'remote_hub_from_pretrained':
 
46
  os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute())
47
  from huggingface_hub import from_pretrained
48
 
49
+ model = from_pretrained(REPO_ID, filename=MODEL_YOLO, cache_dir=MODEL_DIR)
50
  #model.eval()
51
  else:
52
  raise AssertionError('No load type is specified!')
 
119
 
120
  encoded_images=[]
121
  element_list =[]
122
+ top_k_words =[]
123
 
124
  for element, count in element_counts.items():
125
 
126
+ yolo_path = f"{YOLO_DIR}/{message_id}/{element}"
127
+ yolo_file = get_jpg_files(yolo_path)
128
 
129
  element_list.append(element)
130
 
131
+ for yolo_img in yolo_file: # 每張切圖yolo_img
132
+ top_k_words.append(clip_result(yolo_img)) # CLIP預測3個結果(top_k_words)
133
+ encoded_images.append(image_to_base64(yolo_img))
134
 
135
  # if element_counts[element] > 1: #某隻角色的數量>1
136
+ # yolo_path = f"{YOLO_DIR}/{message_id}/{element}"
137
+ # yolo_file = get_jpg_files(yolo_path)
138
 
139
+ # for yolo_img in yolo_file: # 取得每張圖的路徑
140
+ # encoded_images.append(image_to_base64(yolo_img))
141
 
142
  # else : #某隻角色的數量=1
143
+ # yolo_path = f"{YOLO_DIR}/{message_id}/{element}/im.jpg.jpg"
144
+ # encoded_images.append(image_to_base64(yolo_path))
145
 
146
  # 建立回應資料
147
  response_data = {
148
  'message_id': message_id,
 
149
  'description': element_list
150
+ 'images': [
151
+ {
152
+ 'encoded_image': encoded_image,
153
+ 'description_list': top_k_words
154
+ }
155
+ for encoded_image, elements in zip(encoded_images, element_list)
156
+ ]
157
  }
158
+ # response_data = {
159
+ # 'message_id': message_id,
160
+ # 'images': encoded_images,
161
+ # 'description': element_list
162
+ # }
163
 
164
  return jsonify(response_data)
165
 
166
  # for label_name in label_names:
167
+ # yolo_file=f"{YOLO_DIR}/{message_id}/{label_name}/im.jpg.jpg"
168
  # # 將圖片轉換為 base64 編碼
169
+ # encoded_images.append(image_to_base64(yolo_file))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  # # dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary