DawnC commited on
Commit
58a70f4
·
verified ·
1 Parent(s): e83cc4c

Upload llm_enhancer.py

Browse files
Files changed (1) hide show
  1. llm_enhancer.py +32 -122
llm_enhancer.py CHANGED
@@ -34,7 +34,7 @@ class LLMEnhancer:
34
  handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
35
  self.logger.addHandler(handler)
36
 
37
- # 設置默認模型路徑就是用Llama3.2
38
  self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
39
  self.tokenizer_path = tokenizer_path or self.model_path
40
 
@@ -55,7 +55,7 @@ class LLMEnhancer:
55
 
56
  self._initialize_prompts()
57
 
58
- # 只在需要時加載模型
59
  self._model_loaded = False
60
 
61
  try:
@@ -70,7 +70,7 @@ class LLMEnhancer:
70
  self.logger.error(f"Error during Hugging Face login: {e}")
71
 
72
  def _load_model(self):
73
- """懶加載模型 - 僅在首次需要時加載,使用 8 位量化以節省記憶體"""
74
  if self._model_loaded:
75
  return
76
 
@@ -80,18 +80,16 @@ class LLMEnhancer:
80
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
81
  torch.cuda.empty_cache()
82
 
83
- # 打印可用 GPU 記憶體
84
  if torch.cuda.is_available():
85
  free_in_GB = torch.cuda.get_device_properties(0).total_memory / 1024**3
86
  print(f"Total GPU memory: {free_in_GB:.2f} GB")
87
 
88
- # 設置 8 位元量化配置
89
  quantization_config = BitsAndBytesConfig(
90
  load_in_8bit=True,
91
  llm_int8_enable_fp32_cpu_offload=True
92
  )
93
 
94
- # 加載詞元處理器
95
  self.tokenizer = AutoTokenizer.from_pretrained(
96
  self.tokenizer_path,
97
  padding_side="left",
@@ -99,14 +97,14 @@ class LLMEnhancer:
99
  token=self.hf_token
100
  )
101
 
102
- # 設置特殊標記
103
  self.tokenizer.pad_token = self.tokenizer.eos_token
104
 
105
  # 加載 8 位量化模型
106
  self.model = AutoModelForCausalLM.from_pretrained(
107
  self.model_path,
108
  quantization_config=quantization_config,
109
- device_map="auto",
110
  low_cpu_mem_usage=True,
111
  token=self.hf_token
112
  )
@@ -122,7 +120,7 @@ class LLMEnhancer:
122
 
123
  def _initialize_prompts(self):
124
  """Return an optimized prompt template specifically for Zephyr model"""
125
- # the prompt for the model
126
  self.enhance_description_template = """
127
  <|system|>
128
  You are an expert visual analyst. Your task is to improve the readability and fluency of scene descriptions using STRICT factual accuracy.
@@ -153,7 +151,7 @@ class LLMEnhancer:
153
  """
154
 
155
 
156
- # 錯誤檢測提示
157
  self.verify_detection_template = """
158
  Task: You are an advanced vision system that verifies computer vision detections for accuracy.
159
 
@@ -179,7 +177,7 @@ class LLMEnhancer:
179
  Verification Results:
180
  """
181
 
182
- # 無檢測處理提示
183
  self.no_detection_template = """
184
  Task: You are an advanced scene understanding system analyzing an image where standard object detection failed to identify specific objects.
185
 
@@ -232,6 +230,7 @@ class LLMEnhancer:
232
 
233
  return response
234
 
 
235
  def _detect_scene_type(self, detected_objects: List[Dict]) -> str:
236
  """
237
  Detect scene type based on object distribution and patterns
@@ -291,26 +290,6 @@ class LLMEnhancer:
291
 
292
  return response.strip()
293
 
