File size: 4,960 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

import gc
import json
from types import MethodType
import safetensors.torch
import torch
import torch.nn as nn
import safetensors


from torch import Tensor, nn
import copy


def Flux1PartialLoad_Patch(args={}):
    model = args.get("model")

    double_blocks_cuda_size = args.get("double_blocks_cuda_size")
    single_blocks_cuda_size = args.get("single_blocks_cuda_size")

    def other_to_cpu():
        model.model.diffusion_model.img_in.to("cpu")
        model.model.diffusion_model.time_in.to("cpu")
        model.model.diffusion_model.guidance_in.to("cpu")
        model.model.diffusion_model.vector_in.to("cpu")
        model.model.diffusion_model.txt_in.to("cpu")
        model.model.diffusion_model.pe_embedder.to("cpu")

        torch.cuda.empty_cache()

    def other_to_cuda():
        model.model.diffusion_model.img_in.to("cuda")
        model.model.diffusion_model.time_in.to("cuda")
        model.model.diffusion_model.guidance_in.to("cuda")
        model.model.diffusion_model.vector_in.to("cuda")
        model.model.diffusion_model.txt_in.to("cuda")
        model.model.diffusion_model.pe_embedder.to("cuda")

    def double_blocks_to_cpu(layer_start=0, layer_size=-1):
        if layer_size == -1:
            model.model.diffusion_model.double_blocks.to("cpu")
        else:
            model.model.diffusion_model.double_blocks[layer_start:layer_start +
                                                      layer_size].to("cpu")
        torch.cuda.empty_cache()
        # gc.collect()

    def double_blocks_to_cuda(layer_start=0, layer_size=-1):
        if layer_size == -1:
            model.model.diffusion_model.double_blocks.to("cuda")
        else:
            model.model.diffusion_model.double_blocks[layer_start:layer_start +
                                                      layer_size].to("cuda")

    def single_blocks_to_cpu(layer_start=0, layer_size=-1):
        if layer_size == -1:
            model.model.diffusion_model.single_blocks.to("cpu")
        else:
            model.model.diffusion_model.single_blocks[layer_start:layer_start +
                                                      layer_size].to("cpu")
        torch.cuda.empty_cache()
        # gc.collect()

    def single_blocks_to_cuda(layer_start=0, layer_size=-1):
        if layer_size == -1:
            model.model.diffusion_model.single_blocks.to("cuda")
        else:
            model.model.diffusion_model.single_blocks[layer_start:layer_start +
                                                      layer_size].to("cuda")

    def generate_double_blocks_forward_hook(layer_start, layer_size):
        def pre_only_double_blocks_forward_hook(module, inp):

            other_to_cpu()

            if layer_start > 0:
                double_blocks_to_cpu(layer_start=0, layer_size=layer_start)

            double_blocks_to_cuda(layer_start=layer_start,
                                  layer_size=layer_size)
            # print("pre_only_double_blocks_forward_hook: ",
            #       layer_start, layer_size)
            # input("Press Enter to continue...")
            return inp
        return pre_only_double_blocks_forward_hook

    def generate_single_blocks_forward_hook(layer_start, layer_size):
        def pre_only_single_blocks_forward_hook(module, inp):
            double_blocks_to_cpu()
            if layer_start > 0:
                single_blocks_to_cpu(layer_start=0, layer_size=layer_start)

            single_blocks_to_cuda(layer_start=layer_start,
                                  layer_size=layer_size)
            # print("pre_only_single_blocks_forward_hook: ",
            #       layer_start, layer_size)
            # input("Press Enter to continue...")
            return inp
        return pre_only_single_blocks_forward_hook

    def pre_only_model_forward_hook(module, inp):
        # print("double_blocks to cpu")
        double_blocks_to_cpu()
        # print("single_blocks to cpu")
        single_blocks_to_cpu()
        # print("other to cuda")
        other_to_cuda()
        return inp

    model.model.diffusion_model.register_forward_pre_hook(
        pre_only_model_forward_hook)

    double_blocks_depth = len(model.model.diffusion_model.double_blocks)
    steps = double_blocks_cuda_size
    for i in range(0, double_blocks_depth, steps):
        s = steps
        if i + s > double_blocks_depth:
            s = double_blocks_depth - i
        model.model.diffusion_model.double_blocks[i].register_forward_pre_hook(
            generate_double_blocks_forward_hook(i, s))

    single_blocks_depth = len(model.model.diffusion_model.single_blocks)
    steps = single_blocks_cuda_size
    for i in range(0, single_blocks_depth, steps):
        s = steps
        if i + s > single_blocks_depth:
            s = single_blocks_depth - i
        model.model.diffusion_model.single_blocks[i].register_forward_pre_hook(
            generate_single_blocks_forward_hook(i, s))

    return (model,)