Spaces:
Sleeping
Sleeping
import os | |
import time | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from pathlib import Path | |
from ultralytics import YOLO | |
import io | |
import base64 | |
import uuid | |
import glob | |
from tensorflow import keras | |
from flask import Flask, jsonify, request, render_template, send_file | |
import torch | |
from collections import Counter | |
import psutil | |
from gradio_client import Client, handle_file | |
# Disable tensorflow warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
load_type = 'local' | |
MODEL_YOLO = "yolo11_detect_best_241024_1.pt" | |
MODEL_DIR = "./artifacts/models" | |
YOLO_DIR = "./artifacts/yolo" | |
GRADIO_URL = "https://50094cfbc694a82dea.gradio.live/" | |
#REPO_ID = "1vash/mnist_demo_model" | |
# Load the saved YOLO model into memory | |
if load_type == 'local': | |
# 本地模型路徑 | |
model_path = f'{MODEL_DIR}/{MODEL_YOLO}' | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at {model_path}") | |
model = YOLO(model_path) | |
print("***** 1. LOAD YOLO MODEL DONE *****") | |
#model.eval() # 設定模型為推理模式 | |
elif load_type == 'remote_hub_download': | |
from huggingface_hub import hf_hub_download | |
# 從 Hugging Face Hub 下載模型 | |
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_YOLO) | |
model = torch.load(model_path) | |
#model.eval() | |
elif load_type == 'remote_hub_from_pretrained': | |
# 使用 Hugging Face Hub 預訓練的模型方式下載 | |
os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute()) | |
from huggingface_hub import from_pretrained | |
model = from_pretrained(REPO_ID, filename=MODEL_YOLO, cache_dir=MODEL_DIR) | |
#model.eval() | |
else: | |
raise AssertionError('No load type is specified!') | |
# image to base64 | |
def image_to_base64(image_path): | |
with open(image_path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return encoded_string | |
# 抓取指定路徑下的所有 JPG 檔案 | |
def get_jpg_files(path): | |
""" | |
Args: | |
path: 要搜尋的目錄路徑。 | |
Returns: | |
一個包含所有 JPG 檔案路徑的列表。 | |
""" | |
return glob.glob(os.path.join(path, "*.jpg")) | |
# 使用範例 | |
# image_folder = '/content/drive/MyDrive/chiikawa' # 替換成你的目錄路徑 | |
# jpg_files = get_jpg_files(image_folder) | |
def check_memory_usage(): | |
# Get memory details | |
memory_info = psutil.virtual_memory() | |
total_memory = memory_info.total / (1024 * 1024) # Convert bytes to MB | |
available_memory = memory_info.available / (1024 * 1024) | |
used_memory = memory_info.used / (1024 * 1024) | |
memory_usage_percent = memory_info.percent | |
print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^") | |
# Run the function | |
check_memory_usage() | |
# Initialize the Flask application | |
app = Flask(__name__) | |
# # Initialize the ClipModel at the start | |
# clip_model = ClipModel() | |
# API route for prediction(YOLO) | |
def predict(): | |
#user_id = request.args.get('user_id') | |
file = request.files['image'] | |
message_id = request.form.get('message_id') #str(uuid.uuid4()) | |
if 'image' not in request.files: | |
# Handle if no file is selected | |
return jsonify({"error": "No image part"}), 400 | |
#return 'No file selected' | |
# 讀取圖像 | |
try: | |
image_data = Image.open(file) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 400 | |
print("***** 2. Start YOLO predict *****") | |
# Make a prediction using YOLO | |
results = model(image_data) | |
print ("===== YOLO predict result:",results,"=====") | |
print("***** YOLO predict DONE *****") | |
check_memory_usage() | |
# 檢查 YOLO 是否返回了有效的結果 | |
if results is None or len(results) == 0: | |
return jsonify({'error': 'No results from YOLO model'}), 400 | |
saved_images = [] | |
# 儲存辨識後的圖片到指定資料夾 | |
for result in results: | |
encoded_images=[] | |
element_list =[] | |
top_k_words =[] | |
# 保存圖片 | |
result.save_crop(f"{YOLO_DIR}/{message_id}") | |
num_detections = len(result.boxes) # Get the number of detections | |
labels = result.boxes.cls # Get predicted label IDs | |
label_names = [model.names[int(label)] for label in labels] # Convert to names | |
print(f"====== 3. YOLO label_names: {label_names}======") | |
element_counts = Counter(label_names) | |
for element, count in element_counts.items(): | |
yolo_path = f"{YOLO_DIR}/{message_id}/{element}" | |
yolo_file = get_jpg_files(yolo_path) | |
print(f"***** 處理:{yolo_path} *****") | |
if len(yolo_file) == 0: | |
print(f"警告:{element} 沒有找到相關的 JPG 檔案") | |
continue | |
element_list.append(element) | |
for yolo_img in yolo_file: # 每張切圖yolo_img | |
print("***** 4. START CLIP *****") | |
client = Client(GRADIO_URL) | |
clip_result = client.predict( | |
image=handle_file(yolo_img), | |
top_k=3, | |
api_name="/predict" | |
) | |
top_k_words.append(clip_result) # CLIP預測3個結果(top_k_words) | |
#encoded_images.append(image_to_base64(yolo_img)) | |
print(f"===== CLIP result:{top_k_words} =====\n") | |
# if element_counts[element] > 1: #某隻角色的數量>1 | |
# yolo_path = f"{YOLO_DIR}/{message_id}/{element}" | |
# yolo_file = get_jpg_files(yolo_path) | |
# for yolo_img in yolo_file: # 取得每張圖的路徑 | |
# encoded_images.append(image_to_base64(yolo_img)) | |
# else : #某隻角色的數量=1 | |
# yolo_path = f"{YOLO_DIR}/{message_id}/{element}/im.jpg.jpg" | |
# encoded_images.append(image_to_base64(yolo_path)) | |
## 建立回應資料 | |
# response_data = { | |
# 'message_id': message_id, | |
# 'description': element_list, | |
# 'images': [ | |
# { | |
# 'encoded_image': encoded_image, | |
# 'description_list': top_k_words | |
# } | |
# for encoded_image, description_list in zip(encoded_images, top_k_words) | |
# ] | |
# } | |
# response_data = { | |
# 'message_id': message_id, | |
# 'images': encoded_images, | |
# 'description': element_list | |
# } | |
return jsonify(top_k_words), 200 #jsonify(response_data) | |
# for label_name in label_names: | |
# yolo_file=f"{YOLO_DIR}/{message_id}/{label_name}/im.jpg.jpg" | |
# # 將圖片轉換為 base64 編碼 | |
# encoded_images.append(image_to_base64(yolo_file)) | |
# # dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary | |
# # flask.jsonify vs json.dumps https://sentry.io/answers/difference-between-json-dumps-and-flask-jsonify/ | |
# # The flask.jsonify() function returns a Response object with Serializable JSON and content_type=application/json. | |
# return jsonify(response) | |
# # Helper function to preprocess the image | |
# def preprocess_image(image_data): | |
# """Preprocess image for YOLO Model Inference | |
# :param image_data: Raw image (PIL.Image) | |
# :return: image: Preprocessed Image (Tensor) | |
# """ | |
# # Define the YOLO input size (example 640x640, you can modify this based on your model) | |
# input_size = (640, 640) | |
# # Define transformation: Resize the image, convert to Tensor, and normalize pixel values | |
# transform = transforms.Compose([ | |
# transforms.Resize(input_size), # Resize to YOLO input size | |
# transforms.ToTensor(), # Convert image to PyTorch Tensor (通道數、影像高度和寬度) | |
# transforms.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]) # Normalization (if needed) | |
# ]) | |
# # Apply transformations to the image | |
# image = transform(image_data) | |
# # Add batch dimension (1, C, H, W) since YOLO expects a batch | |
# image = image.unsqueeze(0) | |
# return image | |
# API route for health check | |
def health(): | |
""" | |
Health check API to ensure the application is running. | |
Returns "OK" if the application is healthy. | |
Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health" | |
""" | |
return 'OK' | |
# API route for version | |
def version(): | |
""" | |
Returns the version of the application. | |
Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version" | |
""" | |
return '1.0' | |
def hello_world(): | |
return render_template("index.html") | |
# return "<p>Hello, Team!</p>" | |
# Start the Flask application | |
if __name__ == '__main__': | |
app.run(debug=True) | |