Spaces:
Running
on
Zero
Running
on
Zero
Upload 3 files
Browse files- clip_analyzer.py +27 -30
- clip_model_manager.py +18 -24
- clip_zero_shot_classifier.py +4 -4
clip_analyzer.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
-
import
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
|
@@ -20,14 +20,13 @@ class CLIPAnalyzer:
|
|
| 20 |
Use Clip to intergrate scene understanding function
|
| 21 |
"""
|
| 22 |
|
| 23 |
-
def __init__(self, model_name: str = "ViT-B
|
| 24 |
"""
|
| 25 |
-
初始化 CLIP
|
| 26 |
|
| 27 |
Args:
|
| 28 |
-
model_name:
|
| 29 |
-
device:
|
| 30 |
-
pretrained: 預訓練權重,使用 "laion2b_s34b_b79k"
|
| 31 |
"""
|
| 32 |
# 自動選擇設備
|
| 33 |
if device is None:
|
|
@@ -35,23 +34,12 @@ class CLIPAnalyzer:
|
|
| 35 |
else:
|
| 36 |
self.device = device
|
| 37 |
|
| 38 |
-
print(f"Loading
|
| 39 |
try:
|
| 40 |
-
self.model,
|
| 41 |
-
|
| 42 |
-
pretrained=pretrained,
|
| 43 |
-
device=self.device
|
| 44 |
-
)
|
| 45 |
-
self.tokenizer = open_clip.get_tokenizer(model_name)
|
| 46 |
-
print(f"OpenCLIP model loaded successfully.")
|
| 47 |
-
import gc
|
| 48 |
-
gc.collect()
|
| 49 |
-
if torch.cuda.is_available():
|
| 50 |
-
torch.cuda.empty_cache()
|
| 51 |
-
torch.cuda.synchronize()
|
| 52 |
-
print("Memory cleanup completed after OpenCLIP loading.")
|
| 53 |
except Exception as e:
|
| 54 |
-
print(f"Error loading
|
| 55 |
raise
|
| 56 |
|
| 57 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
|
@@ -76,7 +64,7 @@ class CLIPAnalyzer:
|
|
| 76 |
if scene_texts:
|
| 77 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
| 78 |
try:
|
| 79 |
-
self.text_features_cache["scene_type_tokens"] =
|
| 80 |
except Exception as e:
|
| 81 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
| 82 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
|
@@ -94,7 +82,7 @@ class CLIPAnalyzer:
|
|
| 94 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
| 95 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 96 |
try:
|
| 97 |
-
cultural_tokens_dict_val[scene_type] =
|
| 98 |
except Exception as e:
|
| 99 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
| 100 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
|
@@ -108,7 +96,7 @@ class CLIPAnalyzer:
|
|
| 108 |
if lighting_texts:
|
| 109 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
| 110 |
try:
|
| 111 |
-
self.text_features_cache["lighting_tokens"] =
|
| 112 |
except Exception as e:
|
| 113 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
| 114 |
self.text_features_cache["lighting_tokens"] = None
|
|
@@ -125,7 +113,7 @@ class CLIPAnalyzer:
|
|
| 125 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
| 126 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 127 |
try:
|
| 128 |
-
specialized_tokens_dict_val[scene_type] =
|
| 129 |
except Exception as e:
|
| 130 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
| 131 |
specialized_tokens_dict_val[scene_type] = None
|
|
@@ -139,7 +127,7 @@ class CLIPAnalyzer:
|
|
| 139 |
if viewpoint_texts:
|
| 140 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
| 141 |
try:
|
| 142 |
-
self.text_features_cache["viewpoint_tokens"] =
|
| 143 |
except Exception as e:
|
| 144 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
| 145 |
self.text_features_cache["viewpoint_tokens"] = None
|
|
@@ -156,7 +144,7 @@ class CLIPAnalyzer:
|
|
| 156 |
if object_combination_texts:
|
| 157 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
| 158 |
try:
|
| 159 |
-
self.text_features_cache["object_combination_tokens"] =
|
| 160 |
except Exception as e:
|
| 161 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
| 162 |
self.text_features_cache["object_combination_tokens"] = None
|
|
@@ -173,7 +161,7 @@ class CLIPAnalyzer:
|
|
| 173 |
if activity_texts:
|
| 174 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
| 175 |
try:
|
| 176 |
-
self.text_features_cache["activity_tokens"] =
|
| 177 |
except Exception as e:
|
| 178 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
| 179 |
self.text_features_cache["activity_tokens"] = None
|
|
@@ -192,7 +180,7 @@ class CLIPAnalyzer:
|
|
| 192 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
| 193 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
| 194 |
|
| 195 |
-
print("
|
| 196 |
|
| 197 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
| 198 |
"""
|
|
@@ -593,7 +581,16 @@ class CLIPAnalyzer:
|
|
| 593 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
| 594 |
|
| 595 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
| 596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
|
| 598 |
with torch.no_grad():
|
| 599 |
text_features = self.model.encode_text(text_token)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import clip
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
|
|
|
| 20 |
Use Clip to intergrate scene understanding function
|
| 21 |
"""
|
| 22 |
|
| 23 |
+
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
|
| 24 |
"""
|
| 25 |
+
初始化 CLIP 分析器。
|
| 26 |
|
| 27 |
Args:
|
| 28 |
+
model_name: CLIP Model name, 默認 "ViT-B/16"
|
| 29 |
+
device: Use GPU if it can use
|
|
|
|
| 30 |
"""
|
| 31 |
# 自動選擇設備
|
| 32 |
if device is None:
|
|
|
|
| 34 |
else:
|
| 35 |
self.device = device
|
| 36 |
|
| 37 |
+
print(f"Loading CLIP model {model_name} on {self.device}...")
|
| 38 |
try:
|
| 39 |
+
self.model, self.preprocess = clip.load(model_name, device=self.device)
|
| 40 |
+
print(f"CLIP model loaded successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
except Exception as e:
|
| 42 |
+
print(f"Error loading CLIP model: {e}")
|
| 43 |
raise
|
| 44 |
|
| 45 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
|
|
|
| 64 |
if scene_texts:
|
| 65 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
| 66 |
try:
|
| 67 |
+
self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device)
|
| 68 |
except Exception as e:
|
| 69 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
| 70 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
|
|
|
| 82 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
| 83 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 84 |
try:
|
| 85 |
+
cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
|
| 86 |
except Exception as e:
|
| 87 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
| 88 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
|
|
|
| 96 |
if lighting_texts:
|
| 97 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
| 98 |
try:
|
| 99 |
+
self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device)
|
| 100 |
except Exception as e:
|
| 101 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
| 102 |
self.text_features_cache["lighting_tokens"] = None
|
|
|
|
| 113 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
| 114 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 115 |
try:
|
| 116 |
+
specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
|
| 117 |
except Exception as e:
|
| 118 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
| 119 |
specialized_tokens_dict_val[scene_type] = None
|
|
|
|
| 127 |
if viewpoint_texts:
|
| 128 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
| 129 |
try:
|
| 130 |
+
self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device)
|
| 131 |
except Exception as e:
|
| 132 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
| 133 |
self.text_features_cache["viewpoint_tokens"] = None
|
|
|
|
| 144 |
if object_combination_texts:
|
| 145 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
| 146 |
try:
|
| 147 |
+
self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device)
|
| 148 |
except Exception as e:
|
| 149 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
| 150 |
self.text_features_cache["object_combination_tokens"] = None
|
|
|
|
| 161 |
if activity_texts:
|
| 162 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
| 163 |
try:
|
| 164 |
+
self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device)
|
| 165 |
except Exception as e:
|
| 166 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
| 167 |
self.text_features_cache["activity_tokens"] = None
|
|
|
|
| 180 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
| 181 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
| 182 |
|
| 183 |
+
print("CLIP text_features_cache prepared.")
|
| 184 |
|
| 185 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
| 186 |
"""
|
|
|
|
| 581 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
| 582 |
|
| 583 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
| 584 |
+
"""
|
| 585 |
+
將文本轉換為 CLIP 嵌入表示
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
text: 輸入文本
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
np.ndarray: 文本的 CLIP 特徵向量
|
| 592 |
+
"""
|
| 593 |
+
text_token = clip.tokenize([text]).to(self.device)
|
| 594 |
|
| 595 |
with torch.no_grad():
|
| 596 |
text_features = self.model.encode_text(text_token)
|
clip_model_manager.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
-
import
|
| 4 |
import numpy as np
|
| 5 |
import logging
|
| 6 |
import traceback
|
|
@@ -12,7 +12,7 @@ class CLIPModelManager:
|
|
| 12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
def __init__(self, model_name: str = "ViT-B
|
| 16 |
"""
|
| 17 |
初始化 CLIP 模型管理器
|
| 18 |
|
|
@@ -22,8 +22,6 @@ class CLIPModelManager:
|
|
| 22 |
"""
|
| 23 |
self.logger = logging.getLogger(__name__)
|
| 24 |
self.model_name = model_name
|
| 25 |
-
self.pretrained = pretrained
|
| 26 |
-
self.tokenizer = None
|
| 27 |
|
| 28 |
# 設置運行設備
|
| 29 |
if device is None:
|
|
@@ -31,32 +29,19 @@ class CLIPModelManager:
|
|
| 31 |
else:
|
| 32 |
self.device = device
|
| 33 |
|
|
|
|
| 34 |
self.preprocess = None
|
| 35 |
|
| 36 |
self._initialize_model()
|
| 37 |
|
| 38 |
def _initialize_model(self):
|
| 39 |
"""
|
| 40 |
-
初始化
|
| 41 |
"""
|
| 42 |
try:
|
| 43 |
-
self.logger.info(f"Initializing
|
| 44 |
-
self.model,
|
| 45 |
-
|
| 46 |
-
pretrained=self.pretrained,
|
| 47 |
-
device=self.device
|
| 48 |
-
)
|
| 49 |
-
self.tokenizer = open_clip.get_tokenizer(self.model_name)
|
| 50 |
-
self.logger.info("Successfully loaded OpenCLIP model")
|
| 51 |
-
|
| 52 |
-
# 立即清理 OpenCLIP 載入過程中的記憶體碎片
|
| 53 |
-
import gc
|
| 54 |
-
gc.collect()
|
| 55 |
-
if torch.cuda.is_available():
|
| 56 |
-
torch.cuda.empty_cache()
|
| 57 |
-
torch.cuda.synchronize()
|
| 58 |
-
self.logger.info("Memory cleanup completed after OpenCLIP loading in CLIPModelManager")
|
| 59 |
-
|
| 60 |
except Exception as e:
|
| 61 |
self.logger.error(f"Error loading CLIP model: {e}")
|
| 62 |
self.logger.error(traceback.format_exc())
|
|
@@ -102,7 +87,7 @@ class CLIPModelManager:
|
|
| 102 |
|
| 103 |
for i in range(0, len(text_prompts), batch_size):
|
| 104 |
batch_prompts = text_prompts[i:i+batch_size]
|
| 105 |
-
text_tokens =
|
| 106 |
batch_features = self.model.encode_text(text_tokens)
|
| 107 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
| 108 |
features_list.append(batch_features)
|
|
@@ -121,9 +106,18 @@ class CLIPModelManager:
|
|
| 121 |
raise
|
| 122 |
|
| 123 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
try:
|
| 125 |
with torch.no_grad():
|
| 126 |
-
text_tokens =
|
| 127 |
text_features = self.model.encode_text(text_tokens)
|
| 128 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 129 |
return text_features
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
+
import clip
|
| 4 |
import numpy as np
|
| 5 |
import logging
|
| 6 |
import traceback
|
|
|
|
| 12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
|
| 16 |
"""
|
| 17 |
初始化 CLIP 模型管理器
|
| 18 |
|
|
|
|
| 22 |
"""
|
| 23 |
self.logger = logging.getLogger(__name__)
|
| 24 |
self.model_name = model_name
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# 設置運行設備
|
| 27 |
if device is None:
|
|
|
|
| 29 |
else:
|
| 30 |
self.device = device
|
| 31 |
|
| 32 |
+
self.model = None
|
| 33 |
self.preprocess = None
|
| 34 |
|
| 35 |
self._initialize_model()
|
| 36 |
|
| 37 |
def _initialize_model(self):
|
| 38 |
"""
|
| 39 |
+
初始化CLIP模型
|
| 40 |
"""
|
| 41 |
try:
|
| 42 |
+
self.logger.info(f"Initializing CLIP model ({self.model_name}) on {self.device}")
|
| 43 |
+
self.model, self.preprocess = clip.load(self.model_name, device=self.device)
|
| 44 |
+
self.logger.info("Successfully loaded CLIP model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
except Exception as e:
|
| 46 |
self.logger.error(f"Error loading CLIP model: {e}")
|
| 47 |
self.logger.error(traceback.format_exc())
|
|
|
|
| 87 |
|
| 88 |
for i in range(0, len(text_prompts), batch_size):
|
| 89 |
batch_prompts = text_prompts[i:i+batch_size]
|
| 90 |
+
text_tokens = clip.tokenize(batch_prompts).to(self.device)
|
| 91 |
batch_features = self.model.encode_text(text_tokens)
|
| 92 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
| 93 |
features_list.append(batch_features)
|
|
|
|
| 106 |
raise
|
| 107 |
|
| 108 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
編碼單個文本批次的特徵
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
text_prompts: 文本提示列表
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
torch.Tensor: 標準化後的文本特徵
|
| 117 |
+
"""
|
| 118 |
try:
|
| 119 |
with torch.no_grad():
|
| 120 |
+
text_tokens = clip.tokenize(text_prompts).to(self.device)
|
| 121 |
text_features = self.model.encode_text(text_tokens)
|
| 122 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 123 |
return text_features
|
clip_zero_shot_classifier.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
-
import
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
import logging
|
|
@@ -21,18 +21,18 @@ class CLIPZeroShotClassifier:
|
|
| 21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(self, model_name: str = "ViT-B
|
| 25 |
"""
|
| 26 |
初始化CLIP零樣本分類器
|
| 27 |
|
| 28 |
Args:
|
| 29 |
-
model_name:
|
| 30 |
device: 運行設備,None則自動選擇
|
| 31 |
"""
|
| 32 |
self.logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# 初始化各個組件
|
| 35 |
-
self.clip_model_manager = CLIPModelManager(model_name, device
|
| 36 |
self.landmark_data_manager = LandmarkDataManager()
|
| 37 |
self.image_analyzer = ImageAnalyzer()
|
| 38 |
self.confidence_manager = ConfidenceManager()
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
+
import clip
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
import logging
|
|
|
|
| 21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
| 22 |
"""
|
| 23 |
|
| 24 |
+
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
|
| 25 |
"""
|
| 26 |
初始化CLIP零樣本分類器
|
| 27 |
|
| 28 |
Args:
|
| 29 |
+
model_name: CLIP模型名稱,默認為"ViT-B/16"
|
| 30 |
device: 運行設備,None則自動選擇
|
| 31 |
"""
|
| 32 |
self.logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# 初始化各個組件
|
| 35 |
+
self.clip_model_manager = CLIPModelManager(model_name, device)
|
| 36 |
self.landmark_data_manager = LandmarkDataManager()
|
| 37 |
self.image_analyzer = ImageAnalyzer()
|
| 38 |
self.confidence_manager = ConfidenceManager()
|