Spaces:
Running
on
Zero
Running
on
Zero
Upload llm_enhancer.py
Browse files- llm_enhancer.py +32 -122
llm_enhancer.py
CHANGED
@@ -34,7 +34,7 @@ class LLMEnhancer:
|
|
34 |
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
35 |
self.logger.addHandler(handler)
|
36 |
|
37 |
-
#
|
38 |
self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
|
39 |
self.tokenizer_path = tokenizer_path or self.model_path
|
40 |
|
@@ -55,7 +55,7 @@ class LLMEnhancer:
|
|
55 |
|
56 |
self._initialize_prompts()
|
57 |
|
58 |
-
#
|
59 |
self._model_loaded = False
|
60 |
|
61 |
try:
|
@@ -70,7 +70,7 @@ class LLMEnhancer:
|
|
70 |
self.logger.error(f"Error during Hugging Face login: {e}")
|
71 |
|
72 |
def _load_model(self):
|
73 |
-
"""
|
74 |
if self._model_loaded:
|
75 |
return
|
76 |
|
@@ -80,18 +80,16 @@ class LLMEnhancer:
|
|
80 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
81 |
torch.cuda.empty_cache()
|
82 |
|
83 |
-
# 打印可用 GPU 記憶體
|
84 |
if torch.cuda.is_available():
|
85 |
free_in_GB = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
86 |
print(f"Total GPU memory: {free_in_GB:.2f} GB")
|
87 |
|
88 |
-
# 設置 8
|
89 |
quantization_config = BitsAndBytesConfig(
|
90 |
load_in_8bit=True,
|
91 |
llm_int8_enable_fp32_cpu_offload=True
|
92 |
)
|
93 |
|
94 |
-
# 加載詞元處理器
|
95 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
96 |
self.tokenizer_path,
|
97 |
padding_side="left",
|
@@ -99,14 +97,14 @@ class LLMEnhancer:
|
|
99 |
token=self.hf_token
|
100 |
)
|
101 |
|
102 |
-
#
|
103 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
104 |
|
105 |
# 加載 8 位量化模型
|
106 |
self.model = AutoModelForCausalLM.from_pretrained(
|
107 |
self.model_path,
|
108 |
quantization_config=quantization_config,
|
109 |
-
device_map="auto",
|
110 |
low_cpu_mem_usage=True,
|
111 |
token=self.hf_token
|
112 |
)
|
@@ -122,7 +120,7 @@ class LLMEnhancer:
|
|
122 |
|
123 |
def _initialize_prompts(self):
|
124 |
"""Return an optimized prompt template specifically for Zephyr model"""
|
125 |
-
# the prompt for the model
|
126 |
self.enhance_description_template = """
|
127 |
<|system|>
|
128 |
You are an expert visual analyst. Your task is to improve the readability and fluency of scene descriptions using STRICT factual accuracy.
|
@@ -153,7 +151,7 @@ class LLMEnhancer:
|
|
153 |
"""
|
154 |
|
155 |
|
156 |
-
#
|
157 |
self.verify_detection_template = """
|
158 |
Task: You are an advanced vision system that verifies computer vision detections for accuracy.
|
159 |
|
@@ -179,7 +177,7 @@ class LLMEnhancer:
|
|
179 |
Verification Results:
|
180 |
"""
|
181 |
|
182 |
-
#
|
183 |
self.no_detection_template = """
|
184 |
Task: You are an advanced scene understanding system analyzing an image where standard object detection failed to identify specific objects.
|
185 |
|
@@ -232,6 +230,7 @@ class LLMEnhancer:
|
|
232 |
|
233 |
return response
|
234 |
|
|
|
235 |
def _detect_scene_type(self, detected_objects: List[Dict]) -> str:
|
236 |
"""
|
237 |
Detect scene type based on object distribution and patterns
|
@@ -291,26 +290,6 @@ class LLMEnhancer:
|
|
291 |
|
292 |
return response.strip()
|
293 |
|
294 |
-
def _validate_scene_facts(self, enhanced_desc: str, original_desc: str, people_count: int) -> str:
|
295 |
-
"""Validate key facts in enhanced description"""
|
296 |
-
# Check if people count is preserved
|
297 |
-
if people_count > 0:
|
298 |
-
people_pattern = re.compile(r'(\d+)\s+(?:people|persons|pedestrians|individuals)', re.IGNORECASE)
|
299 |
-
people_match = people_pattern.search(enhanced_desc)
|
300 |
-
|
301 |
-
if not people_match or int(people_match.group(1)) != people_count:
|
302 |
-
# Replace incorrect count or add if missing
|
303 |
-
if people_match:
|
304 |
-
enhanced_desc = people_pattern.sub(f"{people_count} people", enhanced_desc)
|
305 |
-
else:
|
306 |
-
enhanced_desc = f"The scene shows {people_count} people. " + enhanced_desc
|
307 |
-
|
308 |
-
# Ensure aerial perspective is mentioned
|
309 |
-
if "aerial" in original_desc.lower() and "aerial" not in enhanced_desc.lower():
|
310 |
-
enhanced_desc = "From an aerial perspective, " + enhanced_desc[0].lower() + enhanced_desc[1:]
|
311 |
-
|
312 |
-
return enhanced_desc
|
313 |
-
|
314 |
def reset_context(self):
|
315 |
"""在處理新圖像前重置模型上下文"""
|
316 |
if self._model_loaded:
|
@@ -357,11 +336,11 @@ class LLMEnhancer:
|
|
357 |
scene_type = scene_data.get("scene_type", "unknown scene")
|
358 |
scene_type = self._clean_scene_type(scene_type)
|
359 |
|
360 |
-
#
|
361 |
detected_objects = scene_data.get("detected_objects", [])
|
362 |
filtered_objects = []
|
363 |
|
364 |
-
#
|
365 |
high_confidence_threshold = 0.65
|
366 |
|
367 |
for obj in detected_objects:
|
@@ -374,11 +353,11 @@ class LLMEnhancer:
|
|
374 |
if confidence < 0.75: # 為這些類別設置更高閾值
|
375 |
continue
|
376 |
|
377 |
-
#
|
378 |
if confidence >= high_confidence_threshold:
|
379 |
filtered_objects.append(obj)
|
380 |
|
381 |
-
# 計算物件列表和數量 -
|
382 |
object_counts = {}
|
383 |
for obj in filtered_objects:
|
384 |
class_name = obj.get("class_name", "")
|
@@ -389,7 +368,7 @@ class LLMEnhancer:
|
|
389 |
# 將高置信度物件格式化為清單
|
390 |
high_confidence_objects = ", ".join([f"{count} {obj}" for obj, count in object_counts.items()])
|
391 |
|
392 |
-
#
|
393 |
if not high_confidence_objects:
|
394 |
# 從原始描述中提取物件提及
|
395 |
object_keywords = self._extract_objects_from_description(original_desc)
|
@@ -406,7 +385,7 @@ class LLMEnhancer:
|
|
406 |
is_indoor = lighting_info.get("is_indoor", False)
|
407 |
lighting_description = f"{'indoor' if is_indoor else 'outdoor'} {time_of_day} lighting"
|
408 |
|
409 |
-
#
|
410 |
prompt = self.enhance_description_template.format(
|
411 |
scene_type=scene_type,
|
412 |
object_list=high_confidence_objects,
|
@@ -421,8 +400,8 @@ class LLMEnhancer:
|
|
421 |
|
422 |
# 檢查回應完整性的更嚴格標準
|
423 |
is_incomplete = (
|
424 |
-
len(response) < 100 or #
|
425 |
-
(len(response) < 200 and "." not in response[-30:]) or #
|
426 |
any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]) # 以不完整短語結尾
|
427 |
)
|
428 |
|
@@ -439,24 +418,23 @@ class LLMEnhancer:
|
|
439 |
(len(response) < 200 and "." not in response[-30:]) or
|
440 |
any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]))
|
441 |
|
442 |
-
# 確保響應不為空
|
443 |
if not response or len(response.strip()) < 10:
|
444 |
self.logger.warning("Generated response was empty or too short, returning original description")
|
445 |
return original_desc
|
446 |
|
447 |
-
#
|
448 |
if "llama" in self.model_path.lower():
|
449 |
result = self._clean_llama_response(response)
|
450 |
else:
|
451 |
result = self._clean_model_response(response)
|
452 |
|
453 |
-
#
|
454 |
result = self._remove_introduction_sentences(result)
|
455 |
|
456 |
-
#
|
457 |
result = self._remove_explanatory_notes(result)
|
458 |
|
459 |
-
#
|
460 |
result = self._verify_factual_accuracy(original_desc, result, high_confidence_objects)
|
461 |
|
462 |
# 確保場景類型和視角一致性
|
@@ -547,36 +525,6 @@ class LLMEnhancer:
|
|
547 |
|
548 |
return result
|
549 |
|
550 |
-
def _validate_content_consistency(self, original_desc: str, enhanced_desc: str) -> str:
|
551 |
-
"""驗證增強描述的內容與原始描述一致"""
|
552 |
-
# 提取原始描述中的關鍵數值
|
553 |
-
people_count_match = re.search(r'(\d+)\s+people', original_desc, re.IGNORECASE)
|
554 |
-
people_count = int(people_count_match.group(1)) if people_count_match else None
|
555 |
-
|
556 |
-
# 驗證人數一致性
|
557 |
-
if people_count:
|
558 |
-
enhanced_count_match = re.search(r'(\d+)\s+people', enhanced_desc, re.IGNORECASE)
|
559 |
-
if not enhanced_count_match or int(enhanced_count_match.group(1)) != people_count:
|
560 |
-
# 保留原始人數
|
561 |
-
if enhanced_count_match:
|
562 |
-
enhanced_desc = re.sub(r'\b\d+\s+people\b', f"{people_count} people", enhanced_desc, flags=re.IGNORECASE)
|
563 |
-
elif "people" in enhanced_desc.lower():
|
564 |
-
enhanced_desc = re.sub(r'\bpeople\b', f"{people_count} people", enhanced_desc, flags=re.IGNORECASE)
|
565 |
-
|
566 |
-
# 驗證視角/透視一致性
|
567 |
-
perspective_terms = ["aerial", "bird's-eye", "overhead", "ground level", "eye level"]
|
568 |
-
|
569 |
-
for term in perspective_terms:
|
570 |
-
if term in original_desc.lower() and term not in enhanced_desc.lower():
|
571 |
-
# 添加缺失的視角信息
|
572 |
-
if enhanced_desc[0].isupper():
|
573 |
-
enhanced_desc = f"From {term} view, {enhanced_desc[0].lower()}{enhanced_desc[1:]}"
|
574 |
-
else:
|
575 |
-
enhanced_desc = f"From {term} view, {enhanced_desc}"
|
576 |
-
break
|
577 |
-
|
578 |
-
return enhanced_desc
|
579 |
-
|
580 |
def _remove_explanatory_notes(self, response: str) -> str:
|
581 |
"""移除解釋性注釋、說明和其他非描述性內容"""
|
582 |
|
@@ -669,7 +617,7 @@ class LLMEnhancer:
|
|
669 |
# 1. 處理連續標點符號問題
|
670 |
text = re.sub(r'([.,;:!?])\1+', r'\1', text)
|
671 |
|
672 |
-
# 2. 修復不完整句子的標點(如 "Something,"
|
673 |
text = re.sub(r',\s*$', '.', text)
|
674 |
|
675 |
# 3. 修復如 "word." 後未加空格即接下一句的問題
|
@@ -686,7 +634,7 @@ class LLMEnhancer:
|
|
686 |
|
687 |
def _fact_check_description(self, original_desc: str, enhanced_desc: str, scene_type: str, detected_objects: List[str]) -> str:
|
688 |
"""
|
689 |
-
|
690 |
|
691 |
Args:
|
692 |
original_desc: 原始場景描述
|
@@ -772,7 +720,7 @@ class LLMEnhancer:
|
|
772 |
|
773 |
# 3. 檢查場景類型一致性
|
774 |
if scene_type and scene_type.lower() != "unknown" and scene_type.lower() not in enhanced_desc.lower():
|
775 |
-
#
|
776 |
if enhanced_desc.startswith("This ") or enhanced_desc.startswith("The "):
|
777 |
# 避免產生 "This scene" 和 "This intersection" 的重複
|
778 |
if "scene" in enhanced_desc[:15].lower():
|
@@ -895,12 +843,12 @@ class LLMEnhancer:
|
|
895 |
if "llama" in self.model_path.lower():
|
896 |
generation_params.update({
|
897 |
"temperature": 0.4, # 不要太高, 否則模型可能會太有主觀意見
|
898 |
-
"max_new_tokens": 600,
|
899 |
-
"do_sample": True,
|
900 |
-
"top_p": 0.8,
|
901 |
"repetition_penalty": 1.2, # 重複的懲罰權重,可避免掉重複字
|
902 |
-
"num_beams": 4 ,
|
903 |
-
"length_penalty": 1.2,
|
904 |
})
|
905 |
|
906 |
else:
|
@@ -926,7 +874,7 @@ class LLMEnhancer:
|
|
926 |
if assistant_tag in full_response:
|
927 |
response = full_response.split(assistant_tag)[-1].strip()
|
928 |
|
929 |
-
# 檢查是否有未閉合的 <|assistant|>
|
930 |
user_tag = "<|user|>"
|
931 |
if user_tag in response:
|
932 |
response = response.split(user_tag)[0].strip()
|
@@ -1118,7 +1066,7 @@ class LLMEnhancer:
|
|
1118 |
|
1119 |
# 14. 統一格式 - 確保輸出始終是單一段落
|
1120 |
response = re.sub(r'\s*\n\s*', ' ', response) # 將所有換行符替換為空格
|
1121 |
-
response = ' '.join(response.split())
|
1122 |
|
1123 |
return response.strip()
|
1124 |
|
@@ -1133,44 +1081,6 @@ class LLMEnhancer:
|
|
1133 |
|
1134 |
return "\n- " + "\n- ".join(formatted)
|
1135 |
|
1136 |
-
def _format_lighting(self, lighting_info: Dict) -> str:
|
1137 |
-
"""格式化光照信息以用於提示"""
|
1138 |
-
if not lighting_info:
|
1139 |
-
return "Unknown lighting conditions"
|
1140 |
-
|
1141 |
-
time = lighting_info.get("time_of_day", "unknown")
|
1142 |
-
conf = lighting_info.get("confidence", 0)
|
1143 |
-
is_indoor = lighting_info.get("is_indoor", False)
|
1144 |
-
|
1145 |
-
base_info = f"{'Indoor' if is_indoor else 'Outdoor'} {time} (confidence: {conf:.2f})"
|
1146 |
-
|
1147 |
-
# 添加更詳細的診斷信息
|
1148 |
-
diagnostics = lighting_info.get("diagnostics", {})
|
1149 |
-
if diagnostics:
|
1150 |
-
diag_str = "\nAdditional lighting diagnostics:"
|
1151 |
-
for key, value in diagnostics.items():
|
1152 |
-
diag_str += f"\n- {key}: {value}"
|
1153 |
-
base_info += diag_str
|
1154 |
-
|
1155 |
-
return base_info
|
1156 |
-
|
1157 |
-
def _format_zones(self, zones: Dict) -> str:
|
1158 |
-
"""格式化功能區域以用於提示"""
|
1159 |
-
if not zones:
|
1160 |
-
return "No distinct functional zones identified"
|
1161 |
-
|
1162 |
-
formatted = ["Identified functional zones:"]
|
1163 |
-
for zone_name, zone_data in zones.items():
|
1164 |
-
desc = zone_data.get("description", "")
|
1165 |
-
objects = zone_data.get("objects", [])
|
1166 |
-
|
1167 |
-
zone_str = f"- {zone_name}: {desc}"
|
1168 |
-
if objects:
|
1169 |
-
zone_str += f" (Contains: {', '.join(objects)})"
|
1170 |
-
|
1171 |
-
formatted.append(zone_str)
|
1172 |
-
|
1173 |
-
return "\n".join(formatted)
|
1174 |
|
1175 |
def _format_clip_results(self, clip_analysis: Dict) -> str:
|
1176 |
"""格式化CLIP分析結果以用於提示"""
|
|
|
34 |
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
35 |
self.logger.addHandler(handler)
|
36 |
|
37 |
+
# 默認用 Llama3.2
|
38 |
self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
|
39 |
self.tokenizer_path = tokenizer_path or self.model_path
|
40 |
|
|
|
55 |
|
56 |
self._initialize_prompts()
|
57 |
|
58 |
+
# only if need to load the model
|
59 |
self._model_loaded = False
|
60 |
|
61 |
try:
|
|
|
70 |
self.logger.error(f"Error during Hugging Face login: {e}")
|
71 |
|
72 |
def _load_model(self):
|
73 |
+
"""只在首次需要時加載,使用 8 位量化以節省記憶體"""
|
74 |
if self._model_loaded:
|
75 |
return
|
76 |
|
|
|
80 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
81 |
torch.cuda.empty_cache()
|
82 |
|
|
|
83 |
if torch.cuda.is_available():
|
84 |
free_in_GB = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
85 |
print(f"Total GPU memory: {free_in_GB:.2f} GB")
|
86 |
|
87 |
+
# 設置 8 位元配置(節省記憶體空間)
|
88 |
quantization_config = BitsAndBytesConfig(
|
89 |
load_in_8bit=True,
|
90 |
llm_int8_enable_fp32_cpu_offload=True
|
91 |
)
|
92 |
|
|
|
93 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
94 |
self.tokenizer_path,
|
95 |
padding_side="left",
|
|
|
97 |
token=self.hf_token
|
98 |
)
|
99 |
|
100 |
+
# 特殊標記
|
101 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
102 |
|
103 |
# 加載 8 位量化模型
|
104 |
self.model = AutoModelForCausalLM.from_pretrained(
|
105 |
self.model_path,
|
106 |
quantization_config=quantization_config,
|
107 |
+
device_map="auto",
|
108 |
low_cpu_mem_usage=True,
|
109 |
token=self.hf_token
|
110 |
)
|
|
|
120 |
|
121 |
def _initialize_prompts(self):
|
122 |
"""Return an optimized prompt template specifically for Zephyr model"""
|
123 |
+
# the critical prompt for the model
|
124 |
self.enhance_description_template = """
|
125 |
<|system|>
|
126 |
You are an expert visual analyst. Your task is to improve the readability and fluency of scene descriptions using STRICT factual accuracy.
|
|
|
151 |
"""
|
152 |
|
153 |
|
154 |
+
# 錯誤檢測的prompt
|
155 |
self.verify_detection_template = """
|
156 |
Task: You are an advanced vision system that verifies computer vision detections for accuracy.
|
157 |
|
|
|
177 |
Verification Results:
|
178 |
"""
|
179 |
|
180 |
+
# 無檢測處理的prompt
|
181 |
self.no_detection_template = """
|
182 |
Task: You are an advanced scene understanding system analyzing an image where standard object detection failed to identify specific objects.
|
183 |
|
|
|
230 |
|
231 |
return response
|
232 |
|
233 |
+
# For Future Usage
|
234 |
def _detect_scene_type(self, detected_objects: List[Dict]) -> str:
|
235 |
"""
|
236 |
Detect scene type based on object distribution and patterns
|
|
|
290 |
|
291 |
return response.strip()
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
def reset_context(self):
|
294 |
"""在處理新圖像前重置模型上下文"""
|
295 |
if self._model_loaded:
|
|
|
336 |
scene_type = scene_data.get("scene_type", "unknown scene")
|
337 |
scene_type = self._clean_scene_type(scene_type)
|
338 |
|
339 |
+
# 提取檢測到的物件並過濾低信心度物件
|
340 |
detected_objects = scene_data.get("detected_objects", [])
|
341 |
filtered_objects = []
|
342 |
|
343 |
+
# 高信心度閾值,嚴格過濾物件
|
344 |
high_confidence_threshold = 0.65
|
345 |
|
346 |
for obj in detected_objects:
|
|
|
353 |
if confidence < 0.75: # 為這些類別設置更高閾值
|
354 |
continue
|
355 |
|
356 |
+
# 只保留高信心度物件
|
357 |
if confidence >= high_confidence_threshold:
|
358 |
filtered_objects.append(obj)
|
359 |
|
360 |
+
# 計算物件列表和數量 - 僅使用過濾後的高信心度物件
|
361 |
object_counts = {}
|
362 |
for obj in filtered_objects:
|
363 |
class_name = obj.get("class_name", "")
|
|
|
368 |
# 將高置信度物件格式化為清單
|
369 |
high_confidence_objects = ", ".join([f"{count} {obj}" for obj, count in object_counts.items()])
|
370 |
|
371 |
+
# 如果沒有高信心度物件,回退到使用原始描述中的關鍵詞
|
372 |
if not high_confidence_objects:
|
373 |
# 從原始描述中提取物件提及
|
374 |
object_keywords = self._extract_objects_from_description(original_desc)
|
|
|
385 |
is_indoor = lighting_info.get("is_indoor", False)
|
386 |
lighting_description = f"{'indoor' if is_indoor else 'outdoor'} {time_of_day} lighting"
|
387 |
|
388 |
+
# 創建prompt,整合所有關鍵資訊
|
389 |
prompt = self.enhance_description_template.format(
|
390 |
scene_type=scene_type,
|
391 |
object_list=high_confidence_objects,
|
|
|
400 |
|
401 |
# 檢查回應完整性的更嚴格標準
|
402 |
is_incomplete = (
|
403 |
+
len(response) < 100 or # too short
|
404 |
+
(len(response) < 200 and "." not in response[-30:]) or # 結尾沒有適當的標點符號
|
405 |
any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]) # 以不完整短語結尾
|
406 |
)
|
407 |
|
|
|
418 |
(len(response) < 200 and "." not in response[-30:]) or
|
419 |
any(response.endswith(phrase) for phrase in ["in the", "with the", "and the"]))
|
420 |
|
|
|
421 |
if not response or len(response.strip()) < 10:
|
422 |
self.logger.warning("Generated response was empty or too short, returning original description")
|
423 |
return original_desc
|
424 |
|
425 |
+
# 使用與模型相符的清理方法
|
426 |
if "llama" in self.model_path.lower():
|
427 |
result = self._clean_llama_response(response)
|
428 |
else:
|
429 |
result = self._clean_model_response(response)
|
430 |
|
431 |
+
# 移除介紹性type句子
|
432 |
result = self._remove_introduction_sentences(result)
|
433 |
|
434 |
+
# 移除explanation
|
435 |
result = self._remove_explanatory_notes(result)
|
436 |
|
437 |
+
# fact check
|
438 |
result = self._verify_factual_accuracy(original_desc, result, high_confidence_objects)
|
439 |
|
440 |
# 確保場景類型和視角一致性
|
|
|
525 |
|
526 |
return result
|
527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
def _remove_explanatory_notes(self, response: str) -> str:
|
529 |
"""移除解釋性注釋、說明和其他非描述性內容"""
|
530 |
|
|
|
617 |
# 1. 處理連續標點符號問題
|
618 |
text = re.sub(r'([.,;:!?])\1+', r'\1', text)
|
619 |
|
620 |
+
# 2. 修復不完整句子的標點(如 "Something," 後沒有繼續接續下去)
|
621 |
text = re.sub(r',\s*$', '.', text)
|
622 |
|
623 |
# 3. 修復如 "word." 後未加空格即接下一句的問題
|
|
|
634 |
|
635 |
def _fact_check_description(self, original_desc: str, enhanced_desc: str, scene_type: str, detected_objects: List[str]) -> str:
|
636 |
"""
|
637 |
+
驗證並可能修正增強後的描述,確保有保持事實準確性。
|
638 |
|
639 |
Args:
|
640 |
original_desc: 原始場景描述
|
|
|
720 |
|
721 |
# 3. 檢查場景類型一致性
|
722 |
if scene_type and scene_type.lower() != "unknown" and scene_type.lower() not in enhanced_desc.lower():
|
723 |
+
# 添加場景類型
|
724 |
if enhanced_desc.startswith("This ") or enhanced_desc.startswith("The "):
|
725 |
# 避免產生 "This scene" 和 "This intersection" 的重複
|
726 |
if "scene" in enhanced_desc[:15].lower():
|
|
|
843 |
if "llama" in self.model_path.lower():
|
844 |
generation_params.update({
|
845 |
"temperature": 0.4, # 不要太高, 否則模型可能會太有主觀意見
|
846 |
+
"max_new_tokens": 600,
|
847 |
+
"do_sample": True,
|
848 |
+
"top_p": 0.8,
|
849 |
"repetition_penalty": 1.2, # 重複的懲罰權重,可避免掉重複字
|
850 |
+
"num_beams": 4 ,
|
851 |
+
"length_penalty": 1.2,
|
852 |
})
|
853 |
|
854 |
else:
|
|
|
874 |
if assistant_tag in full_response:
|
875 |
response = full_response.split(assistant_tag)[-1].strip()
|
876 |
|
877 |
+
# 檢查是否有未閉合的 <|assistant|>
|
878 |
user_tag = "<|user|>"
|
879 |
if user_tag in response:
|
880 |
response = response.split(user_tag)[0].strip()
|
|
|
1066 |
|
1067 |
# 14. 統一格式 - 確保輸出始終是單一段落
|
1068 |
response = re.sub(r'\s*\n\s*', ' ', response) # 將所有換行符替換為空格
|
1069 |
+
response = ' '.join(response.split())
|
1070 |
|
1071 |
return response.strip()
|
1072 |
|
|
|
1081 |
|
1082 |
return "\n- " + "\n- ".join(formatted)
|
1083 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1084 |
|
1085 |
def _format_clip_results(self, clip_analysis: Dict) -> str:
|
1086 |
"""格式化CLIP分析結果以用於提示"""
|