ning8429 commited on
Commit
f66696f
·
verified ·
1 Parent(s): 4273422

Update api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +11 -4
api_server.py CHANGED
@@ -13,8 +13,8 @@ from tensorflow import keras
13
  from flask import Flask, jsonify, request, render_template, send_file
14
  import torch
15
  from collections import Counter
16
- from clip_model import ClipModel
17
  import psutil
 
18
 
19
  # Disable tensorflow warnings
20
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@@ -24,6 +24,7 @@ load_type = 'local'
24
  MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
25
  MODEL_DIR = "./artifacts/models"
26
  YOLO_DIR = "./artifacts/yolo"
 
27
  #REPO_ID = "1vash/mnist_demo_model"
28
 
29
  # Load the saved YOLO model into memory
@@ -97,8 +98,8 @@ check_memory_usage()
97
 
98
  # Initialize the Flask application
99
  app = Flask(__name__)
100
- # Initialize the ClipModel at the start
101
- clip_model = ClipModel()
102
 
103
 
104
 
@@ -168,7 +169,13 @@ def predict():
168
 
169
  for yolo_img in yolo_file: # 每張切圖yolo_img
170
  print("***** 4. START CLIP *****")
171
- top_k_words.append(clip_model.clip_result(yolo_img)) # CLIP預測3個結果(top_k_words)
 
 
 
 
 
 
172
  #encoded_images.append(image_to_base64(yolo_img))
173
  print(f"===== CLIP result:{top_k_words} =====\n")
174
 
 
13
  from flask import Flask, jsonify, request, render_template, send_file
14
  import torch
15
  from collections import Counter
 
16
  import psutil
17
+ from gradio_client import Client, handle_file
18
 
19
  # Disable tensorflow warnings
20
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
 
24
  MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
25
  MODEL_DIR = "./artifacts/models"
26
  YOLO_DIR = "./artifacts/yolo"
27
+ GRADIO_URL = "https://50094cfbc694a82dea.gradio.live/"
28
  #REPO_ID = "1vash/mnist_demo_model"
29
 
30
  # Load the saved YOLO model into memory
 
98
 
99
  # Initialize the Flask application
100
  app = Flask(__name__)
101
+ # # Initialize the ClipModel at the start
102
+ # clip_model = ClipModel()
103
 
104
 
105
 
 
169
 
170
  for yolo_img in yolo_file: # 每張切圖yolo_img
171
  print("***** 4. START CLIP *****")
172
+ client = Client(GRADIO_URL)
173
+ clip_result = client.predict(
174
+ image=handle_file(yolo_img),
175
+ top_k=3,
176
+ api_name="/predict"
177
+ )
178
+ top_k_words.append(clip_result) # CLIP預測3個結果(top_k_words)
179
  #encoded_images.append(image_to_base64(yolo_img))
180
  print(f"===== CLIP result:{top_k_words} =====\n")
181