Spaces:
Running
on
Zero
Running
on
Zero
File size: 32,802 Bytes
3172319 c0fe80d 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 3172319 4d1f920 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 |
import torch
import clip
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Any, Optional, Union
from clip_prompts import (
SCENE_TYPE_PROMPTS,
CULTURAL_SCENE_PROMPTS,
COMPARATIVE_PROMPTS,
LIGHTING_CONDITION_PROMPTS,
SPECIALIZED_SCENE_PROMPTS,
VIEWPOINT_PROMPTS,
OBJECT_COMBINATION_PROMPTS,
ACTIVITY_PROMPTS
)
class CLIPAnalyzer:
"""
Use Clip to intergrate scene understanding function
"""
def __init__(self, model_name: str = "ViT-L/14", device: str = None):
"""
初始化 CLIP 分析器。
Args:
model_name: CLIP Model name, 默認 "ViT-L/14"
device: Use GPU if it can use
"""
# 自動選擇設備
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Loading CLIP model {model_name} on {self.device}...")
try:
self.model, self.preprocess = clip.load(model_name, device=self.device)
print(f"CLIP model loaded successfully.")
except Exception as e:
print(f"Error loading CLIP model: {e}")
raise
self.scene_type_prompts = SCENE_TYPE_PROMPTS
self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS
self.comparative_prompts = COMPARATIVE_PROMPTS
self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS
self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS
self.viewpoint_prompts = VIEWPOINT_PROMPTS
self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS
self.activity_prompts = ACTIVITY_PROMPTS
# turn to CLIP format
self._prepare_text_prompts()
def _prepare_text_prompts(self):
"""準備所有文本提示的 CLIP 特徵並存儲到 self.text_features_cache 中"""
self.text_features_cache = {}
# 處理基礎場景類型 (SCENE_TYPE_PROMPTS)
if hasattr(self, 'scene_type_prompts') and self.scene_type_prompts:
scene_texts = [prompt for scene_type, prompt in self.scene_type_prompts.items()]
if scene_texts:
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
try:
self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
else:
self.text_features_cache["scene_type_keys"] = []
self.text_features_cache["scene_type_tokens"] = None
else:
self.text_features_cache["scene_type_keys"] = []
self.text_features_cache["scene_type_tokens"] = None
# 處理文化場景 (CULTURAL_SCENE_PROMPTS)
# cultural_tokens_dict 存儲的是 tokenized prompts
cultural_tokens_dict_val = {}
if hasattr(self, 'cultural_scene_prompts') and self.cultural_scene_prompts:
for scene_type, prompts in self.cultural_scene_prompts.items():
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
try:
cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
else:
cultural_tokens_dict_val[scene_type] = None # prompts 不合規
self.text_features_cache["cultural_tokens_dict"] = cultural_tokens_dict_val
# 處理光照條件 (LIGHTING_CONDITION_PROMPTS)
if hasattr(self, 'lighting_condition_prompts') and self.lighting_condition_prompts:
lighting_texts = [prompt for cond, prompt in self.lighting_condition_prompts.items()]
if lighting_texts:
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
try:
self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
self.text_features_cache["lighting_tokens"] = None
else:
self.text_features_cache["lighting_condition_keys"] = []
self.text_features_cache["lighting_tokens"] = None
else:
self.text_features_cache["lighting_condition_keys"] = []
self.text_features_cache["lighting_tokens"] = None
# 處理特殊場景 (SPECIALIZED_SCENE_PROMPTS)
specialized_tokens_dict_val = {}
if hasattr(self, 'specialized_scene_prompts') and self.specialized_scene_prompts:
for scene_type, prompts in self.specialized_scene_prompts.items():
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
try:
specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
specialized_tokens_dict_val[scene_type] = None
else:
specialized_tokens_dict_val[scene_type] = None
self.text_features_cache["specialized_tokens_dict"] = specialized_tokens_dict_val
# 處理視角 (VIEWPOINT_PROMPTS)
if hasattr(self, 'viewpoint_prompts') and self.viewpoint_prompts:
viewpoint_texts = [prompt for viewpoint, prompt in self.viewpoint_prompts.items()]
if viewpoint_texts:
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
try:
self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
self.text_features_cache["viewpoint_tokens"] = None
else:
self.text_features_cache["viewpoint_keys"] = []
self.text_features_cache["viewpoint_tokens"] = None
else:
self.text_features_cache["viewpoint_keys"] = []
self.text_features_cache["viewpoint_tokens"] = None
# 處理物件組合 (OBJECT_COMBINATION_PROMPTS)
if hasattr(self, 'object_combination_prompts') and self.object_combination_prompts:
object_combination_texts = [prompt for combo, prompt in self.object_combination_prompts.items()]
if object_combination_texts:
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
try:
self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
self.text_features_cache["object_combination_tokens"] = None
else:
self.text_features_cache["object_combination_keys"] = []
self.text_features_cache["object_combination_tokens"] = None
else:
self.text_features_cache["object_combination_keys"] = []
self.text_features_cache["object_combination_tokens"] = None
# 處理活動 (ACTIVITY_PROMPTS)
if hasattr(self, 'activity_prompts') and self.activity_prompts:
activity_texts = [prompt for activity, prompt in self.activity_prompts.items()]
if activity_texts:
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
try:
self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device)
except Exception as e:
print(f"Warning: Error tokenizing activity_prompts: {e}")
self.text_features_cache["activity_tokens"] = None
else:
self.text_features_cache["activity_keys"] = []
self.text_features_cache["activity_tokens"] = None
else:
self.text_features_cache["activity_keys"] = []
self.text_features_cache["activity_tokens"] = None
self.scene_type_tokens = self.text_features_cache["scene_type_tokens"]
self.lighting_tokens = self.text_features_cache["lighting_tokens"]
self.viewpoint_tokens = self.text_features_cache["viewpoint_tokens"]
self.object_combination_tokens = self.text_features_cache["object_combination_tokens"]
self.activity_tokens = self.text_features_cache["activity_tokens"]
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
print("CLIP text_features_cache prepared.")
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
"""
分析圖像,預測場景類型和光照條件。
Args:
image: 輸入圖像 (PIL Image 或 numpy array)
include_cultural_analysis: 是否包含文化場景的詳細分析
exclude_categories: 要排除的類別列表
enable_landmark: 是否啟用地標檢測功能
places365_guidance: Places365 提供的場景指導信息 (可選)
Returns:
Dict: 包含場景類型預測和光照條件的分析結果
"""
try:
self.enable_landmark = enable_landmark # 更新實例的 enable_landmark 狀態
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理圖像
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
# 獲取圖像特徵
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
places365_focus_areas = []
places365_scene_context = "" # 用於存儲 Places365 提供的場景描述
if places365_guidance and isinstance(places365_guidance, dict) and places365_guidance.get('confidence', 0) > 0.4:
mapped_scene = places365_guidance.get('mapped_scene_type', '')
scene_label = places365_guidance.get('scene_label', '')
# is_indoor = places365_guidance.get('is_indoor', None) # 未使用,可註釋
attributes = places365_guidance.get('attributes', [])
places365_scene_context = f"Scene identified by Places365 as {scene_label}" # 更新上下文描述
# Adjust CLIP analysis focus based on Places365 scene type
if mapped_scene in ['kitchen', 'dining_area', 'restaurant']:
places365_focus_areas.extend(['food preparation', 'dining setup', 'kitchen appliances'])
elif mapped_scene in ['office_workspace', 'educational_setting', 'library', 'conference_room']:
places365_focus_areas.extend(['work environment', 'professional setting', 'learning space', 'study area'])
elif mapped_scene in ['retail_store', 'shopping_mall', 'market', 'supermarket']: # 擴展匹配
places365_focus_areas.extend(['commercial space', 'shopping environment', 'retail display', 'goods for sale'])
elif mapped_scene in ['park_area', 'beach', 'natural_outdoor_area', 'playground', 'sports_field']: # 擴展匹配
places365_focus_areas.extend(['outdoor recreation', 'natural environment', 'leisure activity', 'open space'])
# 根據屬性添加更通用的 focus areas
if isinstance(attributes, list): # 確保 attributes 是列表
if 'commercial' in attributes:
places365_focus_areas.append('business activity')
if 'recreational' in attributes:
places365_focus_areas.append('entertainment or leisure')
if 'residential' in attributes:
places365_focus_areas.append('living space')
# 去重
places365_focus_areas = list(set(places365_focus_areas))
if places365_focus_areas: # 只有在確實有 focus areas 時才打印
print(f"CLIP analysis guided by Places365: {places365_scene_context}, focus areas: {places365_focus_areas}")
# 分析場景類型,傳遞 enable_landmark 參數和 Places365 指導
scene_scores = self._analyze_scene_type(image_features,
enable_landmark=self.enable_landmark, # 使用更新後的實例屬性
places365_focus=places365_focus_areas)
# 如果禁用地標功能,確保排除地標相關類別
current_exclude_categories = list(exclude_categories) if exclude_categories is not None else []
if not self.enable_landmark: # 使用更新後的實例屬性
landmark_related_terms = ["landmark", "monument", "tower", "tourist", "attraction", "historical", "famous", "iconic"]
for term in landmark_related_terms:
if term not in current_exclude_categories:
current_exclude_categories.append(term)
if current_exclude_categories:
filtered_scores = {}
for scene, score in scene_scores.items():
# 檢查 scene 的鍵名(通常是英文)是否包含任何排除詞彙
if not any(cat.lower() in scene.lower() for cat in current_exclude_categories):
filtered_scores[scene] = score
if filtered_scores:
total_score = sum(filtered_scores.values())
if total_score > 1e-5: # 避免除以零或非常小的數
scene_scores = {k: v / total_score for k, v in filtered_scores.items()}
else: # 如果總分趨近於0,則保持原樣或設為0
scene_scores = {k: 0.0 for k in filtered_scores.keys()} # 或者 scene_scores = filtered_scores
else: # 如果過濾後沒有場景了
scene_scores = {k: (0.0 if any(cat.lower() in k.lower() for cat in current_exclude_categories) else v) for k,v in scene_scores.items()}
if not any(s > 1e-5 for s in scene_scores.values()): # 如果還是全0
scene_scores = {"unknown": 1.0} # 給一個默認值避免空字典
lighting_scores = self._analyze_lighting_condition(image_features)
cultural_analysis = {}
if include_cultural_analysis and self.enable_landmark: # 使用更新後的實例屬性
for scene_type_cultural_key in self.text_features_cache.get("cultural_tokens_dict", {}).keys():
# 確保 scene_type_cultural_key 是 SCENE_TYPE_PROMPTS 中的鍵,或者有一個映射關係
if scene_type_cultural_key in scene_scores and scene_scores[scene_type_cultural_key] > 0.2:
cultural_analysis[scene_type_cultural_key] = self._analyze_cultural_scene(
image_features, scene_type_cultural_key
)
specialized_analysis = {}
for scene_type_specialized_key in self.text_features_cache.get("specialized_tokens_dict", {}).keys():
if scene_type_specialized_key in scene_scores and scene_scores[scene_type_specialized_key] > 0.2:
specialized_analysis[scene_type_specialized_key] = self._analyze_specialized_scene(
image_features, scene_type_specialized_key
)
viewpoint_scores = self._analyze_viewpoint(image_features)
object_combination_scores = self._analyze_object_combinations(image_features)
activity_scores = self._analyze_activities(image_features)
if scene_scores: # 確保 scene_scores 不是空的
top_scene = max(scene_scores.items(), key=lambda x: x[1])
# 如果禁用地標,再次確認 top_scene 不是地標相關
if not self.enable_landmark and any(cat.lower() in top_scene[0].lower() for cat in current_exclude_categories):
non_excluded_scores = {k:v for k,v in scene_scores.items() if not any(cat.lower() in k.lower() for cat in current_exclude_categories)}
if non_excluded_scores:
top_scene = max(non_excluded_scores.items(), key=lambda x: x[1])
else:
top_scene = ("unknown", 0.0) # 或其他合適的默認值
else:
top_scene = ("unknown", 0.0)
result = {
"scene_scores": scene_scores,
"top_scene": top_scene,
"lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]) if lighting_scores else ("unknown", 0.0),
"embedding": image_features.cpu().numpy().tolist()[0], # 簡化
"viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]) if viewpoint_scores else ("unknown", 0.0),
"object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3] if object_combination_scores else [],
"activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3] if activity_scores else []
}
if places365_guidance and isinstance(places365_guidance, dict) and places365_focus_areas: # 檢查 places365_focus_areas 是否被填充
result["places365_guidance"] = {
"scene_context": places365_scene_context,
"focus_areas": places365_focus_areas, # 現在這個會包含基於 guidance 的內容
"guided_analysis": True,
"original_places365_scene": places365_guidance.get('scene_label', 'N/A'),
"original_places365_confidence": places365_guidance.get('confidence', 0.0)
}
if cultural_analysis and self.enable_landmark:
result["cultural_analysis"] = cultural_analysis
if specialized_analysis:
result["specialized_analysis"] = specialized_analysis
return result
except Exception as e:
print(f"Error analyzing image with CLIP: {e}")
import traceback
traceback.print_exc()
return {"error": str(e), "scene_scores": {}, "top_scene": ("error", 0.0)}
def _analyze_scene_type(self, image_features: torch.Tensor, enable_landmark: bool = True, places365_focus: List[str] = None) -> Dict[str, float]:
"""
分析圖像特徵與各場景類型的相似度,並可選擇性地排除地標相關場景
Args:
image_features: 經過 CLIP 編碼的圖像特徵
enable_landmark: 是否啟用地標識別功能
Returns:
Dict[str, float]: 各場景類型的相似度分數字典
"""
with torch.no_grad():
# 計算場景類型文本特徵
text_features = self.model.encode_text(self.scene_type_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Apply Places365 guidance if available
if places365_focus and len(places365_focus) > 0:
# Create enhanced prompts that incorporate Places365 guidance
enhanced_prompts = []
for scene_type in self.scene_type_prompts.keys():
base_prompt = self.scene_type_prompts[scene_type]
# Check if this scene type should be emphasized based on Places365 guidance
scene_lower = scene_type.lower()
should_enhance = False
for focus_area in places365_focus:
if any(keyword in scene_lower for keyword in focus_area.split()):
should_enhance = True
enhanced_prompts.append(f"{base_prompt} with {focus_area}")
break
if not should_enhance:
enhanced_prompts.append(base_prompt)
# Re-tokenize and encode enhanced prompts
enhanced_tokens = clip.tokenize(enhanced_prompts).to(self.device)
text_features = self.model.encode_text(enhanced_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立場景分數字典
scene_scores = {}
for i, scene_type in enumerate(self.scene_type_prompts.keys()):
# 如果未啟用地標功能,則跳過地標相關場景類型
if not enable_landmark and scene_type in ["tourist_landmark", "natural_landmark", "historical_monument"]:
scene_scores[scene_type] = 0.0 # 將地標場景分數設為零
else:
base_score = float(similarity[i])
# Apply Places365 guidance boost if applicable
if places365_focus:
scene_lower = scene_type.lower()
boost_factor = 1.0
for focus_area in places365_focus:
if any(keyword in scene_lower for keyword in focus_area.split()):
boost_factor = 1.15 # 15% boost for matching scenes
break
scene_scores[scene_type] = base_score * boost_factor
else:
scene_scores[scene_type] = base_score
# 如果禁用地標功能,確保重新歸一化剩餘場景分數
if not enable_landmark:
# 獲取所有非零分數
non_zero_scores = {k: v for k, v in scene_scores.items() if v > 0}
if non_zero_scores:
# 計算總和並歸一化
total_score = sum(non_zero_scores.values())
if total_score > 0:
for scene_type in non_zero_scores:
scene_scores[scene_type] = non_zero_scores[scene_type] / total_score
return scene_scores
def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的光照條件"""
with torch.no_grad():
# 計算光照條件文本特徵
text_features = self.model.encode_text(self.lighting_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立光照條件分數字典
lighting_scores = {}
for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()):
lighting_scores[lighting_type] = float(similarity[i])
return lighting_scores
def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定文化場景進行深入分析"""
if scene_type not in self.cultural_tokens_dict:
return {"error": f"No cultural analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定文化場景的文本特徵
cultural_tokens = self.cultural_tokens_dict[scene_type]
text_features = self.model.encode_text(cultural_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的文化描述
prompts = self.cultural_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定專門場景進行深入分析"""
if scene_type not in self.specialized_tokens_dict:
return {"error": f"No specialized analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定專門場景的文本特徵
specialized_tokens = self.specialized_tokens_dict[scene_type]
text_features = self.model.encode_text(specialized_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的專門描述
prompts = self.specialized_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的拍攝視角"""
with torch.no_grad():
# 計算視角文本特徵
text_features = self.model.encode_text(self.viewpoint_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立視角分數字典
viewpoint_scores = {}
for i, viewpoint in enumerate(self.viewpoint_prompts.keys()):
viewpoint_scores[viewpoint] = float(similarity[i])
return viewpoint_scores
def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的物體組合"""
with torch.no_grad():
# 計算物體組合文本特徵
text_features = self.model.encode_text(self.object_combination_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立物體組合分數字典
combination_scores = {}
for i, combination in enumerate(self.object_combination_prompts.keys()):
combination_scores[combination] = float(similarity[i])
return combination_scores
def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的活動"""
with torch.no_grad():
# 計算活動文本特徵
text_features = self.model.encode_text(self.activity_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立活動分數字典
activity_scores = {}
for i, activity in enumerate(self.activity_prompts.keys()):
activity_scores[activity] = float(similarity[i])
return activity_scores
def get_image_embedding(self, image) -> np.ndarray:
"""
獲取圖像的 CLIP 嵌入表示
Args:
image: PIL Image 或 numpy array
Returns:
np.ndarray: 圖像的 CLIP 特徵向量
"""
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理並編碼
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 轉換為 numpy 並返回
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
def text_to_embedding(self, text: str) -> np.ndarray:
"""
將文本轉換為 CLIP 嵌入表示
Args:
text: 輸入文本
Returns:
np.ndarray: 文本的 CLIP 特徵向量
"""
text_token = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_token)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0]
def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]:
"""
計算圖像與多個文本查詢的相似度
Args:
image: PIL Image 或 numpy array
text_queries: 文本查詢列表
Returns:
Dict: 每個查詢的相似度分數
"""
# 獲取圖像嵌入
if isinstance(image, np.ndarray) and len(image.shape) == 1:
# 已經是嵌入向量
image_features = torch.tensor(image).unsqueeze(0).to(self.device)
else:
# 是圖像,需要提取嵌入
image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device)
# calulate similarity
text_tokens = clip.tokenize(text_queries).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# display results
result = {}
for i, query in enumerate(text_queries):
result[query] = float(similarity[i])
return result
def get_clip_instance(self):
"""
獲取初始化好的CLIP模型實例,便於其他模組重用
Returns:
tuple: (模型實例, 預處理函數, 設備名稱)
"""
return self.model, self.preprocess, self.device
|