294
- def _validate_scene_facts(self, enhanced_desc: str, original_desc: str, people_count: int) -> str:
295
- """Validate key facts in enhanced description"""
296
- # Check if people count is preserved
297
- if people_count > 0:
298
- people_pattern = re.compile(r'(\d+)\s+(?:people|persons|pedestrians|individuals)', re.IGNORECASE)
299
- people_match = people_pattern.search(enhanced_desc)
300
-
301
- if not people_match or int(people_match.group(1)) != people_count:
302
- # Replace incorrect count or add if missing
303
- if people_match:
304
- enhanced_desc = people_pattern.sub(f"{people_count} people", enhanced_desc)
305
- else:
306
- enhanced_desc = f"The scene shows {people_count} people. " + enhanced_desc
307
-
308
- # Ensure aerial perspective is mentioned
309
- if "aerial" in original_desc.lower() and "aerial" not in enhanced_desc.lower():
310
- enhanced_desc = "From an aerial perspective, " + enhanced_desc[0].lower() + enhanced_desc[1:]
311
-
312
- return enhanced_desc
313
-
314
  def reset_context(self):
315
  """在處理新圖像前重置模型上下文"""
316
  if self._model_loaded:
@@ -357,11 +336,11 @@ class LLMEnhancer:
357
  scene_type = scene_data.get("scene_type", "unknown scene")
358
  scene_type = self._clean_scene_type(scene_type)
359
 
360
- # 提取檢測到的物件並過濾低置信度物件
361
  detected_objects = scene_data.get("detected_objects", [])
362
  filtered_objects = []
363
 
364
- # 高置信度閾值,嚴格過濾物件
365
  high_confidence_threshold = 0.65
366
 
367
  for obj in detected_objects:
@@ -374,11 +353,11 @@ class LLMEnhancer:
374
  if confidence < 0.75: # 為這些類別設置更高閾值
375
  continue
376
 
377
- # 僅保留高置信度物件
378
  if confidence >= high_confidence_threshold:
379
  filtered_objects.append(obj)
380
 
381
- # 計算物件列表和數量 - 僅使用過濾後的高置信度物件
382
  object_counts = {}
383
  for obj in filtered_objects:
384
  class_name = obj.get("class_name", "")
@@ -389,7 +368,7 @@ class LLMEnhancer:
389
  # 將高置信度物件格式化為清單
390
  high_confidence_objects = ", ".join([f"{count} {obj}" for obj, count in object_counts.items()])
391
 
392
- # 如果沒有高置信度物件,回退到使用原始描述中的關鍵詞
393
  if not high_confidence_objects:
394
  # 從原始描述中提取物件提及
395
  object_keywords = self._extract_objects_from_description(original_desc)
@@ -406,7 +385,7 @@ class LLMEnhancer:
406
  is_indoor = lighting_info.get("is_indoor", False)
407
  lighting_description = f"{'indoor' if is_indoor else 'outdoor'} {time_of_day} lighting"
408
 
409
- # 構建提示詞,整合所有關鍵資訊
410
  prompt = self.enhance_description_template.format(
411
  scene_type=scene_type,
412
  object_list=high_confidence_objects,
@@ -421,8 +400,8 @@ class LLMEnhancer:
421
 
422
  # 檢查回應完整性的更嚴格標準
423
  is_incomplete = (
424
- len(response) < 100 or # 太短
425
- (len(response) < 200 and "." not in response[-30:]) or # 結尾沒有適當標點
426
  any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]) # 以不完整短語結尾
427
  )
428
 
@@ -439,24 +418,23 @@ class LLMEnhancer:
439
  (len(response) < 200 and "." not in response[-30:]) or
440
  any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]))
441
 
442
- # 確保響應不為空
443
  if not response or len(response.strip()) < 10:
444
  self.logger.warning("Generated response was empty or too short, returning original description")
445
  return original_desc
446
 
447
- # 清理響應 - 使用與模型相符的清理方法
448
  if "llama" in self.model_path.lower():
