ning8429 commited on
Commit
3762ea2
·
verified ·
1 Parent(s): e2292d4

Update api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +47 -19
api_server.py CHANGED
@@ -15,6 +15,7 @@ import torch
15
  from collections import Counter
16
  import psutil
17
  from gradio_client import Client, handle_file
 
18
 
19
  # Disable tensorflow warnings
20
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@@ -24,7 +25,8 @@ load_type = 'local'
24
  MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
25
  MODEL_DIR = "./artifacts/models"
26
  YOLO_DIR = "./artifacts/yolo"
27
- GRADIO_URL = "https://fd39e54bcb191a37bf.gradio.live/"
 
28
 
29
 
30
  # Load the saved YOLO model into memory
@@ -36,7 +38,7 @@ if load_type == 'local':
36
 
37
  model = YOLO(model_path)
38
 
39
- print("***** 1. LOAD YOLO MODEL DONE *****")
40
  #model.eval() # 設定模型為推理模式
41
  elif load_type == 'remote_hub_download':
42
  from huggingface_hub import hf_hub_download
@@ -62,6 +64,18 @@ def image_to_base64(image_path):
62
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
63
  return encoded_string
64
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # 抓取指定路徑下的所有 JPG 檔案
67
  def get_jpg_files(path):
@@ -117,11 +131,11 @@ def predict():
117
  except Exception as e:
118
  return jsonify({'error': str(e)}), 400
119
 
120
- print("***** 2. Start YOLO predict *****")
121
  # Make a prediction using YOLO
122
  results = model(image_data)
123
- print ("===== YOLO predict result:",results,"=====")
124
- print("***** YOLO predict DONE *****")
125
 
126
  check_memory_usage()
127
 
@@ -145,7 +159,7 @@ def predict():
145
  labels = result.boxes.cls # Get predicted label IDs
146
  label_names = [model.names[int(label)] for label in labels] # Convert to names
147
 
148
- print(f"====== 3. YOLO label_names: {label_names}======")
149
 
150
  element_counts = Counter(label_names)
151
 
@@ -154,15 +168,15 @@ def predict():
154
  yolo_path = f"{YOLO_DIR}/{message_id}/{element}"
155
  yolo_file = get_jpg_files(yolo_path)
156
 
157
- print(f"***** 處理:{yolo_path} *****")
158
 
159
  if len(yolo_file) == 0:
160
- print(f"警告:{element} 沒有找到相關的 JPG 檔案")
161
  continue
162
 
163
  for yolo_img in yolo_file: # 每張切圖yolo_img
