Spaces:
Sleeping
Sleeping
from pia_bench.checker.bench_checker import BenchChecker | |
from pia_bench.checker.sheet_checker import SheetChecker | |
from pia_bench.event_alarm import EventDetector | |
from pia_bench.metric import MetricsEvaluator | |
from sheet_manager.sheet_crud.sheet_crud import SheetManager | |
from pia_bench.bench import PiaBenchMark | |
from dotenv import load_dotenv | |
from typing import Optional, List , Dict | |
import os | |
load_dotenv() | |
import numpy as np | |
from typing import Dict, Tuple | |
from typing import Dict, Optional, Tuple | |
import logging | |
from dataclasses import dataclass | |
from sheet_manager.sheet_checker.sheet_check import SheetChecker | |
from sheet_manager.sheet_crud.sheet_crud import SheetManager | |
from pia_bench.checker.bench_checker import BenchChecker | |
logging.basicConfig(level=logging.INFO) | |
class PipelineConfig: | |
"""ํ์ดํ๋ผ์ธ ์ค์ ์ ์ํ ๋ฐ์ดํฐ ํด๋์ค""" | |
model_name: str | |
benchmark_name: str | |
cfg_target_path: str | |
base_path: str = "/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench" | |
class BenchmarkPipelineStatus: | |
"""ํ์ดํ๋ผ์ธ ์ํ ๋ฐ ๊ฒฐ๊ณผ ๊ด๋ฆฌ""" | |
def __init__(self): | |
self.sheet_status: Tuple[bool, bool] = (False, False) # (model_added, benchmark_exists) | |
self.bench_status: Dict[str, bool] = {} | |
self.bench_result: str = "" | |
self.current_stage: str = "not_started" | |
def is_success(self) -> bool: | |
"""์ ์ฒด ํ์ดํ๋ผ์ธ ์ฑ๊ณต ์ฌ๋ถ""" | |
return (not self.sheet_status[0] # ๋ชจ๋ธ์ด ์ด๋ฏธ ์กด์ฌํ๊ณ | |
and self.sheet_status[1] # ๋ฒค์น๋งํฌ๊ฐ ์กด์ฌํ๊ณ | |
and self.bench_result == "all_passed") # ๋ฒค์น๋งํฌ ์ฒดํฌ๋ ํต๊ณผ | |
def __str__(self) -> str: | |
return (f"Current Stage: {self.current_stage}\n" | |
f"Sheet Status: {self.sheet_status}\n" | |
f"Bench Status: {self.bench_status}\n" | |
f"Bench Result: {self.bench_result}") | |
class BenchmarkPipeline: | |
"""๋ฒค์น๋งํฌ ์คํ์ ์ํ ํ์ดํ๋ผ์ธ""" | |
def __init__(self, config: PipelineConfig): | |
self.config = config | |
self.logger = logging.getLogger(self.__class__.__name__) | |
self.status = BenchmarkPipelineStatus() | |
self.access_token = os.getenv("ACCESS_TOKEN") | |
self.cfg_prompt = os.path.splitext(os.path.basename(self.config.cfg_target_path))[0] | |
# Initialize checkers | |
self.sheet_manager = SheetManager() | |
self.sheet_checker = SheetChecker(self.sheet_manager) | |
self.bench_checker = BenchChecker(self.config.base_path) | |
self.bench_result_dict = None | |
def run(self) -> BenchmarkPipelineStatus: | |
"""์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ""" | |
try: | |
self.status.current_stage = "sheet_check" | |
proceed = self._check_sheet() | |
if not proceed: | |
self.status.current_stage = "completed_no_action_needed" | |
self.logger.info("๋ฒค์น๋งํฌ๊ฐ ์ด๋ฏธ ์กด์ฌํ์ฌ ์ถ๊ฐ ์์ ์ด ํ์ํ์ง ์์ต๋๋ค.") | |
return self.status | |
self.status.current_stage = "bench_check" | |
if not self._check_bench(): | |
return self.status | |
self.status.current_stage = "execution" | |
self._execute_based_on_status() | |
self.status.current_stage = "completed" | |
return self.status | |
except Exception as e: | |
self.logger.error(f"ํ์ดํ๋ผ์ธ ์คํ ์ค ์๋ฌ ๋ฐ์: {str(e)}") | |
self.status.current_stage = "error" | |
return self.status | |
def _check_sheet(self) -> bool: | |
"""๊ตฌ๊ธ ์ํธ ์ํ ์ฒดํฌ""" | |
self.logger.info("์ํธ ์ํ ์ฒดํฌ ์์") | |
model_added, benchmark_exists = self.sheet_checker.check_model_and_benchmark( | |
self.config.model_name, | |
self.config.benchmark_name | |
) | |
self.status.sheet_status = (model_added, benchmark_exists) | |
if model_added: | |
self.logger.info("์๋ก์ด ๋ชจ๋ธ์ด ์ถ๊ฐ๋์์ต๋๋ค") | |
if not benchmark_exists: | |
self.logger.info("๋ฒค์น๋งํฌ ์ธก์ ์ด ํ์ํฉ๋๋ค") | |
return True # ๋ฒค์น๋งํฌ ์ธก์ ์ด ํ์ํ ๊ฒฝ์ฐ๋ง ๋ค์ ๋จ๊ณ๋ก ์งํ | |
self.logger.info("์ด๋ฏธ ๋ฒค์น๋งํฌ๊ฐ ์กด์ฌํฉ๋๋ค. ํ์ดํ๋ผ์ธ์ ์ข ๋ฃํฉ๋๋ค.") | |
return False # ๋ฒค์น๋งํฌ๊ฐ ์ด๋ฏธ ์์ผ๋ฉด ์ฌ๊ธฐ์ ์ค๋จ | |
def _check_bench(self) -> bool: | |
"""๋ก์ปฌ ๋ฒค์น๋งํฌ ํ๊ฒฝ ์ฒดํฌ""" | |
self.logger.info("๋ฒค์น๋งํฌ ํ๊ฒฝ ์ฒดํฌ ์์") | |
self.status.bench_status = self.bench_checker.check_benchmark( | |
self.config.benchmark_name, | |
self.config.model_name, | |
self.cfg_prompt | |
) | |
self.status.bench_result = self.bench_checker.get_benchmark_status( | |
self.status.bench_status | |
) | |
# no bench ์ํ ๋ฒค์น๋ฅผ ๋๋ฆฐ์ ์ด ์์ ํด๋๊ตฌ์กฐ๋ ์์ | |
if self.status.bench_result == "no bench": | |
self.logger.error("๋ฒค์น๋งํฌ ์คํ์ ํ์ํ ๊ธฐ๋ณธ ํด๋๊ตฌ์กฐ๊ฐ ์์ต๋๋ค.") | |
return True | |
return True # ๊ทธ ์ธ์ ๊ฒฝ์ฐ๋ง ๋ค์ ๋จ๊ณ๋ก ์งํ | |
def _execute_based_on_status(self): | |
"""์ํ์ ๋ฐ๋ฅธ ์คํ ๋ก์ง""" | |
if self.status.bench_result == "all_passed": | |
self._execute_full_pipeline() | |
elif self.status.bench_result == "no_vectors": | |
self._execute_vector_generation() | |
elif self.status.bench_result == "no_metrics": | |
self._execute_metrics_generation() | |
else: | |
self._execute_vector_generation() | |
self.logger.warning("ํด๋๊ตฌ์กฐ๊ฐ ์์ต๋๋ค") | |
def _execute_full_pipeline(self): | |
"""๋ชจ๋ ์กฐ๊ฑด์ด ์ถฉ์กฑ๋ ๊ฒฝ์ฐ์ ์คํ ๋ก์ง""" | |
self.logger.info("์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ ์ค...") | |
pia_benchmark = PiaBenchMark( | |
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" , | |
model_name=self.config.model_name, | |
cfg_target_path= self.config.cfg_target_path , | |
token=self.access_token ) | |
pia_benchmark.preprocess_structure() | |
print("Categories identified:", pia_benchmark.categories) | |
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path, | |
label_dir=pia_benchmark.dataset_path, | |
save_dir=pia_benchmark.metric_path) | |
self.bench_result_dict = metric.evaluate() | |
def _execute_vector_generation(self): | |
"""๋ฒกํฐ ์์ฑ์ด ํ์ํ ๊ฒฝ์ฐ์ ์คํ ๋ก์ง""" | |
self.logger.info("๋ฒกํฐ ์์ฑ ์ค...") | |
# ๊ตฌํ ํ์ | |
pia_benchmark = PiaBenchMark( | |
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" , | |
model_name=self.config.model_name, | |
cfg_target_path= self.config.cfg_target_path , | |
token=self.access_token ) | |
pia_benchmark.preprocess_structure() | |
pia_benchmark.preprocess_label_to_csv() | |
print("Categories identified:", pia_benchmark.categories) | |
pia_benchmark.extract_visual_vector() | |
detector = EventDetector(config_path=self.config.cfg_target_path, | |
model_name=self.config.model_name , | |
token=pia_benchmark.token) | |
detector.process_and_save_predictions(pia_benchmark.vector_video_path, | |
pia_benchmark.dataset_path, | |
pia_benchmark.alram_path) | |
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path, | |
label_dir=pia_benchmark.dataset_path, | |
save_dir=pia_benchmark.metric_path) | |
self.bench_result_dict = metric.evaluate() | |
def _execute_metrics_generation(self): | |
"""๋ฉํธ๋ฆญ ์์ฑ์ด ํ์ํ ๊ฒฝ์ฐ์ ์คํ ๋ก์ง""" | |
self.logger.info("๋ฉํธ๋ฆญ ์์ฑ ์ค...") | |
# ๊ตฌํ ํ์ | |
pia_benchmark = PiaBenchMark( | |
benchmark_path = f"/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench/{self.config.benchmark_name}" , | |
model_name=self.config.model_name, | |
cfg_target_path= self.config.cfg_target_path , | |
token=self.access_token ) | |
pia_benchmark.preprocess_structure() | |
pia_benchmark.preprocess_label_to_csv() | |
print("Categories identified:", pia_benchmark.categories) | |
detector = EventDetector(config_path=self.config.cfg_target_path, | |
model_name=self.config.model_name , | |
token=pia_benchmark.token) | |
detector.process_and_save_predictions(pia_benchmark.vector_video_path, | |
pia_benchmark.dataset_path, | |
pia_benchmark.alram_path) | |
metric = MetricsEvaluator(pred_dir=pia_benchmark.alram_path, | |
label_dir=pia_benchmark.dataset_path, | |
save_dir=pia_benchmark.metric_path) | |
self.bench_result_dict = metric.evaluate() | |
if __name__ == "__main__": | |
# ํ์ดํ๋ผ์ธ ์ค์ | |
config = PipelineConfig( | |
model_name="T2V_CLIP4CLIP_MSRVTT", | |
benchmark_name="PIA", | |
cfg_target_path="topk.json", | |
base_path="/mnt/nas_192tb/videos/huggingface_benchmarks_dataset/Leaderboard_bench" | |
) | |
# ํ์ดํ๋ผ์ธ ์คํ | |
pipeline = BenchmarkPipeline(config) | |
result = pipeline.run() | |
print(f"\nํ์ดํ๋ผ์ธ ์คํ ๊ฒฐ๊ณผ:") | |
print(str(result)) |