File size: 3,622 Bytes
f056744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn.functional as F
from diffusers.callbacks import PipelineCallback
from scipy.ndimage import binary_dilation
from skimage.filters import threshold_otsu

from ..attn_utils.mask_utils import get_mask


class CallbackLatentStore(PipelineCallback):
    tensor_inputs = ['latents']
    
    def __init__(self):
        self.latents = []

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
        self.latents.append(callback_kwargs['latents'])
        return callback_kwargs

class CallbackAll(PipelineCallback):
    tensor_inputs = ['latents']
    def __init__(
        self, 
        latents, 
        attn_collector, 
        feature_collector, 
        feature_inject_steps, 
        mid_step_index=0,
        step_start=0,
        use_mask=False,
        use_ca_mask=False,
        source_ca_index=None,
        target_ca_index=None,
        mask_steps=18,
        mask_kwargs={},
        mask=None,
    ):
        self.latents = latents

        self.attn_collector = attn_collector
        self.feature_collector = feature_collector
        self.feature_inject_steps = feature_inject_steps

        self.mid_step_index = mid_step_index
        self.step_start = step_start
        
        self.mask = mask
        self.mask_steps = mask_steps

        self.use_mask = use_mask
        self.use_ca_mask = use_ca_mask
        self.source_ca_index = source_ca_index
        self.target_ca_index = target_ca_index
        self.mask_kwargs = mask_kwargs

    def latent_blend(self, s, t, mask):
        return s * (1-mask) + t * mask
        # return s * mask.logical_not() + t * mask

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
        cur_step = step_index + self.step_start

        if self.latents is None:
            pass
        elif cur_step < self.mid_step_index:
            inject_latent = self.latents[self.mid_step_index]
            callback_kwargs['latents'][:1] = inject_latent

        if self.use_mask:
            if self.use_ca_mask:
                if self.source_ca_index is not None:
                    source_ca = self.attn_collector.controller.source_ca
                    mask = get_mask(source_ca, self.source_ca_index, **self.mask_kwargs)
                elif self.target_ca_index is not None:
                    if cur_step < 5:
                        return callback_kwargs
                    target_ca = self.attn_collector.controller.target_ca
                    mask = get_mask(target_ca, self.target_ca_index, **self.mask_kwargs)
                self.mask = mask
            elif self.mask is not None:
                mask = self.mask
            else:
                return callback_kwargs

            if (cur_step < self.mask_steps):
                mask = mask.to(pipeline.dtype)
                target_latent = callback_kwargs['latents'][1:]
                blend_latent = self.latents[cur_step+1]
                # if cur_step + 1 < self.mid_step_index:
                #     blend_latent = self.latents[cur_step+1]
                # else:
                #     blend_latent = callback_kwargs['latents'][:1]
                
                new_latent = self.latent_blend(
                    pipeline._unpack_latents(blend_latent, 1024, 1024, pipeline.vae_scale_factor), 
                    pipeline._unpack_latents(target_latent, 1024, 1024, pipeline.vae_scale_factor),
                    mask
                )
                new_latent = pipeline._pack_latents(new_latent, *new_latent.shape)
                callback_kwargs['latents'][1:] = new_latent

        return callback_kwargs