import argparse import cv2 import cv2 import torch import gradio as gr from transformers import SamModel, SamProcessor import spaces import numpy as np from PIL import Image from tqdm import tqdm from torchvision.transforms import v2 from rynnec import disable_torch_init, model_init, mm_infer, mm_infer_segmentation from rynnec.mm_utils import annToMask, load_video, load_images from PIL import Image from tqdm import tqdm import numpy as np import colorsys import argparse def get_hsv_palette(n_colors): hues = np.linspace(0, 1, int(n_colors) + 1)[1:-1] s = 0.8 v = 0.9 palette = [(0.0, 0.0, 0.0)] + [ colorsys.hsv_to_rgb(h_i, s, v) for h_i in hues ] return (255 * np.asarray(palette)).astype("uint8") def colorize_masks(images, index_masks, fac: float = 0.8, draw_contour=True, edge_thickness=20): max_idx = max([m.max() for m in index_masks]) palette = get_hsv_palette(max_idx + 1) color_masks = [] out_frames = [] for img, mask in tqdm(zip(images, index_masks), desc='Visualize masks ...'): clr_mask = palette[mask.astype("int")] blended_img = img blended_img = compose_img_mask(blended_img, clr_mask, fac) if draw_contour: blended_img = draw_contours_on_image(blended_img, mask, clr_mask, brightness_factor=1.8, alpha=0.6, thickness=edge_thickness) out_frames.append(blended_img) return out_frames, color_masks def compose_img_mask(img, color_mask, fac: float = 0.5): mask_region = (color_mask.sum(axis=-1) > 0)[..., None] out_f = img.copy() / 255 out_f[mask_region[:, :, 0]] = fac * img[mask_region[:, :, 0]] / 255 + (1 - fac) * color_mask[mask_region[:, :, 0]] / 255 out_u = (255 * out_f).astype("uint8") return out_u def draw_contours_on_image(img, index_mask, color_mask, brightness_factor=1.6, alpha=0.5, thickness=2, ignore_index=0): img = img.astype("float32") overlay = img.copy() unique_indices = np.unique(index_mask) if ignore_index is not None: unique_indices = [idx for idx in unique_indices if idx != ignore_index] for i in unique_indices: bin_mask = (index_mask == i).astype("uint8") * 255 if bin_mask.sum() == 0: continue contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) color = color_mask[index_mask == i][0].astype("float32") bright_color = np.clip(color * brightness_factor, 0, 255).tolist() cv2.drawContours(overlay, contours, -1, bright_color, thickness) blended = (1 - alpha) * img + alpha * overlay return np.clip(blended, 0, 255).astype("uint8") def extract_first_frame_from_video(video): cap = cv2.VideoCapture(video) success, frame = cap.read() cap.release() if success: return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) return None def extract_points_from_mask(mask_pil): mask = np.asarray(mask_pil)[..., 0] coords = np.nonzero(mask) coords = np.stack((coords[1], coords[0]), axis=1) return coords def add_contour(img, mask, color=(1., 1., 1.)): img = img.copy() mask = mask.astype(np.uint8) * 255 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(img, contours, -1, color, thickness=8) return img def load_first_frame(video_path): cap = cv2.VideoCapture(video_path) ret, frame = cap.read() cap.release() if not ret: raise gr.Error("Could not read the video file.") frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(frame) return image def clear_masks(): return [], [], [], [] def clear_all(): return [], [], [], [], None, "", "" @spaces.GPU(duration=120) def apply_sam(image, input_points): inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device) with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0] scores = outputs.iou_scores[0, 0] mask_selection_index = scores.argmax() mask_np = masks[mask_selection_index].numpy() return mask_np @spaces.GPU(duration=120) def run(mode, images, timestamps, masks, mask_ids, instruction, mask_output_video): if mode == "QA": response = run_text_inference(images, timestamps, masks, mask_ids, instruction) else: response, mask_output_video = run_seg_inference(images, timestamps, instruction) return response, mask_output_video def run_text_inference(images, timestamps, masks, mask_ids, instruction): masks = torch.from_numpy(np.stack(masks, axis=0)) if "