ning8429 commited on
Commit
531c3cd
·
verified ·
1 Parent(s): 4dd8b9b

Update clip_model.py

Browse files
Files changed (1) hide show
  1. clip_model.py +15 -8
clip_model.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  from transformers import ChineseCLIPProcessor, ChineseCLIPModel
5
 
6
  class ClipModel:
7
- def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None):
8
  # Set device
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
@@ -19,13 +19,20 @@ class ClipModel:
19
 
20
  self.processor = ChineseCLIPProcessor.from_pretrained(model_name)
21
 
22
- def clip_result(self, image_path, vocab_path='./chiikawa/word_list.txt', top_k=3):
23
- # Load image
24
- image = Image.open(image_path)
25
 
26
  # Load Chinese vocabulary
27
  with open(vocab_path, 'r', encoding='utf-8') as f:
28
- vocab = [line.strip() for line in f.readlines()]
 
 
 
 
 
 
 
 
 
29
 
30
  # Process images and texts
31
  batch_size = 16 # Process 16 vocab at a time
@@ -35,8 +42,8 @@ class ClipModel:
35
  torch.cuda.empty_cache()
36
 
37
  with torch.no_grad():
38
- for i in range(0, len(vocab), batch_size):
39
- batch_vocab = vocab[i:i + batch_size]
40
  inputs = self.processor(
41
  text=batch_vocab,
42
  images=image,
@@ -56,7 +63,7 @@ class ClipModel:
56
 
57
  # Find top-3 similarities
58
  top_k_indices = torch.topk(similarity, top_k).indices.tolist()
59
- top_k_words = [vocab[idx] for idx in top_k_indices]
60
 
61
  # 6. 輸出最接近的前3名中文詞彙
62
  return top_k_words
 
4
  from transformers import ChineseCLIPProcessor, ChineseCLIPModel
5
 
6
  class ClipModel:
7
+ def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None, vocab_path='./chiikawa/word_list.txt'):
8
  # Set device
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
19
 
20
  self.processor = ChineseCLIPProcessor.from_pretrained(model_name)
21
 
22
+ print("***** Clip Model LOAD DONE *****")
 
 
23
 
24
  # Load Chinese vocabulary
25
  with open(vocab_path, 'r', encoding='utf-8') as f:
26
+ self.vocab = [line.strip() for line in f.readlines()]
27
+
28
+ def clip_result(self, image_path, top_k=3):
29
+ """
30
+ 給定圖片路徑,返回最接近的 top_k 詞彙
31
+ """
32
+ # Load image
33
+ image = Image.open(image_path)
34
+
35
+ print(f"===== Clip Model_clip_result : {image_path} ===== ")
36
 
37
  # Process images and texts
38
  batch_size = 16 # Process 16 vocab at a time
 
42
  torch.cuda.empty_cache()
43
 
44
  with torch.no_grad():
45
+ for i in range(0, len(self.vocab), batch_size):
46
+ batch_vocab = self.vocab[i:i + batch_size]
47
  inputs = self.processor(
48
  text=batch_vocab,
49
  images=image,
 
63
 
64
  # Find top-3 similarities
65
  top_k_indices = torch.topk(similarity, top_k).indices.tolist()
66
+ top_k_words = [self.vocab[idx] for idx in top_k_indices]
67
 
68
  # 6. 輸出最接近的前3名中文詞彙
69
  return top_k_words