Spaces:
Running
Running
import argparse | |
import os | |
import random | |
import cv2 | |
import imageio | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from loguru import logger | |
from PIL import Image | |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry | |
from tqdm import tqdm | |
# use bfloat16 for the entire notebook | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.build_sam import build_sam2, build_sam2_video_predictor | |
def show_anns(anns, borders=True): | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) | |
img[:,:,3] = 0 | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.concatenate([np.random.random(3), [0.5]]) | |
img[m] = color_mask | |
if borders: | |
import cv2 | |
contours, _ = cv2.findContours(m.astype(np.uint8),cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
# Try to smooth contours | |
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] | |
cv2.drawContours(img, contours, -1, (0,0,1,0.4), thickness=1) | |
ax.imshow(img) | |
def mask_nms(masks, scores, iou_thr=0.7, score_thr=0.1, inner_thr=0.2, **kwargs): | |
""" | |
Perform mask non-maximum suppression (NMS) on a set of masks based on their scores. | |
Args: | |
masks (torch.Tensor): has shape (num_masks, H, W) | |
scores (torch.Tensor): The scores of the masks, has shape (num_masks,) | |
iou_thr (float, optional): The threshold for IoU. | |
score_thr (float, optional): The threshold for the mask scores. | |
inner_thr (float, optional): The threshold for the overlap rate. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
selected_idx (torch.Tensor): A tensor representing the selected indices of the masks after NMS. | |
""" | |
scores, idx = scores.sort(0, descending=True) | |
num_masks = idx.shape[0] | |
masks_ord = masks[idx.view(-1), :] | |
masks_area = torch.sum(masks_ord, dim=(1, 2), dtype=torch.float) | |
mask_chunk_size = 20 | |
mask_chunks = masks_ord.split(mask_chunk_size, dim=0) | |
area_chunks = masks_area.split(mask_chunk_size, dim=0) | |
iou_matrix = [] | |
inner_iou_matrix = [] | |
for i_areas, i_chunk in zip(area_chunks, mask_chunks): | |
row_iou_matrix = [] | |
row_inner_iou_matrix = [] | |
for j_areas, j_chunk in zip(area_chunks, mask_chunks): | |
intersection = torch.logical_and(i_chunk.unsqueeze(1), j_chunk.unsqueeze(0)).sum(dim=(-1, -2)) | |
union = torch.logical_or(i_chunk.unsqueeze(1), j_chunk.unsqueeze(0)).sum(dim=(-1, -2)) | |
local_iou_mat = intersection / union | |
row_iou_matrix.append(local_iou_mat) | |
row_inter_mat = intersection / i_areas[:, None] | |
col_inter_mat = intersection / j_areas[None, :] | |
inter = torch.logical_and(row_inter_mat < 0.5, col_inter_mat >= 0.85) | |
local_inner_iou_mat = torch.zeros((len(i_areas), len(j_areas))) | |
local_inner_iou_mat[inter] = 1 - row_inter_mat[inter] * col_inter_mat[inter] | |
row_inner_iou_matrix.append(local_inner_iou_mat) | |
row_iou_matrix = torch.cat(row_iou_matrix, dim=1) | |
row_inner_iou_matrix = torch.cat(row_inner_iou_matrix, dim=1) | |
iou_matrix.append(row_iou_matrix) | |
inner_iou_matrix.append(row_inner_iou_matrix) | |
iou_matrix = torch.cat(iou_matrix, dim=0) | |
inner_iou_matrix = torch.cat(inner_iou_matrix, dim=0) | |
iou_matrix.triu_(diagonal=1) | |
iou_max, _ = iou_matrix.max(dim=0) | |
inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1) | |
inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0) | |
inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=1) | |
inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0) | |
keep = iou_max <= iou_thr | |
keep_conf = scores > score_thr | |
keep_inner_u = inner_iou_max_u <= 1 - inner_thr | |
keep_inner_l = inner_iou_max_l <= 1 - inner_thr | |
if keep_conf.sum() == 0: | |
index = scores.topk(3).indices | |
keep_conf[index, 0] = True | |
if keep_inner_u.sum() == 0: | |
index = scores.topk(3).indices | |
keep_inner_u[index, 0] = True | |
if keep_inner_l.sum() == 0: | |
index = scores.topk(3).indices | |
keep_inner_l[index, 0] = True | |
keep *= keep_conf | |
keep *= keep_inner_u | |
keep *= keep_inner_l | |
selected_idx = idx[keep] | |
return selected_idx | |
def filter(keep: torch.Tensor, masks_result) -> None: | |
keep = keep.int().cpu().numpy() | |
result_keep = [] | |
for i, m in enumerate(masks_result): | |
if i in keep: result_keep.append(m) | |
return result_keep | |
def masks_update(*args, **kwargs): | |
# remove redundant masks based on the scores and overlap rate between masks | |
masks_new = () | |
for masks_lvl in (args): | |
if isinstance(masks_lvl, tuple): | |
masks_lvl = masks_lvl[0] # If it's a tuple, take the first element | |
if len(masks_lvl) == 0: | |
masks_new += (masks_lvl,) | |
continue | |
# Check if masks_lvl is a list of dictionaries | |
if isinstance(masks_lvl[0], dict): | |
seg_pred = torch.from_numpy(np.stack([m['segmentation'] for m in masks_lvl], axis=0)) | |
iou_pred = torch.from_numpy(np.stack([m['predicted_iou'] for m in masks_lvl], axis=0)) | |
stability = torch.from_numpy(np.stack([m['stability_score'] for m in masks_lvl], axis=0)) | |
else: | |
# If it's a direct list of masks, use them directly | |
seg_pred = torch.from_numpy(np.stack(masks_lvl, axis=0)) | |
# Create default values for cases without iou and stability | |
iou_pred = torch.ones(len(masks_lvl)) | |
stability = torch.ones(len(masks_lvl)) | |
scores = stability * iou_pred | |
keep_mask_nms = mask_nms(seg_pred, scores, **kwargs) | |
masks_lvl = filter(keep_mask_nms, masks_lvl) | |
masks_new += (masks_lvl,) | |
return masks_new | |
def show_mask(mask, ax, obj_id=None, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
cmap = plt.get_cmap("tab20") | |
cmap_idx = 0 if obj_id is None else obj_id | |
color = np.array([*cmap(cmap_idx)[:3], 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def save_mask(mask,frame_idx,save_dir): | |
image_array = (mask * 255).astype(np.uint8) | |
# Create image object | |
image = Image.fromarray(image_array[0]) | |
# Save image | |
image.save(os.path.join(save_dir,f'{frame_idx:03}.png')) | |
def save_masks(mask_list,frame_idx,save_dir): | |
os.makedirs(save_dir,exist_ok=True) | |
if len(mask_list[0].shape) == 3: | |
# Calculate dimensions for concatenated image | |
total_width = mask_list[0].shape[2] * len(mask_list) | |
max_height = mask_list[0].shape[1] | |
# Create large image | |
final_image = Image.new('RGB', (total_width, max_height)) | |
for i, img in enumerate(mask_list): | |
img = Image.fromarray((img[0] * 255).astype(np.uint8)).convert("RGB") | |
final_image.paste(img, (i * img.width, 0)) | |
final_image.save(os.path.join(save_dir,f"mask_{frame_idx:03}.png")) | |
else: | |
# Calculate dimensions for concatenated image | |
total_width = mask_list[0].shape[1] * len(mask_list) | |
max_height = mask_list[0].shape[0] | |
# Create large image | |
final_image = Image.new('RGB', (total_width, max_height)) | |
for i, img in enumerate(mask_list): | |
img = Image.fromarray((img * 255).astype(np.uint8)).convert("RGB") | |
final_image.paste(img, (i * img.width, 0)) | |
final_image.save(os.path.join(save_dir,f"mask_{frame_idx:03}.png")) | |
def save_masks_npy(mask_list,frame_idx,save_dir): | |
np.save(os.path.join(save_dir,f"mask_{frame_idx:03}.npy"),np.array(mask_list)) | |
def show_points(coords, labels, ax, marker_size=200): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def make_enlarge_bbox(origin_bbox, max_width,max_height,ratio): | |
width = origin_bbox[2] | |
height = origin_bbox[3] | |
new_box = [max(origin_bbox[0]-width*(ratio-1)/2,0),max(origin_bbox[1]-height*(ratio-1)/2,0)] | |
new_box.append(min(width*ratio,max_width-new_box[0])) | |
new_box.append(min(height*ratio,max_height-new_box[1])) | |
return new_box | |
def sample_points(masks, enlarge_bbox,positive_num=1,negtive_num=40): | |
ex, ey, ewidth, eheight = enlarge_bbox | |
positive_count = positive_num | |
negtive_count = negtive_num | |
output_points = [] | |
while True: | |
x = int(np.random.uniform(ex, ex + ewidth)) | |
y = int(np.random.uniform(ey, ey + eheight)) | |
if masks[y][x]==True and positive_count>0: | |
output_points.append((x,y,1)) | |
positive_count-=1 | |
elif masks[y][x]==False and negtive_count>0: | |
output_points.append((x,y,0)) | |
negtive_count-=1 | |
if positive_count == 0 and negtive_count == 0: | |
break | |
return output_points | |
def sample_points_from_mask(mask): | |
# Get indices of all True values | |
true_indices = np.argwhere(mask) | |
# Check if there are any True values | |
if true_indices.size == 0: | |
raise ValueError("The mask does not contain any True values.") | |
# Randomly select a point from True value indices | |
random_index = np.random.choice(len(true_indices)) | |
sample_point = true_indices[random_index] | |
return tuple(sample_point) | |
def search_new_obj(masks_from_prev, mask_list,other_masks_list=None,mask_ratio_thresh=0,ratio=0.5, area_threash = 5000): | |
new_mask_list = [] | |
# Calculate mask_none, representing areas not included in any previous masks | |
mask_none = ~masks_from_prev[0].copy()[0] | |
for prev_mask in masks_from_prev[1:]: | |
mask_none &= ~prev_mask[0] | |
for mask in mask_list: | |
seg = mask['segmentation'] | |
if (mask_none & seg).sum()/seg.sum() > ratio and seg.sum() > area_threash: | |
new_mask_list.append(mask) | |
for mask in new_mask_list: | |
mask_none &= ~mask['segmentation'] | |
logger.info(len(new_mask_list)) | |
logger.info("now ratio:",mask_none.sum() / (mask_none.shape[0] * mask_none.shape[1]) ) | |
logger.info("expected ratios:",mask_ratio_thresh) | |
if other_masks_list is not None: | |
for mask in other_masks_list: | |
if mask_none.sum() / (mask_none.shape[0] * mask_none.shape[1]) > mask_ratio_thresh: # Still a lot of gaps, greater than current thresh | |
seg = mask['segmentation'] | |
if (mask_none & seg).sum()/seg.sum() > ratio and seg.sum() > area_threash: | |
new_mask_list.append(mask) | |
mask_none &= ~seg | |
else: | |
break | |
logger.info(len(new_mask_list)) | |
return new_mask_list | |
def get_bbox_from_mask(mask): | |
# Get row and column indices of non-zero elements | |
rows = np.any(mask, axis=1) | |
cols = np.any(mask, axis=0) | |
# Find min and max indices of non-zero rows and columns | |
ymin, ymax = np.where(rows)[0][[0, -1]] | |
xmin, xmax = np.where(cols)[0][[0, -1]] | |
# Calculate width and height | |
width = xmax - xmin + 1 | |
height = ymax - ymin + 1 | |
return xmin, ymin, width, height | |
def cal_no_mask_area_ratio(out_mask_list): | |
h = out_mask_list[0].shape[1] | |
w = out_mask_list[0].shape[2] | |
mask_none = ~out_mask_list[0].copy() | |
for prev_mask in out_mask_list[1:]: | |
mask_none &= ~prev_mask | |
return(mask_none.sum() / (h * w)) | |
class Prompts: | |
def __init__(self,bs:int): | |
self.batch_size = bs | |
self.prompts = {} | |
self.obj_list = [] | |
self.key_frame_list = [] | |
self.key_frame_obj_begin_list = [] | |
def add(self,obj_id,frame_id,mask): | |
if obj_id not in self.obj_list: | |
new_obj = True | |
self.prompts[obj_id] = [] | |
self.obj_list.append(obj_id) | |
else: | |
new_obj = False | |
self.prompts[obj_id].append((frame_id,mask)) | |
if frame_id not in self.key_frame_list and new_obj: | |
# import ipdb; ipdb.set_trace() | |
self.key_frame_list.append(frame_id) | |
self.key_frame_obj_begin_list.append(obj_id) | |
logger.info("key_frame_obj_begin_list:",self.key_frame_obj_begin_list) | |
def get_obj_num(self): | |
return len(self.obj_list) | |
def __len__(self): | |
if self.obj_list % self.batch_size == 0: | |
return len(self.obj_list) // self.batch_size | |
else: | |
return len(self.obj_list) // self.batch_size +1 | |
def __iter__(self): | |
# self.batch_index = 0 | |
self.start_idx = 0 | |
self.iter_frameindex = 0 | |
return self | |
def __next__(self): | |
if self.start_idx < len(self.obj_list): | |
if self.iter_frameindex == len(self.key_frame_list)-1: | |
end_idx = min(self.start_idx+self.batch_size, len(self.obj_list)) | |
else: | |
if self.start_idx+self.batch_size < self.key_frame_obj_begin_list[self.iter_frameindex+1]: | |
end_idx = self.start_idx+self.batch_size | |
else: | |
end_idx = self.key_frame_obj_begin_list[self.iter_frameindex+1] | |
self.iter_frameindex+=1 | |
# end_idx = min(self.start_idx+self.batch_size, self.key_frame_obj_begin_list[self.iter_frameindex+1]) | |
batch_keys = self.obj_list[self.start_idx:end_idx] | |
batch_prompts = {key: self.prompts[key] for key in batch_keys} | |
self.start_idx = end_idx | |
return batch_prompts | |
# if self.batch_index * self.batch_size < len(self.obj_list): | |
# start_idx = self.batch_index * self.batch_size | |
# end_idx = min(start_idx + self.batch_size, len(self.obj_list)) | |
# batch_keys = self.obj_list[start_idx:end_idx] | |
# batch_prompts = {key: self.prompts[key] for key in batch_keys} | |
# self.batch_index += 1 | |
# return batch_prompts | |
else: | |
raise StopIteration | |
def get_video_segments(prompts_loader,predictor,inference_state,final_output=False): | |
video_segments = {} | |
for batch_prompts in tqdm(prompts_loader,desc="processing prompts\n"): | |
predictor.reset_state(inference_state) | |
for id, prompt_list in batch_prompts.items(): | |
for prompt in prompt_list: | |
# import ipdb; ipdb.set_trace() | |
_, out_obj_ids, out_mask_logits = predictor.add_new_mask( | |
inference_state=inference_state, | |
frame_idx=prompt[0], | |
obj_id=id, | |
mask=prompt[1] | |
) | |
# start_frame_idx = 0 if final_output else None | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): | |
if out_frame_idx not in video_segments: | |
video_segments[out_frame_idx] = { } | |
for i, out_obj_id in enumerate(out_obj_ids): | |
video_segments[out_frame_idx][out_obj_id]= (out_mask_logits[i] > 0.0).cpu().numpy() | |
if final_output: | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state,reverse=True): | |
for i, out_obj_id in enumerate(out_obj_ids): | |
video_segments[out_frame_idx][out_obj_id]= (out_mask_logits[i] > 0.0).cpu().numpy() | |
return video_segments | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--video_path",type=str,required=True) | |
parser.add_argument("--output_dir",type=str,required=True) | |
parser.add_argument("--level",choices=['default','small','middle','large']) | |
parser.add_argument("--batch_size",type=int,default=20) | |
parser.add_argument("--sam1_checkpoint",type=str,default="/home/lff/bigdata1/cjw/checkpoints/sam/sam_vit_h_4b8939.pth") | |
parser.add_argument("--sam2_checkpoint",type=str,default="/home/lff/bigdata1/cjw/checkpoints/sam2/sam2_hiera_large.pt") | |
parser.add_argument("--detect_stride",type=int,default=10) | |
parser.add_argument("--use_other_level",type=int,default=1) | |
parser.add_argument("--postnms",type=int,default=1) | |
parser.add_argument("--pred_iou_thresh",type=float,default=0.7) | |
parser.add_argument("--box_nms_thresh",type=float,default=0.7) | |
parser.add_argument("--stability_score_thresh",type=float,default=0.85) | |
parser.add_argument("--reverse", action="store_true") | |
level_dict = { | |
"default": 0, | |
"small": 1, | |
"middle": 2, | |
"large": 3 | |
} | |
args = parser.parse_args() | |
logger.add(os.path.join(args.output_dir,f'{args.level}.log'), rotation="500 MB") | |
logger.info(args) | |
video_dir = args.video_path | |
level = args.level | |
base_dir = args.output_dir | |
##### load Sam2 and Sam1 Model ##### | |
sam2_checkpoint = args.sam2_checkpoint | |
model_cfg = "sam2_hiera_l.yaml" | |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) | |
sam2 = build_sam2(model_cfg, sam2_checkpoint, device='cuda', apply_postprocessing=False) | |
sam_ckpt_path = args.sam1_checkpoint | |
sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt_path).to('cuda') | |
mask_generator = SamAutomaticMaskGenerator( | |
model=sam, | |
points_per_side=32, | |
pred_iou_thresh=args.pred_iou_thresh, | |
box_nms_thresh=args.box_nms_thresh, | |
stability_score_thresh=args.stability_score_thresh, | |
crop_n_layers=1, | |
crop_n_points_downscale_factor=1, | |
min_mask_region_area=100, | |
) | |
# scan all the JPEG frame names in this directory | |
frame_names = [ | |
p for p in os.listdir(video_dir) | |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"] | |
] | |
try: | |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]), reverse=args.reverse) | |
except: | |
frame_names.sort(key=lambda p: os.path.splitext(p)[0], reverse=args.reverse) | |
now_frame = 0 | |
inference_state = predictor.init_state(video_path=video_dir) | |
masks_from_prev = [] | |
sum_id = 0 # Record total number of objects | |
prompts_loader = Prompts(bs=args.batch_size) # hold all the clicks we add for visualization | |
while True: | |
logger.info(f"frame: {now_frame}") | |
sum_id = prompts_loader.get_obj_num() | |
image_path = os.path.join(video_dir,frame_names[now_frame]) | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# resize if the input is too large: | |
orig_h, orig_w = image.shape[:2] | |
if orig_h > 1080: | |
logger.info("Resizing original image to 1080P...") | |
scale = 1080 / orig_h | |
h = int(orig_h * scale) | |
w = int(orig_w * scale) | |
image = cv2.resize(image, (w, h)) | |
# Generate only large masks | |
# masks_l = mask_generator.generate_l(image) | |
all_masks = mask_generator.generate(image) | |
masks = all_masks[level_dict[args.level]] | |
# masks_l = mask_generator.generate(image) | |
if args.postnms: | |
# # Pass masks_l directly, no need to wrap in tuple | |
# # masks_l = masks_update(masks_l, iou_thr=0.8, score_thr=0.7, inner_thr=0.5)[0] | |
masks = masks_update(masks, iou_thr=0.8, score_thr=0.7, inner_thr=0.5)[0] | |
# Use large level masks | |
# masks = masks_l | |
other_masks = None | |
if not args.use_other_level: | |
other_masks = None | |
if now_frame == 0: # first frame | |
ann_obj_id_list = range(len(masks)) | |
for ann_obj_id in tqdm(ann_obj_id_list): | |
seg = masks[ann_obj_id]['segmentation'] | |
prompts_loader.add(ann_obj_id,0,seg) | |
else: | |
new_mask_list = search_new_obj(masks_from_prev, masks, other_masks,mask_ratio_thresh) | |
logger.info(f"number of new obj: {len(new_mask_list)}") | |
for id,mask in enumerate(masks_from_prev): | |
if mask.sum() == 0: | |
continue | |
prompts_loader.add(id,now_frame,mask[0]) | |
for i in range(len(new_mask_list)): | |
new_mask = new_mask_list[i]['segmentation'] | |
prompts_loader.add(sum_id+i,now_frame,new_mask) | |
logger.info(f"obj num: {prompts_loader.get_obj_num()}") | |
if now_frame==0 or len(new_mask_list)!=0: | |
video_segments = get_video_segments(prompts_loader,predictor,inference_state) | |
vis_frame_stride = args.detect_stride | |
max_area_no_mask = (0,-1) | |
for out_frame_idx in tqdm(range(0, len(frame_names), vis_frame_stride)): | |
if out_frame_idx < now_frame: | |
continue | |
out_mask_list = [] | |
for out_obj_id, out_mask in video_segments[out_frame_idx].items(): | |
out_mask_list.append(out_mask) | |
no_mask_ratio = cal_no_mask_area_ratio(out_mask_list) | |
if now_frame == out_frame_idx: | |
mask_ratio_thresh = no_mask_ratio | |
logger.info(f"mask_ratio_thresh: {mask_ratio_thresh}") | |
if no_mask_ratio > mask_ratio_thresh + 0.01 and out_frame_idx > now_frame: | |
masks_from_prev = out_mask_list | |
max_area_no_mask = (no_mask_ratio, out_frame_idx) | |
logger.info(max_area_no_mask) | |
break | |
if max_area_no_mask[1] == -1: | |
break | |
logger.info("max_area_no_mask:", max_area_no_mask) | |
now_frame = max_area_no_mask[1] | |
###### Final output ###### | |
save_dir = os.path.join(base_dir,level,"final-output") | |
os.makedirs(save_dir, exist_ok=True) | |
video_segments = get_video_segments(prompts_loader,predictor,inference_state,final_output=True) | |
for out_frame_idx in tqdm(range(0, len(frame_names), 1)): | |
out_mask_list = [] | |
for out_obj_id, out_mask in video_segments[out_frame_idx].items(): | |
out_mask_list.append(out_mask) | |
no_mask_ratio = cal_no_mask_area_ratio(out_mask_list) | |
logger.info(no_mask_ratio) | |
save_masks(out_mask_list, out_frame_idx,save_dir) | |
save_masks_npy(out_mask_list, out_frame_idx,save_dir) | |
###### Generate Visualization Frames ###### | |
logger.info("Start generating visualization frames...") | |
vis_save_dir = os.path.join(base_dir,level,'visualization','full-mask-npy') | |
os.makedirs(vis_save_dir,exist_ok=True) | |
frame_save_dir = os.path.join(base_dir,level,'visualization','frames') | |
os.makedirs(frame_save_dir, exist_ok=True) | |
# Read all npy files | |
npy_name_list = [] | |
for name in os.listdir(save_dir): | |
if 'npy' in name: | |
npy_name_list.append(name) | |
npy_name_list.sort() | |
logger.info(f"Found {len(npy_name_list)} npy files") | |
npy_list = [np.load(os.path.join(save_dir,name)) for name in npy_name_list] | |
image_list = [Image.open(os.path.join(video_dir,name)) for name in frame_names] | |
assert len(npy_list) == len(image_list), "Number of npy files does not match number of images" | |
logger.info(f"Processing {len(npy_list)} frames in total") | |
# Generate random colors | |
def generate_random_colors(num_colors): | |
colors = [] | |
for _ in range(num_colors): | |
reroll = True | |
iter_cnt = 0 | |
while reroll and iter_cnt < 100: | |
iter_cnt += 1 | |
reroll = False | |
color = tuple(random.randint(1, 255) for _ in range(3)) | |
for selected_color in colors: | |
if np.linalg.norm(np.array(color) - np.array(selected_color)) < 70: | |
reroll = True | |
break | |
colors.append(color) | |
return colors | |
num_masks = max(len(masks) for masks in npy_list) | |
colors = generate_random_colors(num_masks) | |
post_colors = [(0, 0, 0)] + colors | |
post_colors = np.array(post_colors) # [num_masks, 3] | |
np.save(os.path.join(base_dir, "colors.npy"), post_colors) | |
# Only process first and last frames | |
# frames_to_process = [0, -1] # Indices for first and last frames | |
for frame_idx in range(len(frame_names)): | |
# for frame_idx in frames_to_process: | |
masks = npy_list[frame_idx] | |
image = image_list[frame_idx] | |
image_np = np.array(image) | |
mask_combined = np.zeros_like(image_np, dtype=np.uint8) | |
for mask_id, mask in enumerate(masks): | |
mask = mask.squeeze(0) | |
mask_area = mask > 0 | |
mask_combined[mask_area, :] = colors[mask_id] | |
# Blend original image with colored mask | |
mask_combined = np.clip(mask_combined, 0, 255) | |
# blended_image = cv2.addWeighted(image_np, 0.7, mask_combined, 0.3, 0) | |
blended_image = mask_combined | |
# change the save path | |
frame_name = frame_names[frame_idx] | |
frame_save_dir = base_dir | |
output_path = os.path.join(frame_save_dir, frame_name) | |
Image.fromarray(blended_image).save(output_path) | |
logger.info(f"Frame saved to: {output_path}") |