|
import os |
|
import numpy as np |
|
import torch |
|
from typing import Dict, List, Tuple |
|
from devmacs_core.devmacs_core import DevMACSCore |
|
|
|
|
|
from devmacs_core.utils.common.cal import loose_similarity |
|
from utils.parser import load_config, PromptManager |
|
import json |
|
import pandas as pd |
|
from tqdm import tqdm |
|
import logging |
|
from datetime import datetime |
|
from utils.except_dir import cust_listdir |
|
|
|
class EventDetector: |
|
def __init__(self, config_path: str , model_name:str = None, token:str = None): |
|
self.config = load_config(config_path) |
|
self.macs = DevMACSCore.from_huggingface(token=token, repo_id=f"PIA-SPACE-LAB/{model_name}") |
|
|
|
|
|
self.prompt_manager = PromptManager(config_path) |
|
self.sentences = self.prompt_manager.sentences |
|
self.text_vectors = self.macs.get_text_vector(self.sentences) |
|
|
|
def process_and_save_predictions(self, vector_base_dir: str, label_base_dir: str, save_base_dir: str): |
|
"""๋น๋์ค ๋ฒกํฐ๋ฅผ ์ฒ๋ฆฌํ๊ณ ๊ฒฐ๊ณผ๋ฅผ CSV๋ก ์ ์ฅ""" |
|
|
|
|
|
total_videos = sum(len([f for f in cust_listdir(os.path.join(vector_base_dir, d)) |
|
if f.endswith('.npy')]) |
|
for d in cust_listdir(vector_base_dir) |
|
if os.path.isdir(os.path.join(vector_base_dir, d))) |
|
pbar = tqdm(total=total_videos, desc="Processing videos") |
|
|
|
for category in cust_listdir(vector_base_dir): |
|
category_path = os.path.join(vector_base_dir, category) |
|
if not os.path.isdir(category_path): |
|
continue |
|
|
|
|
|
save_category_dir = os.path.join(save_base_dir, category) |
|
os.makedirs(save_category_dir, exist_ok=True) |
|
|
|
for file in cust_listdir(category_path): |
|
if file.endswith('.npy'): |
|
video_name = os.path.splitext(file)[0] |
|
vector_path = os.path.join(category_path, file) |
|
|
|
|
|
label_path = os.path.join(label_base_dir, category, f"{video_name}.json") |
|
with open(label_path, 'r') as f: |
|
label_data = json.load(f) |
|
total_frames = label_data['video_info']['total_frame'] |
|
|
|
|
|
self._process_and_save_single_video( |
|
vector_path=vector_path, |
|
total_frames=total_frames, |
|
save_path=os.path.join(save_category_dir, f"{video_name}.csv") |
|
) |
|
pbar.update(1) |
|
pbar.close() |
|
|
|
def _process_and_save_single_video(self, vector_path: str, total_frames: int, save_path: str): |
|
"""๋จ์ผ ๋น๋์ค ์ฒ๋ฆฌ ๋ฐ ์ ์ฅ""" |
|
|
|
sparse_predictions = self._process_single_vector(vector_path) |
|
|
|
|
|
df = self._expand_predictions(sparse_predictions, total_frames) |
|
|
|
|
|
df.to_csv(save_path, index=False) |
|
|
|
def _process_single_vector(self, vector_path: str) -> Dict: |
|
"""๊ธฐ์กด ์์ธก ๋ก์ง""" |
|
video_vector = np.load(vector_path) |
|
processed_vectors = [] |
|
frame_interval = 15 |
|
|
|
for vector in video_vector: |
|
v = vector.squeeze(0) |
|
v = torch.from_numpy(v).unsqueeze(0).cuda() |
|
processed_vectors.append(v) |
|
|
|
frame_results = {} |
|
for vector_idx, v in enumerate(processed_vectors): |
|
actual_frame = vector_idx * frame_interval |
|
sim_scores = loose_similarity( |
|
sequence_output=self.text_vectors.cuda(), |
|
visual_output=v.unsqueeze(1) |
|
) |
|
frame_results[actual_frame] = self._calculate_alarms(sim_scores) |
|
|
|
return frame_results |
|
|
|
def _expand_predictions(self, sparse_predictions: Dict, total_frames: int) -> pd.DataFrame: |
|
"""์์ธก์ ์ ์ฒด ํ๋ ์์ผ๋ก ํ์ฅ""" |
|
|
|
first_frame = list(sparse_predictions.keys())[0] |
|
categories = list(sparse_predictions[first_frame].keys()) |
|
|
|
|
|
df = pd.DataFrame({'frame': range(total_frames)}) |
|
|
|
|
|
for category in categories: |
|
df[category] = 0 |
|
|
|
|
|
frame_keys = sorted(sparse_predictions.keys()) |
|
for i in range(len(frame_keys)): |
|
current_frame = frame_keys[i] |
|
next_frame = frame_keys[i + 1] if i + 1 < len(frame_keys) else total_frames |
|
|
|
|
|
for category in categories: |
|
alarm_value = sparse_predictions[current_frame][category]['alarm'] |
|
df.loc[current_frame:next_frame-1, category] = alarm_value |
|
|
|
return df |
|
|
|
|
|
def _calculate_alarms(self, sim_scores: torch.Tensor) -> Dict: |
|
"""์ ์ฌ๋ ์ ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ ์ด๋ฒคํธ์ ์๋ ์ํ ๊ณ์ฐ""" |
|
|
|
log_filename = f"alarm_calculation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" |
|
logging.basicConfig( |
|
filename=log_filename, |
|
level=logging.ERROR, |
|
format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
event_alarms = {} |
|
|
|
for event_config in self.config['PROMPT_CFG']: |
|
event = event_config['event'] |
|
top_k = event_config['top_candidates'] |
|
threshold = event_config['alert_threshold'] |
|
|
|
|
|
|
|
|
|
event_prompts = self._get_event_prompts(event) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
event_scores = sim_scores[event_prompts['indices']].squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
top_k_values, top_k_indices = torch.topk(event_scores, min(top_k, len(event_scores))) |
|
|
|
|
|
|
|
|
|
|
|
for idx, (value, index) in enumerate(zip(top_k_values, top_k_indices)): |
|
|
|
prompt_type = event_prompts['types'][index] |
|
|
|
|
|
|
|
abnormal_count = sum(1 for idx in top_k_indices |
|
if event_prompts['types'][idx] == 'abnormal') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alarm_result = 1 if abnormal_count >= threshold else 0 |
|
|
|
|
|
|
|
event_alarms[event] = { |
|
'alarm': alarm_result, |
|
'scores': top_k_values.tolist(), |
|
'top_k_types': [event_prompts['types'][idx.item()] for idx in top_k_indices] |
|
} |
|
|
|
|
|
logging.shutdown() |
|
|
|
return event_alarms |
|
|
|
def _get_event_prompts(self, event: str) -> Dict: |
|
indices = [] |
|
types = [] |
|
current_idx = 0 |
|
|
|
for event_config in self.config['PROMPT_CFG']: |
|
if event_config['event'] == event: |
|
for status in ['normal', 'abnormal']: |
|
for _ in range(len(event_config['prompts'][status])): |
|
indices.append(current_idx) |
|
types.append(status) |
|
current_idx += 1 |
|
|
|
return {'indices': indices, 'types': types} |
|
|
|
|
|
|