|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility to convert weights to safetensors.""" |
|
|
|
import argparse |
|
|
|
import torch |
|
|
|
from .configuration_embed1 import CosmosEmbed1Config |
|
from .modeling_embed1 import CosmosEmbed1 |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Save model weights with optional format conversion and sharding.") |
|
parser.add_argument("--input_weights", type=str, required=True, help="Path to the input .pt weights file") |
|
parser.add_argument( |
|
"--output_weights", |
|
type=str, |
|
required=True, |
|
help="Path to the output directory where safetensors weights will be saved", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
model = CosmosEmbed1(CosmosEmbed1Config()).to("cuda", dtype=torch.bfloat16) |
|
|
|
|
|
model.qformer.cls.predictions.decoder.weight = torch.nn.Parameter( |
|
model.qformer.cls.predictions.decoder.weight.clone() |
|
) |
|
model.qformer.bert.embeddings.word_embeddings.weight = torch.nn.Parameter( |
|
model.qformer.bert.embeddings.word_embeddings.weight.clone() |
|
) |
|
model.qformer.cls.predictions.decoder.bias = torch.nn.Parameter(model.qformer.cls.predictions.decoder.bias.clone()) |
|
model.qformer.cls.predictions.bias = torch.nn.Parameter(model.qformer.cls.predictions.bias.clone()) |
|
|
|
with open(args.input_weights, "rb") as fp: |
|
state_dict = torch.load(fp) |
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
model.save_pretrained( |
|
args.output_weights, |
|
safe_serialization=True, |
|
max_shard_size="500MB", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|