# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import torch from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf HF_MODEL_ID_TO_FILENAMES = { "facebook/sam2-hiera-tiny": ( "configs/sam2/sam2_hiera_t.yaml", "sam2_hiera_tiny.pt", ), "facebook/sam2-hiera-small": ( "configs/sam2/sam2_hiera_s.yaml", "sam2_hiera_small.pt", ), "facebook/sam2-hiera-base-plus": ( "configs/sam2/sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt", ), "facebook/sam2-hiera-large": ( "configs/sam2/sam2_hiera_l.yaml", "sam2_hiera_large.pt", ), "facebook/sam2.1-hiera-tiny": ( "configs/sam2.1/sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt", ), "facebook/sam2.1-hiera-small": ( "configs/sam2.1/sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt", ), "facebook/sam2.1-hiera-base-plus": ( "configs/sam2.1/sam2.1_hiera_b+.yaml", "sam2.1_hiera_base_plus.pt", ), "facebook/sam2.1-hiera-large": ( "configs/sam2.1/sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", ), } def get_best_available_device(): """ Get the best available device in the order: CUDA, MPS, CPU Returns: device string for torch.device """ if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" else: return "cpu" def build_sam2( config_file, ckpt_path=None, device=None, mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, **kwargs, ): # Use the provided device or get the best available one device = device or get_best_available_device() logging.info(f"Using device: {device}") if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", ] # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) if mode == "eval": model.eval() return model def build_sam2_video_predictor( config_file, ckpt_path=None, device=None, mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, **kwargs, ): # Use the provided device or get the best available one device = device or get_best_available_device() logging.info(f"Using device: {device}") hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking "++model.binarize_mask_from_pts_for_mem_enc=true", # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) "++model.fill_hole_area=8", ] hydra_overrides.extend(hydra_overrides_extra) # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) if mode == "eval": model.eval() return model def build_sam2_video_predictor_npz( config_file, ckpt_path=None, device=None, mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, **kwargs, ): # Use the provided device or get the best available one device = device or get_best_available_device() logging.info(f"Using device: {device}") hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor_npz.SAM2VideoPredictorNPZ", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking "++model.binarize_mask_from_pts_for_mem_enc=true", # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) "++model.fill_hole_area=8", ] hydra_overrides.extend(hydra_overrides_extra) # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) if mode == "eval": model.eval() return model def _hf_download(model_id): from huggingface_hub import hf_hub_download config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) return config_name, ckpt_path def build_sam2_hf(model_id, **kwargs): config_name, ckpt_path = _hf_download(model_id) return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) def build_sam2_video_predictor_hf(model_id, **kwargs): config_name, ckpt_path = _hf_download(model_id) return build_sam2_video_predictor( config_file=config_name, ckpt_path=ckpt_path, **kwargs ) def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] missing_keys, unexpected_keys = model.load_state_dict(sd) if missing_keys: logging.error(missing_keys) raise RuntimeError() if unexpected_keys: logging.error(unexpected_keys) raise RuntimeError() logging.info("Loaded checkpoint sucessfully")