from transformers import AutoModel, AutoProcessor from pathlib import Path from image_adapter import ImageAdapter import torch CLIP_PATH = "google/siglip-so400m-patch14-384" CHECKPOINT_PATH = Path("Adieee5/Image-captioning") # CHECKPOINT_PATH = Path("cheackpoints") def initialize_models(): """Initialize and load all models""" print("Loading CLIP") clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) clip_model = AutoModel.from_pretrained(CLIP_PATH) clip_model = clip_model.vision_model if (CHECKPOINT_PATH / "clip_model.pt").exists(): print("Loading VLM's custom vision model") checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu') checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()} clip_model.load_state_dict(checkpoint) del checkpoint else: print("Custom CLIP weights not found, using default weights") clip_model.eval() clip_model.requires_grad_(False) clip_model.to("cpu") image_adapter = None if (CHECKPOINT_PATH / "image_presenter.pt").exists(): print("Loading image adapter") image_adapter = ImageAdapter(clip_model.config.hidden_size, 4096, False, False, 38, False) image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_presenter.pt", map_location="cpu")) image_adapter.eval() image_adapter.to("cpu") else: print("Image adapter not found, will use CLIP features directly") return clip_model, image_adapter