Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from PIL import Image | |
from transformers import ChineseCLIPProcessor, ChineseCLIPModel | |
import psutil | |
def check_memory_usage(): | |
# Get memory details | |
memory_info = psutil.virtual_memory() | |
total_memory = memory_info.total / (1024 * 1024) # Convert bytes to MB | |
available_memory = memory_info.available / (1024 * 1024) | |
used_memory = memory_info.used / (1024 * 1024) | |
memory_usage_percent = memory_info.percent | |
print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^") | |
class ClipModel: | |
def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None, vocab_path='./chiikawa/word_list.txt'): | |
# Set device | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model and processor | |
self.model = ChineseCLIPModel.from_pretrained(model_name) | |
if model_path is None: | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth') | |
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
self.processor = ChineseCLIPProcessor.from_pretrained(model_name) | |
print("***** Clip Model LOAD DONE *****") | |
check_memory_usage() | |
# Load Chinese vocabulary | |
with open(vocab_path, 'r', encoding='utf-8') as f: | |
self.vocab = [line.strip() for line in f.readlines()] | |
def clip_result(self, image_path, top_k=3): | |
""" | |
給定圖片路徑,返回最接近的 top_k 詞彙 | |
""" | |
# Load image | |
image = Image.open(image_path) | |
print(f"===== Clip Model_clip_result : {image_path} ===== ") | |
# Run the function | |
check_memory_usage() | |
# Process images and texts | |
batch_size = 16 # Process 16 vocab at a time | |
similarities = [] | |
# # Check memory usage before calling empty_cache | |
# reserved_memory = torch.cuda.memory_reserved() | |
# allocated_memory = torch.cuda.memory_allocated() | |
# # Only call empty_cache if reserved memory exceeds a threshold | |
# if reserved_memory > 0.8 * torch.cuda.get_device_properties(0).total_memory: | |
# # Release unused memory | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
for i in range(0, len(self.vocab), batch_size): | |
batch_vocab = self.vocab[i:i + batch_size] | |
inputs = self.processor( | |
text=batch_vocab, | |
images=image, | |
return_tensors="pt", | |
padding=True | |
).to(self.device) | |
# Inference and similarity calculation | |
outputs = self.model(**inputs) | |
image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True) | |
text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True) | |
similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0) | |
similarities.append(similarity) | |
# Merge all similarities | |
similarity = torch.cat(similarities, dim=0) | |
# Find top-3 similarities | |
top_k_indices = torch.topk(similarity, top_k).indices.tolist() | |
top_k_words = [self.vocab[idx] for idx in top_k_indices] | |
# 6. 輸出最接近的前3名中文詞彙 | |
return top_k_words | |