How to load CONCH v1.5 weights?
Hi Mahmood Lab team,
Thank you for your excellent work . we’ve been using your models regularly for our pathology image research.
I have a question regarding the usage of CONCH v1.5. The weight file structure seems different from the previous version of CONCH, and since the model structure and loading logic seem slightly different as well, I wanted to confirm the correct usage.
Could you kindly provide a brief guide on how to apply the v1.5 weights? If possible, a minimal example or reference to a sample script would be greatly appreciated.
Also, it looks like the v1.5 weights only contain parameters for the vision encoder. Are the previous CONCH text encoder parameters still compatible with v1.5?
Thank you in advance!
import os
import logging
import torch
import torch.nn as nn
from torchvision import transforms
import timm
from os.path import join as pjoin
import json
from torchvision.transforms.functional import to_pil_image
logging.basicConfig(level=logging.INFO)
def get_eval_transforms_conchv1_5(img_resize: int = 448):
transform = transforms.Compose(
[
transforms.Resize(
img_resize, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(img_resize),
transforms.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
return transform
def get_encoder_conchv1_5():
ckpt_path = "/data/ckpt/conchv1.5/pytorch_model_vision.bin"
vision_tower = timm.create_model(
"vit_large_patch16_224",
img_size=448,
patch_size=16,
init_values=1.0,
num_classes=0,
dynamic_img_size=True,
)
# original_forward = vision_tower.forward
# vision_tower.forward = lambda x: vision_tower.forward_features(x)
state_dict = torch.load(ckpt_path, map_location="cpu")
missing_keys, unexpected_keys = vision_tower.load_state_dict(
state_dict, strict=False
)
print("missing_keys, unexpected_keys: ", missing_keys, unexpected_keys)
print("ConchV1.5 parameters: ", sum(p.numel() for p in vision_tower.parameters()))
vision_tower.eval()
return vision_tower
if __name__ == "__main__":
model = get_encoder_conchv1_5()
transform = get_eval_transforms_conchv1_5()
print(model)
print(transform)
input = torch.rand(3, 256, 512)
print(input.shape)
input = transform(to_pil_image(input))
print(input.shape)
output = model.forward_features(input.unsqueeze(0))
print(output.shape)
You can give this a try.
We recommend using the Trident library for model loading, see https://github.com/mahmoodlab/TRIDENT/blob/main/trident/patch_encoder_models/load.py#L26.
from trident.patch_encoder_models import encoder_factory
model = encoder_factory(model_name='conch_v15')
This logic applies to all our models (UNI, UNI2, CONCH, TITAN, and other models from the community)
Hope this helps,
Best, Guillaume
Wow, thank you very much! This code repository is incredibly helpful for my research!!
Thank you! I was able to successfully load the model thanks to your help.
I have one more question—would it be alright to ask if the text model is compatible with the original CONCH?
I think it might be incompatible, due to the use of different training models.
Hi everyone,
Thank you for the wonderful discussion and thanks to the Mahmood group for all the amazing models. I am trying to use CONCHv1.5 to extract features from my patch images (.png) of size 512.
I find that TRIDENT supports input as patch coordinates. Is there a way to extract features from patch images instead of coordinates using CONCHv1.5?
Thanks!
If you want a quick fix, Trident can handle PIL-readable images off the shelf. So all you gotta do is run the normal Trident pipeline on all images ("pretending" they are WSIs). You just need to provide the MPP (likely 0.5?) into a csv file as these are not embedded into the PNGs. eg.,
python run_batch_of_slides.py --task all --wsi_dir ./pngs --job_dir ./png_processed --patch_encoder conch_v15 --patch_size 512 --mag 20 --custom_list_of_wsis my_pngs_to_encode.csv
where my_pngs_to_encode.csv
is a CSV file that will look like:
wsi,mpp
1.png,0.5
2.png,0.5
3.png,0.5
Hope this helps.
This is incredibly helpful—I'll give it a try. Thanks so much for your prompt response!
@guillaumejaume
Thanks again for the great tip—it’s a helpful workaround for handling images instead of coordinates.
However, when --tasks is set to all, it attempts to perform segmentation within the patches, which I don’t need, as segmentation and patch extraction were already completed at the slide level beforehand.
My goal is to extract features from the patch images (.png) using the CONCHv1.5 model.
I also tried using the model.encode_image(image) function (available in the CONCH model). But seems like it is not compatible with CONCHv1.5.
I’d really appreciate any suggestions you might have for addressing this.
Thanks in advance!
Thank you for your kind response. I understand that CONCH v1.5 was fine-tuned in the CoCa style based on UNI weights. Are the weights for the text encoder in CONCH v1.5 not publicly available?
Thanks in advance!
Hi @credit1
You are correct - Currently, we are not planning on releasing the text encoder for CONCH v1.5.
If there are any further updates on this matter, we will make sure to announce it through HF.
Thanks!