File size: 2,210 Bytes
1b65314
 
 
 
 
 
 
 
 
 
 
 
 
 
3b40174
 
1b65314
 
 
3b40174
8c164d9
1b65314
 
 
 
 
3b40174
 
1b65314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import sys
sys.path.append(os.path.abspath('./modules/ultralytics'))

from transformers import AutoTokenizer, AutoModel, AutoProcessor
from modules.mast3r.model import AsymmetricMASt3R

from modules.sam2.build_sam import build_sam2_video_predictor
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
from modules.mobilesamv2 import sam_model_registry

class Models:
    def __init__(self, device):
        # -- mast3r --
        # MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
        MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
        self.mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)

        # -- sam2 --
        # SAM2_CKP = "./checkpoints/sam2.1_hiera_large.pt"
        SAM2_CKP = 'hujiecpp/sam2-1-hiera-large'
        SAM2_CONFIG = "./configs/sam2.1/sam2.1_hiera_l.yaml"
        self.sam2 = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CKP, device=device, apply_postprocessing=False)
        self.sam2.eval()

        # -- mobilesamv2 & sam1 --
        # SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
        SAM1_ENCODER_CKP = 'facebook/sam-vit-huge'
        SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
        self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
        image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
        prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
        self.mobilesamv2.prompt_encoder = prompt_encoder
        self.mobilesamv2.mask_decoder = mask_decoder
        self.mobilesamv2.image_encoder=image_encoder
        self.mobilesamv2.to(device=device)
        self.mobilesamv2.eval()

        # -- yolov8 --
        YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
        self.yolov8 = ObjectAwareModel(YOLO8_CKP)

        # -- siglip --
        self.siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
        self.siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256", device_map=device)
        self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256", device_map=device)