flask-docker / clip_model.py
ning8429's picture
Update clip_model.py
4273422 verified
raw
history blame
3.74 kB
import os
import torch
from PIL import Image
from transformers import ChineseCLIPProcessor, ChineseCLIPModel
import psutil
def check_memory_usage():
# Get memory details
memory_info = psutil.virtual_memory()
total_memory = memory_info.total / (1024 * 1024) # Convert bytes to MB
available_memory = memory_info.available / (1024 * 1024)
used_memory = memory_info.used / (1024 * 1024)
memory_usage_percent = memory_info.percent
print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^")
print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^")
print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^")
print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^")
class ClipModel:
def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None, vocab_path='./chiikawa/word_list.txt'):
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
self.model = ChineseCLIPModel.from_pretrained(model_name)
if model_path is None:
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model = self.model.to(self.device)
self.model.eval()
self.processor = ChineseCLIPProcessor.from_pretrained(model_name)
print("***** Clip Model LOAD DONE *****")
check_memory_usage()
# Load Chinese vocabulary
with open(vocab_path, 'r', encoding='utf-8') as f:
self.vocab = [line.strip() for line in f.readlines()]
def clip_result(self, image_path, top_k=3):
"""
給定圖片路徑,返回最接近的 top_k 詞彙
"""
# Load image
image = Image.open(image_path)
print(f"===== Clip Model_clip_result : {image_path} ===== ")
# Run the function
check_memory_usage()
# Process images and texts
batch_size = 16 # Process 16 vocab at a time
similarities = []
# # Check memory usage before calling empty_cache
# reserved_memory = torch.cuda.memory_reserved()
# allocated_memory = torch.cuda.memory_allocated()
# # Only call empty_cache if reserved memory exceeds a threshold
# if reserved_memory > 0.8 * torch.cuda.get_device_properties(0).total_memory:
# # Release unused memory
torch.cuda.empty_cache()
with torch.no_grad():
for i in range(0, len(self.vocab), batch_size):
batch_vocab = self.vocab[i:i + batch_size]
inputs = self.processor(
text=batch_vocab,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
# Inference and similarity calculation
outputs = self.model(**inputs)
image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
similarities.append(similarity)
# Merge all similarities
similarity = torch.cat(similarities, dim=0)
# Find top-3 similarities
top_k_indices = torch.topk(similarity, top_k).indices.tolist()
top_k_words = [self.vocab[idx] for idx in top_k_indices]
# 6. 輸出最接近的前3名中文詞彙
return top_k_words