449
  result = self._clean_llama_response(response)
450
  else:
451
  result = self._clean_model_response(response)
452
 
453
- # 移除介紹性句子
454
  result = self._remove_introduction_sentences(result)
455
 
456
- # 移除解釋性注釋
457
  result = self._remove_explanatory_notes(result)
458
 
459
- # 進行事實準確性檢查
460
  result = self._verify_factual_accuracy(original_desc, result, high_confidence_objects)
461
 
462
  # 確保場景類型和視角一致性
@@ -547,36 +525,6 @@ class LLMEnhancer:
547
 
548
  return result
549
 
550
- def _validate_content_consistency(self, original_desc: str, enhanced_desc: str) -> str:
551
- """驗證增強描述的內容與原始描述一致"""
552
- # 提取原始描述中的關鍵數值
553
- people_count_match = re.search(r'(\d+)\s+people', original_desc, re.IGNORECASE)
554
- people_count = int(people_count_match.group(1)) if people_count_match else None
555
-
556
- # 驗證人數一致性
557
- if people_count:
558
- enhanced_count_match = re.search(r'(\d+)\s+people', enhanced_desc, re.IGNORECASE)
559
- if not enhanced_count_match or int(enhanced_count_match.group(1)) != people_count:
560
- # 保留原始人數
561
- if enhanced_count_match:
562
- enhanced_desc = re.sub(r'\b\d+\s+people\b', f"{people_count} people", enhanced_desc, flags=re.IGNORECASE)
563
- elif "people" in enhanced_desc.lower():
564
- enhanced_desc = re.sub(r'\bpeople\b', f"{people_count} people", enhanced_desc, flags=re.IGNORECASE)
565
-
566
- # 驗證視角/透視一致性
567
- perspective_terms = ["aerial", "bird's-eye", "overhead", "ground level", "eye level"]
568
-
569
- for term in perspective_terms:
570
- if term in original_desc.lower() and term not in enhanced_desc.lower():
571
- # 添加缺失的視角信息
572
- if enhanced_desc[0].isupper():
573
- enhanced_desc = f"From {term} view, {enhanced_desc[0].lower()}{enhanced_desc[1:]}"
574
- else:
575
- enhanced_desc = f"From {term} view, {enhanced_desc}"
576
- break
577
-
578
- return enhanced_desc
579
-
580
  def _remove_explanatory_notes(self, response: str) -> str:
581
  """移除解釋性注釋、說明和其他非描述性內容"""
582
 
@@ -669,7 +617,7 @@ class LLMEnhancer:
669
  # 1. 處理連續標點符號問題
670
  text = re.sub(r'([.,;:!?])\1+', r'\1', text)
671
 
672
- # 2. 修復不完整句子的標點(如 "Something," 後沒有繼續句子)
673
  text = re.sub(r',\s*$', '.', text)
674
 
675
  # 3. 修復如 "word." 後未加空格即接下一句的問題
@@ -686,7 +634,7 @@ class LLMEnhancer:
686
 
687
  def _fact_check_description(self, original_desc: str, enhanced_desc: str, scene_type: str, detected_objects: List[str]) -> str:
