File size: 3,757 Bytes
95c7f67
 
5c67222
058f52d
ef5bc80
058f52d
95c7f67
 
 
 
 
 
 
 
 
 
 
058f52d
 
 
 
 
 
 
 
 
 
 
 
734cd8a
058f52d
 
 
 
 
 
734cd8a
058f52d
 
 
 
 
 
 
 
 
 
 
 
 
 
734cd8a
 
 
058f52d
 
734cd8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c7f67
 
 
 
 
 
058f52d
ef5bc80
734cd8a
 
 
 
95c7f67
 
ef5bc80
 
058f52d
95c7f67
ef5bc80
5c67222
95c7f67
ef5bc80
95c7f67
 
936aca5
ef5bc80
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import requests
from safetensors.torch import load_file, save_file
import torch
torch.cuda.empty_cache()
import torch.nn.functional as F
from tqdm import tqdm

def download_file(url, dest_path):
    print(f"Downloading {url} to {dest_path}")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(dest_path, 'wb') as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
    else:
        raise Exception(f"Failed to download file from {url}")

def load_model(file_path):
    return load_file(file_path)

def save_model(merged_model, output_file):
    print(f"Saving merged model to {output_file}")
    save_file(merged_model, output_file)

def resize_tensor_shapes(tensor1, tensor2):
    if tensor1.size() == tensor2.size():
        return tensor1, tensor2

    # Resize tensor2 to match tensor1's size (Base size)
    max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
    tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
    tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))

    return tensor1_resized, tensor2_resized

def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.75):
    print(f"Merging checkpoints with blend ratio: {blend_ratio}")
    merged = {}
    all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))

    for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
        t1, t2 = ckpt1.get(key), ckpt2.get(key)
        if t1 is not None and t2 is not None:
            t1, t2 = resize_tensor_shapes(t1, t2)
            merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
        elif t1 is not None:
            merged[key] = t1
        else:
            merged[key] = t2

    # Control the final size to be approximately 26 GB
    control_output_size(merged, target_size_gb=26)

    return merged

def control_output_size(merged, target_size_gb):
    # Estimate the size in bytes
    target_size_bytes = target_size_gb * 1024**3  # Convert GB to bytes
    current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())

    # If the current size exceeds the target, truncate the tensors
    if current_size_bytes > target_size_bytes:
        excess_size = current_size_bytes - target_size_bytes
        print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
        
        # Adjusting the tensors to meet the target size
        for key in merged.keys():
            tensor = merged[key]
            # Calculate how much we can reduce
            reduce_size = excess_size // tensor.element_size()  # Number of elements to reduce
            if tensor.numel() > reduce_size:
                # Truncate the tensor
                merged[key] = tensor.flatten()[:tensor.numel() - reduce_size].view(tensor.shape)

def cleanup_files(*file_paths):
    for file_path in file_paths:
        if os.path.exists(file_path):
            os.remove(file_path)
            print(f"Deleted {file_path}")

if __name__ == "__main__":
    try:
        model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
        model2_path = "output_checkpoint.safetensors"
        blend_ratio = 0.75  # Adjust ratio based on requirement
        output_file = "output_checkpoints.safetensors"

        # Loading models
        model1 = load_model(model1_path)
        model2 = load_model(model2_path)

        # Merging models
        merged_model = merge_checkpoints(model1, model2, blend_ratio)

        # Saving merged model
        save_model(merged_model, output_file)

        # Cleaning up downloaded files
        cleanup_files(model1_path)
        
    except Exception as e:
        print(f"An error occurred: {e}")