Spaces:
No application file
No application file
import argparse | |
import os | |
import torch | |
def save_only_ema_weights(checkpoint_file): | |
"""Extract and save only the EMA weights.""" | |
checkpoint = torch.load(checkpoint_file, map_location="cpu") | |
weights = {} | |
if "ema" in checkpoint: | |
weights["model"] = checkpoint["ema"]["module"] | |
else: | |
raise ValueError("The checkpoint does not contain 'ema'.") | |
dir_name, base_name = os.path.split(checkpoint_file) | |
name, ext = os.path.splitext(base_name) | |
output_file = os.path.join(dir_name, f"{name}_converted{ext}") | |
torch.save(weights, output_file) | |
print(f"EMA weights saved to {output_file}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Extract and save only EMA weights.") | |
parser.add_argument("checkpoint_file", type=str, help="Path to the input checkpoint file.") | |
args = parser.parse_args() | |
save_only_ema_weights(args.checkpoint_file) | |