688
  """
689
- 驗證並可能修正增強後的描述,確保其保持事實準確性,針對普遍事實而非特定場景。
690
 
691
  Args:
692
  original_desc: 原始場景描述
@@ -772,7 +720,7 @@ class LLMEnhancer:
772
 
773
  # 3. 檢查場景類型一致性
774
  if scene_type and scene_type.lower() != "unknown" and scene_type.lower() not in enhanced_desc.lower():
775
- # 優雅地添加場景類型
776
  if enhanced_desc.startswith("This ") or enhanced_desc.startswith("The "):
777
  # 避免產生 "This scene" 和 "This intersection" 的重複
778
  if "scene" in enhanced_desc[:15].lower():
@@ -895,12 +843,12 @@ class LLMEnhancer:
895
  if "llama" in self.model_path.lower():
896
  generation_params.update({
897
  "temperature": 0.4, # 不要太高, 否則模型可能會太有主觀意見
898
- "max_new_tokens": 600,
899
- "do_sample": True,
900
- "top_p": 0.8,
901
  "repetition_penalty": 1.2, # 重複的懲罰權重,可避免掉重複字
902
- "num_beams": 4 ,
903
- "length_penalty": 1.2,
904
  })
905
 
906
  else:
@@ -926,7 +874,7 @@ class LLMEnhancer:
926
  if assistant_tag in full_response:
927
  response = full_response.split(assistant_tag)[-1].strip()
928
 
929
- # 檢查是否有未閉合的 <|assistant|>
930
  user_tag = "<|user|>"
931
  if user_tag in response:
932
  response = response.split(user_tag)[0].strip()
@@ -1118,7 +1066,7 @@ class LLMEnhancer:
1118
 
1119
  # 14. 統一格式 - 確保輸出始終是單一段落
1120
  response = re.sub(r'\s*\n\s*', ' ', response) # 將所有換行符替換為空格
1121
- response = ' '.join(response.split())
1122
 
1123
  return response.strip()
1124
 
@@ -1133,44 +1081,6 @@ class LLMEnhancer:
1133
 
1134
  return "\n- " + "\n- ".join(formatted)
1135
 
1136
- def _format_lighting(self, lighting_info: Dict) -> str:
1137
- """格式化光照信息以用於提示"""
1138
- if not lighting_info:
1139
- return "Unknown lighting conditions"
1140
-
1141
- time = lighting_info.get("time_of_day", "unknown")
1142
- conf = lighting_info.get("confidence", 0)
1143
- is_indoor = lighting_info.get("is_indoor", False)
1144
-
1145
- base_info = f"{'Indoor' if is_indoor else 'Outdoor'} {time} (confidence: {conf:.2f})"
1146
-
1147
- # 添加更詳細的診斷信息
1148
- diagnostics = lighting_info.get("diagnostics", {})
1149
- if diagnostics:
1150
- diag_str = "\nAdditional lighting diagnostics:"
1151
- for key, value in diagnostics.items():
1152
- diag_str += f"\n- {key}: {value}"
1153
- base_info += diag_str
1154
-
1155
- return base_info
1156
-
1157
- def _format_zones(self, zones: Dict) -> str:
1158
- """格式化功能區域以用於提示"""
1159
- if not zones:
1160
- return "No distinct functional zones identified"
1161
-
1162
- formatted = ["Identified functional zones:"]
1163
- for zone_name, zone_data in zones.items():
1164
- desc = zone_data.get("description", "")
1165
- objects = zone_data.get("objects", [])
1166
-
1167
- zone_str = f"- {zone_name}: {desc}"
1168
- if objects:
1169
- zone_str += f" (Contains: {', '.join(objects)})"
1170
-
1171
- formatted.append(zone_str)
1172
-
1173
- return "\n".join(formatted)
1174
 
1175
  def _format_clip_results(self, clip_analysis: Dict) -> str:
1176
  """格式化CLIP分析結果以用於提示"""
 
34
  handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
35
  self.logger.addHandler(handler)
36
 
37
+ # 默認用 Llama3.2
38
  self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
39
  self.tokenizer_path = tokenizer_path or self.model_path
40
 
 
55
 
56
  self._initialize_prompts()
57
 
58
+ # only if need to load the model
59
  self._model_loaded = False
60
 
61
  try:
 
70
  self.logger.error(f"Error during Hugging Face login: {e}")
71
 
72
  def _load_model(self):
73
+ """只在首次需要時加載,使用 8 位量化以節省記憶體"""
74
  if self._model_loaded:
75
  return
76
 
 
80
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
81
  torch.cuda.empty_cache()
82
 
 
83
  if torch.cuda.is_available():
84
  free_in_GB = torch.cuda.get_device_properties(0).total_memory / 1024**3
85
  print(f"Total GPU memory: {free_in_GB:.2f} GB")
86
 
87
+ # 設置 8 位元配置(節省記憶體空間)
88
  quantization_config = BitsAndBytesConfig(
89
  load_in_8bit=True,
90
  llm_int8_enable_fp32_cpu_offload=True
91
  )
92
 
 
93
  self.tokenizer = AutoTokenizer.from_pretrained(
94
  self.tokenizer_path,
95
  padding_side="left",
 
97
  token=self.hf_token
98
  )
99
 
100
+ # 特殊標記
101
  self.tokenizer.pad_token = self.tokenizer.eos_token
102
 
103
  # 加載 8 位量化模型
104
  self.model = AutoModelForCausalLM.from_pretrained(
105
  self.model_path,
106
  quantization_config=quantization_config,
107
+ device_map="auto",
108
  low_cpu_mem_usage=True,
109
  token=self.hf_token
110
  )
 
120
 
121
  def _initialize_prompts(self):
122
  """Return an optimized prompt template specifically for Zephyr model"""
123
+ # the critical prompt for the model
124
  self.enhance_description_template = """
