## Base Configurations

In [1]:
import os
import torch
from transformers import SegformerForSemanticSegmentation
from dataclasses import dataclass


@dataclass
class Configs:
    NUM_CLASSES = 4
    MODEL_PATH: str = "nvidia/segformer-b4-finetuned-ade-512-512"

## Load Model To Inspect Parameter Names

In [2]:


def get_model(*, model_path, num_classes):
    model = SegformerForSemanticSegmentation.from_pretrained(
        model_path,
        num_labels=num_classes,
        ignore_mismatched_sizes=True,
    )
    return model

In [3]:
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
model_state_dict = model.state_dict()

print()
for i, (key, val) in enumerate(model_state_dict.items()):
    print(key)
    if i == 5:
        break

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



segformer.encoder.patch_embeddings.0.proj.weight
segformer.encoder.patch_embeddings.0.proj.bias
segformer.encoder.patch_embeddings.0.layer_norm.weight
segformer.encoder.patch_embeddings.0.layer_norm.bias
segformer.encoder.patch_embeddings.1.proj.weight
segformer.encoder.patch_embeddings.1.proj.bias


## Download & load PyTorch-Lightning Checkpoint and Inspect Parameter Names

In [4]:
import wandb

run = wandb.init()
artifact = run.use_artifact("veb-101/UM_medical_segmentation/model-fpgquxev:v0", type="model")
artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mveb-101[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model-fpgquxev:v0, 977.89MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:1:5.3


In [5]:
CKPT = torch.load(os.path.join(artifact_dir, "model.ckpt"), map_location="cpu")
print(CKPT.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecisionPlugin', 'hparams_name', 'hyper_parameters'])


In [6]:
TRAINED_CKPT_state_dict = CKPT["state_dict"]

for i, (key, val) in enumerate(TRAINED_CKPT_state_dict.items()):
    print(key)
    if i == 5:
        break

model.segformer.encoder.patch_embeddings.0.proj.weight
model.segformer.encoder.patch_embeddings.0.proj.bias
model.segformer.encoder.patch_embeddings.0.layer_norm.weight
model.segformer.encoder.patch_embeddings.0.layer_norm.bias
model.segformer.encoder.patch_embeddings.1.proj.weight
model.segformer.encoder.patch_embeddings.1.proj.bias


**The pytorch-lightning `state_dict()` has an extra `model.` string at the front that refers to the object/variable name that was holding the model in the `LightningModule` class.**

We can simply iterate over the parameters and change the parameter key name. We'll create a new `OrderedDict` for it.

In [7]:
from collections import OrderedDict

new_state_dict = OrderedDict()

for key_name, value in CKPT["state_dict"].items():
    new_state_dict[key_name.replace("model.", "")] = value

In [8]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [9]:
# torch.save(model.state_dict(), "Segformer_best_state_dict.ckpt")

In [10]:
model.save_pretrained("segformer_trained_weights")

To load the saved model, we simply need to pass the path to the directory "segformer_trained_weights".

In [None]:
# model = get_model(model_path=os.path.join(os.getcwd(), "segformer_trained_weights"), num_classes=Configs.NUM_CLASSES)