File size: 12,057 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import random
from typing import Any, Dict, Optional

import torch
import torch.distributed.checkpoint.stateful
from diffusers.video_processor import VideoProcessor

import finetrainers.functional as FF
from finetrainers.logging import get_logger
from finetrainers.processors import CannyProcessor, CopyProcessor

from .config import ControlType, FrameConditioningType


logger = get_logger()


class IterableControlDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
    def __init__(
        self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None
    ):
        super().__init__()

        self.dataset = dataset
        self.control_type = control_type

        self.control_processors = []
        if control_type == ControlType.CANNY:
            self.control_processors.append(
                CannyProcessor(
                    output_names=["control_output"], input_names={"image": "input", "video": "input"}, device=device
                )
            )
        elif control_type == ControlType.NONE:
            self.control_processors.append(
                CopyProcessor(output_names=["control_output"], input_names={"image": "input", "video": "input"})
            )

        logger.info("Initialized IterableControlDataset")

    def __iter__(self):
        logger.info("Starting IterableControlDataset")
        for data in iter(self.dataset):
            control_augmented_data = self._run_control_processors(data)
            yield control_augmented_data

    def load_state_dict(self, state_dict):
        self.dataset.load_state_dict(state_dict)

    def state_dict(self):
        return self.dataset.state_dict()

    def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
        if "control_image" in data:
            if "image" in data:
                data["control_image"] = FF.resize_to_nearest_bucket_image(
                    data["control_image"], [data["image"].shape[-2:]], resize_mode="bicubic"
                )
            if "video" in data:
                batch_size, num_frames, num_channels, height, width = data["video"].shape
                data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
                    data["control_video"], [[num_frames, height, width]], resize_mode="bicubic"
                )
                if _first_frame_only:
                    msg = (
                        "The number of frames in the control video is less than the minimum bucket size "
                        "specified. The first frame is being used as a single frame video. This "
                        "message is logged at the first occurence and for every 128th occurence "
                        "after that."
                    )
                    logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128)
                    data["control_video"] = data["control_video"][0]
            return data

        if "control_video" in data:
            if "image" in data:
                data["control_image"] = FF.resize_to_nearest_bucket_image(
                    data["control_video"][0], [data["image"].shape[-2:]], resize_mode="bicubic"
                )
            if "video" in data:
                batch_size, num_frames, num_channels, height, width = data["video"].shape
                data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
                    data["control_video"], [[num_frames, height, width]], resize_mode="bicubic"
                )
                if _first_frame_only:
                    msg = (
                        "The number of frames in the control video is less than the minimum bucket size "
                        "specified. The first frame is being used as a single frame video. This "
                        "message is logged at the first occurence and for every 128th occurence "
                        "after that."
                    )
                    logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128)
                    data["control_video"] = data["control_video"][0]
            return data

        if self.control_type == ControlType.CUSTOM:
            return data

        shallow_copy_data = dict(data.items())
        is_image_control = "image" in shallow_copy_data
        is_video_control = "video" in shallow_copy_data
        if (is_image_control + is_video_control) != 1:
            raise ValueError("Exactly one of 'image' or 'video' should be present in the data.")
        for processor in self.control_processors:
            result = processor(**shallow_copy_data)
            result_keys = set(result.keys())
            repeat_keys = result_keys.intersection(shallow_copy_data.keys())
            if repeat_keys:
                logger.warning(
                    f"Processor {processor.__class__.__name__} returned keys that already exist in "
                    f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
                    f"be intended. Please rename the keys in the processor to avoid conflicts."
                )
            shallow_copy_data.update(result)
        if "control_output" in shallow_copy_data:
            # Normalize to [-1, 1] range
            control_output = shallow_copy_data.pop("control_output")
            # TODO(aryan): need to specify a dim for normalize here across channels
            control_output = FF.normalize(control_output, min=-1.0, max=1.0)
            key = "control_image" if is_image_control else "control_video"
            shallow_copy_data[key] = control_output
        return shallow_copy_data