125
  <|system|>
126
  You are an expert visual analyst. Your task is to improve the readability and fluency of scene descriptions using STRICT factual accuracy.
 
151
  """
152
 
153
 
154
+ # 錯誤檢測的prompt
155
  self.verify_detection_template = """
156
  Task: You are an advanced vision system that verifies computer vision detections for accuracy.
157
 
 
177
  Verification Results:
178
  """
179
 
180
+ # 無檢測處理的prompt
181
  self.no_detection_template = """
182
  Task: You are an advanced scene understanding system analyzing an image where standard object detection failed to identify specific objects.
183
 
 
230
 
231
  return response
232
 
233
+ # For Future Usage
234
  def _detect_scene_type(self, detected_objects: List[Dict]) -> str:
235
  """
236
  Detect scene type based on object distribution and patterns
 
290
 
291
  return response.strip()
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  def reset_context(self):
294
  """在處理新圖像前重置模型上下文"""
295
  if self._model_loaded:
 
336
  scene_type = scene_data.get("scene_type", "unknown scene")
337
  scene_type = self._clean_scene_type(scene_type)
338
 
339
+ # 提取檢測到的物件並過濾低信心度物件
340
  detected_objects = scene_data.get("detected_objects", [])
341
  filtered_objects = []
342
 
343
+ # 高信心度閾值,嚴格過濾物件
344
  high_confidence_threshold = 0.65
345
 
346
  for obj in detected_objects:
 
353
  if confidence < 0.75: # 為這些類別設置更高閾值
354
  continue
355
 
356
+ # 只保留高信心度物件
357
  if confidence >= high_confidence_threshold:
358
  filtered_objects.append(obj)
359
 
360
+ # 計算物件列表和數量 - 僅使用過濾後的高信心度物件
361
  object_counts = {}
362
  for obj in filtered_objects:
363
  class_name = obj.get("class_name", "")
 
368
  # 將高置信度物件格式化為清單
369
  high_confidence_objects = ", ".join([f"{count} {obj}" for obj, count in object_counts.items()])
370
 
371
+ # 如果沒有高信心度物件,回退到使用原始描述中的關鍵詞
372
  if not high_confidence_objects:
373
  # 從原始描述中提取物件提及
374
  object_keywords = self._extract_objects_from_description(original_desc)
 
385
  is_indoor = lighting_info.get("is_indoor", False)
386
  lighting_description = f"{'indoor' if is_indoor else 'outdoor'} {time_of_day} lighting"
387
 
388
+ # 創建prompt,整合所有關鍵資訊
389
  prompt = self.enhance_description_template.format(
390
  scene_type=scene_type,
391
  object_list=high_confidence_objects,
 
400
 
401
  # 檢查回應完整性的更嚴格標準
402
  is_incomplete = (
403
+ len(response) < 100 or # too short
404
+ (len(response) < 200 and "." not in response[-30:]) or # 結尾沒有適當的標點符號
405
  any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]) # 以不完整短語結尾
406
  )
407
 
 
418
  (len(response) < 200 and "." not in response[-30:]) or
419
  any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]))
420
 
 
421
  if not response or len(response.strip()) < 10:
422
  self.logger.warning("Generated response was empty or too short, returning original description")
423
  return original_desc
424
 
425
+ # 使用與模型相符的清理方法
426
  if "llama" in self.model_path.lower():
427
  result = self._clean_llama_response(response)
428
  else:
429
  result = self._clean_model_response(response)
430
 
431
+ # 移除介紹性type句子
432
  result = self._remove_introduction_sentences(result)
433
 
434
+ # 移除explanation
435
  result = self._remove_explanatory_notes(result)
436
 
437
+ # fact check
438
  result = self._verify_factual_accuracy(original_desc, result, high_confidence_objects)
439
 
440
  # 確保場景類型和視角一致性
 
525
 
526
  return result
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  def _remove_explanatory_notes(self, response: str) -> str:
529
  """移除解釋性注釋、說明和其他非描述性內容"""
530
 
 
617
  # 1. 處理連續標點符號問題
618
  text = re.sub(r'([.,;:!?])\1+', r'\1', text)
619
 
620
+ # 2. 修復不完整句子的標點(如 "Something," 後沒有繼續接續下去)
621
  text = re.sub(r',\s*$', '.', text)
622
 
623
  # 3. 修復如 "word." 後未加空格即接下一句的問題
 
634
 
635
  def _fact_check_description(self, original_desc: str, enhanced_desc: str, scene_type: str, detected_objects: List[str]) -> str:
636
  """
