diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2eeea76528d765d73b1b34c503831ee7d7d555 --- /dev/null +++ b/app.py @@ -0,0 +1,387 @@ +import argparse +import cv2 +import cv2 +import torch +import gradio as gr +from transformers import SamModel, SamProcessor + +import spaces +import numpy as np +from PIL import Image +from tqdm import tqdm +from torchvision.transforms import v2 + +from rynnec import disable_torch_init, model_init, mm_infer, mm_infer_segmentation +from rynnec.mm_utils import annToMask, load_video, load_images + +from PIL import Image +from tqdm import tqdm +import numpy as np +import colorsys +import argparse + + +def get_hsv_palette(n_colors): + hues = np.linspace(0, 1, int(n_colors) + 1)[1:-1] + s = 0.8 + v = 0.9 + palette = [(0.0, 0.0, 0.0)] + [ + colorsys.hsv_to_rgb(h_i, s, v) for h_i in hues + ] + return (255 * np.asarray(palette)).astype("uint8") + + +def colorize_masks(images, index_masks, fac: float = 0.8, draw_contour=True, edge_thickness=20): + max_idx = max([m.max() for m in index_masks]) + palette = get_hsv_palette(max_idx + 1) + color_masks = [] + out_frames = [] + for img, mask in tqdm(zip(images, index_masks), desc='Visualize masks ...'): + clr_mask = palette[mask.astype("int")] + blended_img = img + + blended_img = compose_img_mask(blended_img, clr_mask, fac) + + if draw_contour: + blended_img = draw_contours_on_image(blended_img, mask, clr_mask, + brightness_factor=1.8, + alpha=0.6, + thickness=edge_thickness) + out_frames.append(blended_img) + + return out_frames, color_masks + + +def compose_img_mask(img, color_mask, fac: float = 0.5): + mask_region = (color_mask.sum(axis=-1) > 0)[..., None] + out_f = img.copy() / 255 + out_f[mask_region[:, :, 0]] = fac * img[mask_region[:, :, 0]] / 255 + (1 - fac) * color_mask[mask_region[:, :, 0]] / 255 + out_u = (255 * out_f).astype("uint8") + return out_u + + +def draw_contours_on_image(img, index_mask, color_mask, brightness_factor=1.6, alpha=0.5, thickness=2, ignore_index=0): + img = img.astype("float32") + overlay = img.copy() + + unique_indices = np.unique(index_mask) + if ignore_index is not None: + unique_indices = [idx for idx in unique_indices if idx != ignore_index] + + for i in unique_indices: + bin_mask = (index_mask == i).astype("uint8") * 255 + if bin_mask.sum() == 0: + continue + + contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + color = color_mask[index_mask == i][0].astype("float32") + bright_color = np.clip(color * brightness_factor, 0, 255).tolist() + + cv2.drawContours(overlay, contours, -1, bright_color, thickness) + + blended = (1 - alpha) * img + alpha * overlay + return np.clip(blended, 0, 255).astype("uint8") + + +def extract_first_frame_from_video(video): + cap = cv2.VideoCapture(video) + success, frame = cap.read() + cap.release() + if success: + return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + return None + + +def extract_points_from_mask(mask_pil): + mask = np.asarray(mask_pil)[..., 0] + coords = np.nonzero(mask) + coords = np.stack((coords[1], coords[0]), axis=1) + + return coords + +def add_contour(img, mask, color=(1., 1., 1.)): + img = img.copy() + + mask = mask.astype(np.uint8) * 255 + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(img, contours, -1, color, thickness=8) + + return img + + +def load_first_frame(video_path): + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + cap.release() + if not ret: + raise gr.Error("Could not read the video file.") + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + image = Image.fromarray(frame) + return image + + +def clear_masks(): + return [], [], [], [] + +def clear_all(): + return [], [], [], [], None, "", "" + + +@spaces.GPU(duration=120) +def apply_sam(image, input_points): + inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = sam_model(**inputs) + + masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0] + scores = outputs.iou_scores[0, 0] + + mask_selection_index = scores.argmax() + + mask_np = masks[mask_selection_index].numpy() + + return mask_np + + +@spaces.GPU(duration=120) +def run(mode, images, timestamps, masks, mask_ids, instruction, mask_output_video): + if mode == "QA": + response = run_text_inference(images, timestamps, masks, mask_ids, instruction) + else: + response, mask_output_video = run_seg_inference(images, timestamps, instruction) + return response, mask_output_video + + +def run_text_inference(images, timestamps, masks, mask_ids, instruction): + masks = torch.from_numpy(np.stack(masks, axis=0)) + + if "