ning8429 commited on
Commit
329a484
·
verified ·
1 Parent(s): 4ddd633

Delete clip_model.py

Browse files
Files changed (1) hide show
  1. clip_model.py +0 -94
clip_model.py DELETED
@@ -1,94 +0,0 @@
1
- import os
2
- import torch
3
- from PIL import Image
4
- from transformers import ChineseCLIPProcessor, ChineseCLIPModel
5
- import psutil
6
-
7
- def check_memory_usage():
8
- # Get memory details
9
- memory_info = psutil.virtual_memory()
10
-
11
- total_memory = memory_info.total / (1024 * 1024) # Convert bytes to MB
12
- available_memory = memory_info.available / (1024 * 1024)
13
- used_memory = memory_info.used / (1024 * 1024)
14
- memory_usage_percent = memory_info.percent
15
-
16
- print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^")
17
- print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^")
18
- print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^")
19
- print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^")
20
-
21
- class ClipModel:
22
- def __init__(self, model_name="OFA-Sys/chinese-clip-vit-base-patch16", model_path=None, vocab_path='./chiikawa/word_list.txt'):
23
- # Set device
24
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
-
26
- # Load model and processor
27
- self.model = ChineseCLIPModel.from_pretrained(model_name)
28
- if model_path is None:
29
- script_dir = os.path.dirname(os.path.abspath(__file__))
30
- model_path = os.path.join(script_dir, 'artifacts/models', 'best_clip_model.pth')
31
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
32
- self.model = self.model.to(self.device)
33
- self.model.eval()
34
-
35
- self.processor = ChineseCLIPProcessor.from_pretrained(model_name)
36
-
37
- print("***** Clip Model LOAD DONE *****")
38
- check_memory_usage()
39
-
40
- # Load Chinese vocabulary
41
- with open(vocab_path, 'r', encoding='utf-8') as f:
42
- self.vocab = [line.strip() for line in f.readlines()]
43
-
44
- def clip_result(self, image_path, top_k=3):
45
- """
46
- 給定圖片路徑,返回最接近的 top_k 詞彙
47
- """
48
- # Load image
49
- image = Image.open(image_path)
50
-
51
- print(f"===== Clip Model_clip_result : {image_path} ===== ")
52
- # Run the function
53
- check_memory_usage()
54
-
55
- # Process images and texts
56
- batch_size = 16 # Process 16 vocab at a time
57
- similarities = []
58
-
59
- # # Check memory usage before calling empty_cache
60
- # reserved_memory = torch.cuda.memory_reserved()
61
- # allocated_memory = torch.cuda.memory_allocated()
62
-
63
- # # Only call empty_cache if reserved memory exceeds a threshold
64
- # if reserved_memory > 0.8 * torch.cuda.get_device_properties(0).total_memory:
65
- # # Release unused memory
66
- torch.cuda.empty_cache()
67
-
68
- with torch.no_grad():
69
- for i in range(0, len(self.vocab), batch_size):
70
- batch_vocab = self.vocab[i:i + batch_size]
71
- inputs = self.processor(
72
- text=batch_vocab,
73
- images=image,
74
- return_tensors="pt",
75
- padding=True
76
- ).to(self.device)
77
-
78
- # Inference and similarity calculation
79
- outputs = self.model(**inputs)
80
- image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
81
- text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
82
- similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0)
83
- similarities.append(similarity)
84
-
85
- # Merge all similarities
86
- similarity = torch.cat(similarities, dim=0)
87
-
88
- # Find top-3 similarities
89
- top_k_indices = torch.topk(similarity, top_k).indices.tolist()
90
- top_k_words = [self.vocab[idx] for idx in top_k_indices]
91
-
92
- # 6. 輸出最接近的前3名中文詞彙
93
- return top_k_words
94
-