File size: 2,458 Bytes
2184c4e
 
 
 
 
4ddb23c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2184c4e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import torch
from PIL import Image
from transformers import ChineseCLIPProcessor, ChineseCLIPModel

class ClipModel:
    def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None):
        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load model and processor
        self.model = ChineseCLIPModel.from_pretrained(model_name)
        if model_path is None:
            script_dir = os.path.dirname(os.path.abspath(__file__))
            model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model = self.model.to(self.device)
        self.model.eval()

        self.processor = ChineseCLIPProcessor.from_pretrained(model_name)

    def clip_result(self, image_path, vocab_path='./chiikawa/word_list.txt', top_k=3):
        # Load image
        image = Image.open(image_path)

        # Load Chinese vocabulary
        with open(vocab_path, 'r', encoding='utf-8') as f:
            vocab = [line.strip() for line in f.readlines()]

        # Process images and texts
        batch_size = 16  # Process 16 vocab at a time
        similarities = []

        # Release unused memory
        torch.cuda.empty_cache()

        with torch.no_grad():
            for i in range(0, len(vocab), batch_size):
                batch_vocab = vocab[i:i + batch_size]
                inputs = self.processor(
                    text=batch_vocab,
                    images=image,
                    return_tensors="pt",
                    padding=True
                ).to(self.device)

                # Inference and similarity calculation
                outputs = self.model(**inputs)
                image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
                text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
                similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
                similarities.append(similarity)

        # Merge all similarities
        similarity = torch.cat(similarities, dim=0)

        # Find top-3 similarities
        top_k_indices = torch.topk(similarity, top_k).indices.tolist()
        top_k_words = [vocab[idx] for idx in top_k_indices]
        
        # 6. 輸出最接近的前3名中文詞彙
        return top_k_words