ning8429 commited on
Commit
4ddb23c
·
verified ·
1 Parent(s): 8e8837a

Update clip.py

Browse files
Files changed (1) hide show
  1. clip.py +57 -66
clip.py CHANGED
@@ -3,70 +3,61 @@ import torch
3
  from PIL import Image
4
  from transformers import ChineseCLIPProcessor, ChineseCLIPModel
5
 
6
- def clip_result(image_path):
7
- # 設置設備
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Get the directory where this script is located
11
- script_dir = os.path.dirname(os.path.abspath(__file__))
12
-
13
- # Construct the full path to the file in the subfolder
14
- model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
15
-
16
- print("model_path:", model_path)
17
-
18
- # 載入訓練好的模型和處理器
19
- model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
20
- model.load_state_dict(torch.load(model_path, map_location=device))
21
- model = model.to(device)
22
- model.eval()
23
-
24
- processor = ChineseCLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
25
-
26
- # 1. 加載圖片
27
- # image_path = '/content/drive/MyDrive/幽靈吉伊卡哇.png'
28
- image = Image.open(image_path)
29
-
30
- # 2. 加載中文詞彙表
31
- with open('./chiikawa/word_list.txt', 'r', encoding='utf-8') as f:
32
- vocab = [line.strip() for line in f.readlines()]
33
-
34
- # 3. 圖像和文本處理
35
- batch_size = 16 # 每次處理16個詞彙
36
- similarities = []
37
-
38
- # 釋放未使用的顯存
39
- torch.cuda.empty_cache()
40
-
41
- with torch.no_grad():
42
- for i in range(0, len(vocab), batch_size):
43
- batch_vocab = vocab[i:i + batch_size]
44
- inputs = processor(
45
- text=batch_vocab,
46
- images=image,
47
- return_tensors="pt",
48
- padding=True
49
- ).to(device)
50
-
51
- # 推理並進行相似度計算
52
- outputs = model(**inputs)
53
- image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
54
- text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
55
- similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
56
- similarities.append(similarity)
57
-
58
- # 4. 合併所有相似度
59
- similarity = torch.cat(similarities, dim=0)
60
-
61
- # 5. 找到相似度最高的詞彙
62
- top_k = 3
63
- top_k_indices = torch.topk(similarity, top_k).indices.tolist()
64
- top_k_words = [vocab[idx] for idx in top_k_indices]
65
-
66
- # 6. 輸出最接近的前3名中文詞彙
67
-
68
- # print("最接近的前3名中文詞彙是:")
69
- # for rank, word in enumerate(top_k_words, 1):
70
- # print(f"{rank}. {word}")
71
-
72
- return top_k_words
 
3
  from PIL import Image
4
  from transformers import ChineseCLIPProcessor, ChineseCLIPModel
5
 
6
+ class ClipModel:
7
+ def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None):
8
+ # Set device
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load model and processor
12
+ self.model = ChineseCLIPModel.from_pretrained(model_name)
13
+ if model_path is None:
14
+ script_dir = os.path.dirname(os.path.abspath(__file__))
15
+ model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
16
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
17
+ self.model = self.model.to(self.device)
18
+ self.model.eval()
19
+
20
+ self.processor = ChineseCLIPProcessor.from_pretrained(model_name)
21
+
22
+ def clip_result(self, image_path, vocab_path='./chiikawa/word_list.txt', top_k=3):
23
+ # Load image
24
+ image = Image.open(image_path)
25
+
26
+ # Load Chinese vocabulary
27
+ with open(vocab_path, 'r', encoding='utf-8') as f:
28
+ vocab = [line.strip() for line in f.readlines()]
29
+
30
+ # Process images and texts
31
+ batch_size = 16 # Process 16 vocab at a time
32
+ similarities = []
33
+
34
+ # Release unused memory
35
+ torch.cuda.empty_cache()
36
+
37
+ with torch.no_grad():
38
+ for i in range(0, len(vocab), batch_size):
39
+ batch_vocab = vocab[i:i + batch_size]
40
+ inputs = self.processor(
41
+ text=batch_vocab,
42
+ images=image,
43
+ return_tensors="pt",
44
+ padding=True
45
+ ).to(self.device)
46
+
47
+ # Inference and similarity calculation
48
+ outputs = self.model(**inputs)
49
+ image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
50
+ text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
51
+ similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
52
+ similarities.append(similarity)
53
+
54
+ # Merge all similarities
55
+ similarity = torch.cat(similarities, dim=0)
56
+
57
+ # Find top-3 similarities
58
+ top_k_indices = torch.topk(similarity, top_k).indices.tolist()
59
+ top_k_words = [vocab[idx] for idx in top_k_indices]
60
+
61
+ # 6. 輸出最接近的前3名中文詞彙
62
+ return top_k_words
63