File size: 3,739 Bytes
2184c4e
 
 
 
f1a51ff
2184c4e
8185e78
 
 
 
 
 
 
 
 
 
 
 
 
 
4ddb23c
531c3cd
4ddb23c
 
 
 
 
 
 
 
 
 
 
 
 
 
531c3cd
eba1ae5
4ddb23c
 
 
531c3cd
 
 
 
 
 
 
 
 
 
eba1ae5
 
4ddb23c
 
 
 
 
4273422
 
 
ce7e2d8
4273422
 
 
 
4ddb23c
 
531c3cd
 
4ddb23c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531c3cd
4ddb23c
 
 
2184c4e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import torch
from PIL import Image
from transformers import ChineseCLIPProcessor, ChineseCLIPModel
import psutil

def check_memory_usage():
    # Get memory details
    memory_info = psutil.virtual_memory()
    
    total_memory = memory_info.total / (1024 * 1024)  # Convert bytes to MB
    available_memory = memory_info.available / (1024 * 1024)
    used_memory = memory_info.used / (1024 * 1024)
    memory_usage_percent = memory_info.percent
    
    print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^")
    print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^")
    print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^")
    print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^")

class ClipModel:
    def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None, vocab_path='./chiikawa/word_list.txt'):
        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load model and processor
        self.model = ChineseCLIPModel.from_pretrained(model_name)
        if model_path is None:
            script_dir = os.path.dirname(os.path.abspath(__file__))
            model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model = self.model.to(self.device)
        self.model.eval()

        self.processor = ChineseCLIPProcessor.from_pretrained(model_name)

        print("***** Clip Model LOAD DONE *****")
        check_memory_usage()

        # Load Chinese vocabulary
        with open(vocab_path, 'r', encoding='utf-8') as f:
            self.vocab = [line.strip() for line in f.readlines()]

    def clip_result(self, image_path, top_k=3):
        """
        給定圖片路徑,返回最接近的 top_k 詞彙
        """
        # Load image
        image = Image.open(image_path)

        print(f"===== Clip Model_clip_result : {image_path} ===== ")
        # Run the function
        check_memory_usage()

        # Process images and texts
        batch_size = 16  # Process 16 vocab at a time
        similarities = []

        # # Check memory usage before calling empty_cache
        # reserved_memory = torch.cuda.memory_reserved()
        # allocated_memory = torch.cuda.memory_allocated()
        
        # # Only call empty_cache if reserved memory exceeds a threshold
        # if reserved_memory > 0.8 * torch.cuda.get_device_properties(0).total_memory:
        #     # Release unused memory
        torch.cuda.empty_cache()

        with torch.no_grad():
            for i in range(0, len(self.vocab), batch_size):
                batch_vocab = self.vocab[i:i + batch_size]
                inputs = self.processor(
                    text=batch_vocab,
                    images=image,
                    return_tensors="pt",
                    padding=True
                ).to(self.device)

                # Inference and similarity calculation
                outputs = self.model(**inputs)
                image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
                text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
                similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
                similarities.append(similarity)

        # Merge all similarities
        similarity = torch.cat(similarities, dim=0)

        # Find top-3 similarities
        top_k_indices = torch.topk(similarity, top_k).indices.tolist()
        top_k_words = [self.vocab[idx] for idx in top_k_indices]
        
        # 6. 輸出最接近的前3名中文詞彙
        return top_k_words