jisujang's picture
first
a005c19
from pia_bench.checker.bench_checker import BenchChecker
from pia_bench.checker.sheet_checker import SheetChecker
from pia_bench.event_alarm import EventDetector
from pia_bench.metric import MetricsEvaluator
from sheet_manager.sheet_crud.sheet_crud import SheetManager
from pia_bench.bench import PiaBenchMark
from dotenv import load_dotenv
from typing import Optional, List , Dict
import os
load_dotenv()
import numpy as np
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
import logging
from dataclasses import dataclass
from sheet_manager.sheet_checker.sheet_check import SheetChecker
from sheet_manager.sheet_crud.sheet_crud import SheetManager
from pia_bench.checker.bench_checker import BenchChecker
logging.basicConfig(level=logging.INFO)
@dataclass
class PipelineConfig:
"""ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ํด๋ž˜์Šค"""
model_name: str
benchmark_name: str
cfg_target_path: str
base_path: str = "/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
class BenchmarkPipelineStatus:
"""ํŒŒ์ดํ”„๋ผ์ธ ์ƒํƒœ ๋ฐ ๊ฒฐ๊ณผ ๊ด€๋ฆฌ"""
def __init__(self):
self.sheet_status: Tuple[bool, bool] = (False, False) # (model_added, benchmark_exists)
self.bench_status: Dict[str, bool] = {}
self.bench_result: str = ""
self.current_stage: str = "not_started"
def is_success(self) -> bool:
"""์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์„ฑ๊ณต ์—ฌ๋ถ€"""
return (not self.sheet_status[0] # ๋ชจ๋ธ์ด ์ด๋ฏธ ์กด์žฌํ•˜๊ณ 
and self.sheet_status[1] # ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์กด์žฌํ•˜๊ณ 
and self.bench_result == "all_passed") # ๋ฒค์น˜๋งˆํฌ ์ฒดํฌ๋„ ํ†ต๊ณผ
def __str__(self) -> str:
return (f"Current Stage: {self.current_stage}\n"
f"Sheet Status: {self.sheet_status}\n"
f"Bench Status: {self.bench_status}\n"
f"Bench Result: {self.bench_result}")
class BenchmarkPipeline:
"""๋ฒค์น˜๋งˆํฌ ์‹คํ–‰์„ ์œ„ํ•œ ํŒŒ์ดํ”„๋ผ์ธ"""
def __init__(self, config: PipelineConfig):
self.config = config
self.logger = logging.getLogger(self.__class__.__name__)
self.status = BenchmarkPipelineStatus()
self.access_token = os.getenv("ACCESS_TOKEN")
self.cfg_prompt = os.path.splitext(os.path.basename(self.config.cfg_target_path))[0]
# Initialize checkers
self.sheet_manager = SheetManager()
self.sheet_checker = SheetChecker(self.sheet_manager)
self.bench_checker = BenchChecker(self.config.base_path)
self.bench_result_dict = None
def run(self) -> BenchmarkPipelineStatus:
"""์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰"""
try:
self.status.current_stage = "sheet_check"
proceed = self._check_sheet()
if not proceed:
self.status.current_stage = "completed_no_action_needed"
self.logger.info("๋ฒค์น˜๋งˆํฌ๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•˜์—ฌ ์ถ”๊ฐ€ ์ž‘์—…์ด ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
return self.status
self.status.current_stage = "bench_check"
if not self._check_bench():
return self.status
self.status.current_stage = "execution"
self._execute_based_on_status()
self.status.current_stage = "completed"
return self.status
except Exception as e:
self.logger.error(f"ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ์ค‘ ์—๋Ÿฌ ๋ฐœ์ƒ: {str(e)}")
self.status.current_stage = "error"
return self.status
def _check_sheet(self) -> bool:
"""๊ตฌ๊ธ€ ์‹œํŠธ ์ƒํƒœ ์ฒดํฌ"""
self.logger.info("์‹œํŠธ ์ƒํƒœ ์ฒดํฌ ์‹œ์ž‘")
model_added, benchmark_exists = self.sheet_checker.check_model_and_benchmark(
self.config.model_name,
self.config.benchmark_name
)
self.status.sheet_status = (model_added, benchmark_exists)
if model_added:
self.logger.info("์ƒˆ๋กœ์šด ๋ชจ๋ธ์ด ์ถ”๊ฐ€๋˜์—ˆ์Šต๋‹ˆ๋‹ค")
if not benchmark_exists:
self.logger.info("๋ฒค์น˜๋งˆํฌ ์ธก์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
return True # ๋ฒค์น˜๋งˆํฌ ์ธก์ •์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ๋งŒ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
self.logger.info("์ด๋ฏธ ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
return False # ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์ด๋ฏธ ์žˆ์œผ๋ฉด ์—ฌ๊ธฐ์„œ ์ค‘๋‹จ
def _check_bench(self) -> bool:
"""๋กœ์ปฌ ๋ฒค์น˜๋งˆํฌ ํ™˜๊ฒฝ ์ฒดํฌ"""
self.logger.info("๋ฒค์น˜๋งˆํฌ ํ™˜๊ฒฝ ์ฒดํฌ ์‹œ์ž‘")
self.status.bench_status = self.bench_checker.check_benchmark(
self.config.benchmark_name,
self.config.model_name,
self.cfg_prompt
)
self.status.bench_result = self.bench_checker.get_benchmark_status(
self.status.bench_status
)
# no bench ์ƒํƒœ ๋ฒค์น˜๋ฅผ ๋Œ๋ฆฐ์ ์ด ์—†์Œ ํด๋”๊ตฌ์กฐ๋„ ์—†์Œ
if self.status.bench_result == "no bench":
self.logger.error("๋ฒค์น˜๋งˆํฌ ์‹คํ–‰์— ํ•„์š”ํ•œ ๊ธฐ๋ณธ ํด๋”๊ตฌ์กฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return True
return True # ๊ทธ ์™ธ์˜ ๊ฒฝ์šฐ๋งŒ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
def _execute_based_on_status(self):
"""์ƒํƒœ์— ๋”ฐ๋ฅธ ์‹คํ–‰ ๋กœ์ง"""
if self.status.bench_result == "all_passed":
self._execute_full_pipeline()
elif self.status.bench_result == "no_vectors":
self._execute_vector_generation()
elif self.status.bench_result == "no_metrics":
self._execute_metrics_generation()
else:
self._execute_vector_generation()
self.logger.warning("ํด๋”๊ตฌ์กฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
def _execute_full_pipeline(self):
"""๋ชจ๋“  ์กฐ๊ฑด์ด ์ถฉ์กฑ๋œ ๊ฒฝ์šฐ์˜ ์‹คํ–‰ ๋กœ์ง"""
self.logger.info("์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ์ค‘...")
pia_benchmark = PiaBenchMark(
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
model_name=self.config.model_name,
cfg_target_path= self.config.cfg_target_path ,
token=self.access_token )
pia_benchmark.preprocess_structure()
print("Categories identified:", pia_benchmark.categories)
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
label_dir=pia_benchmark.dataset_path,
save_dir=pia_benchmark.metric_path)
self.bench_result_dict = metric.evaluate()
def _execute_vector_generation(self):
"""๋ฒกํ„ฐ ์ƒ์„ฑ์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ์˜ ์‹คํ–‰ ๋กœ์ง"""
self.logger.info("๋ฒกํ„ฐ ์ƒ์„ฑ ์ค‘...")
# ๊ตฌํ˜„ ํ•„์š”
pia_benchmark = PiaBenchMark(
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
model_name=self.config.model_name,
cfg_target_path= self.config.cfg_target_path ,
token=self.access_token )
pia_benchmark.preprocess_structure()
pia_benchmark.preprocess_label_to_csv()
print("Categories identified:", pia_benchmark.categories)
pia_benchmark.extract_visual_vector()
detector = EventDetector(config_path=self.config.cfg_target_path,
model_name=self.config.model_name ,
token=pia_benchmark.token)
detector.process_and_save_predictions(pia_benchmark.vector_video_path,
pia_benchmark.dataset_path,
pia_benchmark.alram_path)
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
label_dir=pia_benchmark.dataset_path,
save_dir=pia_benchmark.metric_path)
self.bench_result_dict = metric.evaluate()
def _execute_metrics_generation(self):
"""๋ฉ”ํŠธ๋ฆญ ์ƒ์„ฑ์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ์˜ ์‹คํ–‰ ๋กœ์ง"""
self.logger.info("๋ฉ”ํŠธ๋ฆญ ์ƒ์„ฑ ์ค‘...")
# ๊ตฌํ˜„ ํ•„์š”
pia_benchmark = PiaBenchMark(
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" ,
model_name=self.config.model_name,
cfg_target_path= self.config.cfg_target_path ,
token=self.access_token )
pia_benchmark.preprocess_structure()
pia_benchmark.preprocess_label_to_csv()
print("Categories identified:", pia_benchmark.categories)
detector = EventDetector(config_path=self.config.cfg_target_path,
model_name=self.config.model_name ,
token=pia_benchmark.token)
detector.process_and_save_predictions(pia_benchmark.vector_video_path,
pia_benchmark.dataset_path,
pia_benchmark.alram_path)
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path,
label_dir=pia_benchmark.dataset_path,
save_dir=pia_benchmark.metric_path)
self.bench_result_dict = metric.evaluate()
if __name__ == "__main__":
# ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •
config = PipelineConfig(
model_name="T2V_CLIP4CLIP_MSRVTT",
benchmark_name="PIA",
cfg_target_path="topk.json",
base_path="/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench"
)
# ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰
pipeline = BenchmarkPipeline(config)
result = pipeline.run()
print(f"\nํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ ๊ฒฐ๊ณผ:")
print(str(result))