class ValidationControlDataset(torch.utils.data.IterableDataset):
    def __init__(
        self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None
    ):
        super().__init__()

        self.dataset = dataset
        self.control_type = control_type
        self.device = device
        self._video_processor = VideoProcessor()

        self.control_processors = []
        if control_type == ControlType.CANNY:
            self.control_processors.append(
                CannyProcessor(["control_output"], input_names={"image": "input", "video": "input"}, device=device)
            )
        elif control_type == ControlType.NONE:
            self.control_processors.append(
                CopyProcessor(["control_output"], input_names={"image": "input", "video": "input"})
            )

        logger.info("Initialized ValidationControlDataset")

    def __iter__(self):
        logger.info("Starting ValidationControlDataset")
        for data in iter(self.dataset):
            control_augmented_data = self._run_control_processors(data)
            yield control_augmented_data

    def load_state_dict(self, state_dict):
        self.dataset.load_state_dict(state_dict)

    def state_dict(self):
        return self.dataset.state_dict()

    def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
        if self.control_type == ControlType.CUSTOM:
            return data
        # These are already expected to be tensors
        if "control_image" in data or "control_video" in data:
            return data
        shallow_copy_data = dict(data.items())
        is_image_control = "image" in shallow_copy_data
        is_video_control = "video" in shallow_copy_data
        if (is_image_control + is_video_control) != 1:
            raise ValueError("Exactly one of 'image' or 'video' should be present in the data.")
        for processor in self.control_processors:
            result = processor(**shallow_copy_data)
            result_keys = set(result.keys())
            repeat_keys = result_keys.intersection(shallow_copy_data.keys())
            if repeat_keys:
                logger.warning(
                    f"Processor {processor.__class__.__name__} returned keys that already exist in "
                    f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
                    f"be intended. Please rename the keys in the processor to avoid conflicts."
                )
            shallow_copy_data.update(result)
        if "control_output" in shallow_copy_data:
            # Normalize to [-1, 1] range
            control_output = shallow_copy_data.pop("control_output")
            if torch.is_tensor(control_output):
                # TODO(aryan): need to specify a dim for normalize here across channels
                control_output = FF.normalize(control_output, min=-1.0, max=1.0)
                ndim = control_output.ndim
                assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5"
                if ndim == 5:
                    control_output = self._video_processor.postprocess_video(control_output, output_type="pil")
                else:
                    if ndim == 3:
                        control_output = control_output.unsqueeze(0)
                    control_output = self._video_processor.postprocess(control_output, output_type="pil")[0]
            key = "control_image" if is_image_control else "control_video"
            shallow_copy_data[key] = control_output
        return shallow_copy_data


# TODO(aryan): write a test for this function
def apply_frame_conditioning_on_latents(
    latents: torch.Tensor,
    expected_num_frames: int,
    channel_dim: int,
    frame_dim: int,
    frame_conditioning_type: FrameConditioningType,
    frame_conditioning_index: Optional[int] = None,
    concatenate_mask: bool = False,
) -> torch.Tensor:
    num_frames = latents.size(frame_dim)
    mask = torch.zeros_like(latents)

    if frame_conditioning_type == FrameConditioningType.INDEX:
        frame_index = min(frame_conditioning_index, num_frames - 1)
        indexing = [slice(None)] * latents.ndim
        indexing[frame_dim] = frame_index
        mask[tuple(indexing)] = 1
        latents = latents * mask

    elif frame_conditioning_type == FrameConditioningType.PREFIX:
        frame_index = random.randint(1, num_frames)
        indexing = [slice(None)] * latents.ndim
        indexing[frame_dim] = slice(0, frame_index)  # Keep frames 0 to frame_index-1
        mask[tuple(indexing)] = 1
        latents = latents * mask

    elif frame_conditioning_type == FrameConditioningType.RANDOM:
        # Zero or more random frames to keep
        num_frames_to_keep = random.randint(1, num_frames)
        frame_indices = random.sample(range(num_frames), num_frames_to_keep)
        indexing = [slice(None)] * latents.ndim
        indexing[frame_dim] = frame_indices
        mask[tuple(indexing)] = 1
        latents = latents * mask

    elif frame_conditioning_type == FrameConditioningType.FIRST_AND_LAST:
        indexing = [slice(None)] * latents.ndim
        indexing[frame_dim] = 0
        mask[tuple(indexing)] = 1
        indexing[frame_dim] = num_frames - 1
        mask[tuple(indexing)] = 1
        latents = latents * mask

    elif frame_conditioning_type == FrameConditioningType.FULL:
        indexing = [slice(None)] * latents.ndim
        indexing[frame_dim] = slice(0, num_frames)
        mask[tuple(indexing)] = 1

    if latents.size(frame_dim) >= expected_num_frames:
        slicing = [slice(None)] * latents.ndim
        slicing[frame_dim] = slice(expected_num_frames)
        latents = latents[tuple(slicing)]
        mask = mask[tuple(slicing)]
    else:
        pad_size = expected_num_frames - num_frames
        pad_shape = list(latents.shape)
        pad_shape[frame_dim] = pad_size
        padding = latents.new_zeros(pad_shape)
        latents = torch.cat([latents, padding], dim=frame_dim)
        mask = torch.cat([mask, padding], dim=frame_dim)

    if concatenate_mask:
        slicing = [slice(None)] * latents.ndim
        slicing[channel_dim] = 0
        latents = torch.cat([latents, mask], dim=channel_dim)

    return latents