Image_Prompting-and-Captioning / model_initial.py
Adieee5's picture
Update model_initial.py
a27a49d verified
from transformers import AutoModel, AutoProcessor
from pathlib import Path
from image_adapter import ImageAdapter
import torch
CLIP_PATH = "google/siglip-so400m-patch14-384"
CHECKPOINT_PATH = Path("Adieee5/Image-captioning")
# CHECKPOINT_PATH = Path("cheackpoints")
def initialize_models():
"""Initialize and load all models"""
print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH)
clip_model = clip_model.vision_model
if (CHECKPOINT_PATH / "clip_model.pt").exists():
print("Loading VLM's custom vision model")
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
clip_model.load_state_dict(checkpoint)
del checkpoint
else:
print("Custom CLIP weights not found, using default weights")
clip_model.eval()
clip_model.requires_grad_(False)
clip_model.to("cpu")
image_adapter = None
if (CHECKPOINT_PATH / "image_presenter.pt").exists():
print("Loading image adapter")
image_adapter = ImageAdapter(clip_model.config.hidden_size, 4096, False, False, 38, False)
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_presenter.pt", map_location="cpu"))
image_adapter.eval()
image_adapter.to("cpu")
else:
print("Image adapter not found, will use CLIP features directly")
return clip_model, image_adapter