|
from transformers import AutoModel, AutoProcessor |
|
from pathlib import Path |
|
from image_adapter import ImageAdapter |
|
import torch |
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
CHECKPOINT_PATH = Path("Adieee5/Image-captioning") |
|
|
|
|
|
|
|
def initialize_models(): |
|
"""Initialize and load all models""" |
|
print("Loading CLIP") |
|
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) |
|
clip_model = AutoModel.from_pretrained(CLIP_PATH) |
|
clip_model = clip_model.vision_model |
|
|
|
if (CHECKPOINT_PATH / "clip_model.pt").exists(): |
|
print("Loading VLM's custom vision model") |
|
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu') |
|
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()} |
|
clip_model.load_state_dict(checkpoint) |
|
del checkpoint |
|
else: |
|
print("Custom CLIP weights not found, using default weights") |
|
|
|
clip_model.eval() |
|
clip_model.requires_grad_(False) |
|
clip_model.to("cpu") |
|
|
|
image_adapter = None |
|
if (CHECKPOINT_PATH / "image_presenter.pt").exists(): |
|
print("Loading image adapter") |
|
image_adapter = ImageAdapter(clip_model.config.hidden_size, 4096, False, False, 38, False) |
|
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_presenter.pt", map_location="cpu")) |
|
image_adapter.eval() |
|
image_adapter.to("cpu") |
|
else: |
|
print("Image adapter not found, will use CLIP features directly") |
|
|
|
return clip_model, image_adapter |
|
|