164
- print("***** 4. START CLIP *****")
165
- client = Client(GRADIO_URL)
166
  clip_result = client.predict(
167
  image=handle_file(yolo_img),
168
  top_k=3,
@@ -171,7 +185,7 @@ def predict():
171
  top_k_words.append(clip_result) # CLIP預測3個結果(top_k_words)
172
  encoded_images.append(image_to_base64(yolo_img))
173
  element_list.append(element)
174
- print(f"===== CLIP result:{top_k_words} =====\n")
175
 
176
  # 建立回應資料
177
  response_data = {
@@ -193,14 +207,28 @@ def predict():
193
 
194
 
195
  # API route for health check
196
- @app.route('/health', methods=['GET'])
197
- def health():
198
- """
199
- Health check API to ensure the application is running.
200
- Returns "OK" if the application is healthy.
201
- Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health"
202
- """
203
- return 'OK'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  # API route for version
 
15
  from collections import Counter
16
  import psutil
17
  from gradio_client import Client, handle_file
18
+ from io import BytesIO
19
 
20
  # Disable tensorflow warnings
21
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
 
25
  MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
26
  MODEL_DIR = "./artifacts/models"
27
  YOLO_DIR = "./artifacts/yolo"
28
+ IMG2TEXT_URL = "https://fd39e54bcb191a37bf.gradio.live/"
29
+ TEXT2IMG_URL = "https://91698ded8ba92d3bb0.gradio.live/"
30
 
31
 
32
  # Load the saved YOLO model into memory
 
38
 
39
  model = YOLO(model_path)
40
 
41
+ print("***** FLASK API---LOAD YOLO MODEL DONE *****")
42
  #model.eval() # 設定模型為推理模式
43
  elif load_type == 'remote_hub_download':
44
  from huggingface_hub import hf_hub_download
 
64
  encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
65
  return encoded_string
66
 
67
+ def convert_webp_to_base64(webp_path):
68
+ # 開啟 .webp 圖片檔
69
+ with Image.open(webp_path) as img:
70
+ # 將圖片存到 BytesIO 物件中,以便轉換為 base64
71
+ buffered = BytesIO()
72
+ img.save(buffered, format="WEBP")
73
+
74
+ # 取得 base64 編碼的字串
75
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
76
+
77
+ return img_base64
78
+
79
 
80
  # 抓取指定路徑下的所有 JPG 檔案
81
  def get_jpg_files(path):
 
131
  except Exception as e:
132
  return jsonify({'error': str(e)}), 400
133
 
134
+ print("***** FLASK API---/predict Start YOLO predict *****")
135
  # Make a prediction using YOLO
136
  results = model(image_data)
137
+ print ("===== FLASK API---/predict YOLO predict result:",results,"=====")
138
+ print("***** FLASK API---/predict YOLO predict DONE *****")
139
 
140
  check_memory_usage()
141
 
 
159
  labels = result.boxes.cls # Get predicted label IDs
160
  label_names = [model.names[int(label)] for label in labels] # Convert to names
161
 
162
+ print(f"====== FLASK API---/predict 3. YOLO label_names: {label_names}======")
163
 
164
  element_counts = Counter(label_names)
165
 
 
168
  yolo_path = f"{YOLO_DIR}/{message_id}/{element}"
169
  yolo_file = get_jpg_files(yolo_path)
170
 
171
+ print(f"***** FLASK API---/predict 處理:{yolo_path} *****")
172
 
173
  if len(yolo_file) == 0:
174
+ print(f" FLASK API---/predict 警告:{element} 沒有找到相關的 JPG 檔案")
175
  continue
176
 
177
  for yolo_img in yolo_file: # 每張切圖yolo_img
178
+ print("***** FLASK API---/predict 4. START CLIP *****")
179
+ client = Client(IMG2TEXT_URL)
180
  clip_result = client.predict(
181
  image=handle_file(yolo_img),
182
  top_k=3,
 
185
  top_k_words.append(clip_result) # CLIP預測3個結果(top_k_words)
186
  encoded_images.append(image_to_base64(yolo_img))
187
  element_list.append(element)
188
+ print(f"===== FLASK API---/predict CLIP result:{top_k_words} =====\n")
189
 
190
  # 建立回應資料
191
  response_data = {
 
207
 
208
 
209
  # API route for health check
210
+ @app.route('/text2img', methods=['POST'])
211
+ def text2img():
212
+ text_message = request.form.get('text_message')
213
+ message_id = request.form.get('message_id')
214
+
215
+ client = Client(TEXT2IMG_URL)
216
+ result = client.predict(
217
+ word= text_message,
218
+ api_name="/predict"
219
+ )
220
+ print(f"===== FLASK API---/text2img 文字轉圖片result[0]:{result[0]} =====")
221
+ result_img = convert_webp_to_base64(result[0])
222
+ print(f"===== FLASK API---/text2img 文字轉圖片轉base64:{result_img} =====")
223
+
224
+ # 建立回應資料
225
+ response_data = {
226
+ 'message_id': message_id,
227
+ 'encoded_image': result_img,
228
+ 'description': result[1]
229
+ }
230
+
231
+ return jsonify(response_data), 200
232
 
233
 
234
  # API route for version