637
+ 驗證並可能修正增強後的描述,確保有保持事實準確性。
638
 
639
  Args:
640
  original_desc: 原始場景描述
 
720
 
721
  # 3. 檢查場景類型一致性
722
  if scene_type and scene_type.lower() != "unknown" and scene_type.lower() not in enhanced_desc.lower():
723
+ # 添加場景類型
724
  if enhanced_desc.startswith("This ") or enhanced_desc.startswith("The "):
725
  # 避免產生 "This scene" 和 "This intersection" 的重複
726
  if "scene" in enhanced_desc[:15].lower():
 
843
  if "llama" in self.model_path.lower():
844
  generation_params.update({
845
  "temperature": 0.4, # 不要太高, 否則模型可能會太有主觀意見
846
+ "max_new_tokens": 600,
847
+ "do_sample": True,
848
+ "top_p": 0.8,
849
  "repetition_penalty": 1.2, # 重複的懲罰權重,可避免掉重複字
850
+ "num_beams": 4 ,
851
+ "length_penalty": 1.2,
852
  })
853
 
854
  else:
 
874
  if assistant_tag in full_response:
875
  response = full_response.split(assistant_tag)[-1].strip()
876
 
877
+ # 檢查是否有未閉合的 <|assistant|>
878
  user_tag = "<|user|>"
879
  if user_tag in response:
880
  response = response.split(user_tag)[0].strip()
 
1066
 
1067
  # 14. 統一格式 - 確保輸出始終是單一段落
1068
  response = re.sub(r'\s*\n\s*', ' ', response) # 將所有換行符替換為空格
1069
+ response = ' '.join(response.split())
1070
 
1071
  return response.strip()
1072
 
 
1081
 
1082
  return "\n- " + "\n- ".join(formatted)
1083
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1084
 
1085
  def _format_clip_results(self, clip_analysis: Dict) -> str:
1086
  """格式化CLIP分析結果以用於提示"""