File size: 1,552 Bytes
19dc712
 
 
 
 
 
a27a49d
19dc712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import AutoModel, AutoProcessor
from pathlib import Path
from image_adapter import ImageAdapter
import torch

CLIP_PATH = "google/siglip-so400m-patch14-384"
CHECKPOINT_PATH = Path("Adieee5/Image-captioning")
# CHECKPOINT_PATH = Path("cheackpoints")


def initialize_models():
    """Initialize and load all models"""
    print("Loading CLIP")
    clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
    clip_model = AutoModel.from_pretrained(CLIP_PATH)
    clip_model = clip_model.vision_model

    if (CHECKPOINT_PATH / "clip_model.pt").exists():
        print("Loading VLM's custom vision model")
        checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
        checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
        clip_model.load_state_dict(checkpoint)
        del checkpoint
    else:
        print("Custom CLIP weights not found, using default weights")

    clip_model.eval()
    clip_model.requires_grad_(False)
    clip_model.to("cpu")

    image_adapter = None
    if (CHECKPOINT_PATH / "image_presenter.pt").exists():
        print("Loading image adapter")
        image_adapter = ImageAdapter(clip_model.config.hidden_size, 4096, False, False, 38, False)
        image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_presenter.pt", map_location="cpu"))
        image_adapter.eval()
        image_adapter.to("cpu")
    else:
        print("Image adapter not found, will use CLIP features directly")

    return clip_model, image_adapter