Upload 25 files
Browse files- src/builder.py +4 -0
- src/datasets/collate_functions.py +90 -0
- src/datasets/samplers/multi_source_sampler.py +202 -0
- src/datasets/text2image/__init__.py +0 -0
- src/datasets/text2image/text2image.py +649 -0
- src/datasets/understanding/caption_datasets.py +342 -0
- src/datasets/understanding/caption_prompts.py +28 -0
- src/datasets/understanding/vlm_datasets_sig.py +168 -0
- src/datasets/utils.py +303 -0
- src/models/mar/decoder.py +102 -0
- src/models/mar/diffloss.py +249 -0
- src/models/mar/diffusion/__init__.py +47 -0
- src/models/mar/diffusion/diffusion_utils.py +73 -0
- src/models/mar/diffusion/gaussian_diffusion.py +884 -0
- src/models/mar/diffusion/respace.py +129 -0
- src/models/mar/engine_mar.py +99 -0
- src/models/mar/mar.py +477 -0
- src/models/mar/misc.py +340 -0
- src/models/mar/vae.py +525 -0
- src/models/skywork_unipic_dev.py +645 -0
- src/models/skywork_unipic_ori.py +350 -0
- src/models/skywork_unipic_siglip.py +342 -0
- src/optimisers/constructor.py +64 -0
- src/optimisers/custom_adamw.py +32 -0
- src/runners/custom_runner.py +202 -0
src/builder.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmengine.registry import Registry
|
2 |
+
__all__ = ['BUILDER']
|
3 |
+
|
4 |
+
BUILDER = Registry('builder')
|
src/datasets/collate_functions.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from xtuner.utils import IGNORE_INDEX
|
3 |
+
from typing import Dict, Sequence
|
4 |
+
from torch.nn.utils.rnn import pad_sequence
|
5 |
+
from functools import partial
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
|
9 |
+
def collate_func_gen(instances: Sequence[Dict],
|
10 |
+
pad_index: int = 151645):
|
11 |
+
pixel_values_src, pixel_values, input_ids, input_lengths = [], [], [], []
|
12 |
+
|
13 |
+
for example in instances:
|
14 |
+
# 提取图像数据
|
15 |
+
if 'pixel_values_src' in example:
|
16 |
+
pixel_values_src.append(example.pop('pixel_values_src'))
|
17 |
+
if 'pixel_values' in example:
|
18 |
+
pixel_values.append(example.pop('pixel_values'))
|
19 |
+
|
20 |
+
input_lengths.append(len(example['input_ids']))
|
21 |
+
input_ids.append(example.pop('input_ids'))
|
22 |
+
|
23 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
|
24 |
+
attention_mask = torch.zeros_like(input_ids).bool()
|
25 |
+
for i in range(len(input_ids)):
|
26 |
+
attention_mask[i, :input_lengths[i]] = True
|
27 |
+
|
28 |
+
data_dict = {
|
29 |
+
'input_ids': input_ids,
|
30 |
+
'attention_mask': attention_mask,
|
31 |
+
}
|
32 |
+
|
33 |
+
if pixel_values:
|
34 |
+
data_dict['pixel_values'] = torch.stack(pixel_values)
|
35 |
+
if pixel_values_src:
|
36 |
+
data_dict['pixel_values_src'] = torch.stack(pixel_values_src)
|
37 |
+
|
38 |
+
return {'data': data_dict, 'data_samples': None}
|
39 |
+
|
40 |
+
|
41 |
+
def collate_func_und(instances, pad_index=151645):
|
42 |
+
input_ids_list, labels_list, pixel_values_list = [], [], []
|
43 |
+
|
44 |
+
for sample in instances:
|
45 |
+
input_ids_list.append(torch.LongTensor(sample['input_ids']))
|
46 |
+
labels_list.append(torch.LongTensor(sample['labels']))
|
47 |
+
|
48 |
+
if 'pixel_values' in sample:
|
49 |
+
pixel_values_list.append(sample['pixel_values'])
|
50 |
+
|
51 |
+
ori_length = [len(input_ids_) for input_ids_ in input_ids_list]
|
52 |
+
# right padding
|
53 |
+
if len(instances) > 1:
|
54 |
+
input_ids = pad_sequence(
|
55 |
+
input_ids_list, batch_first=True, padding_value=pad_index)
|
56 |
+
labels = pad_sequence(
|
57 |
+
labels_list, batch_first=True, padding_value=IGNORE_INDEX)
|
58 |
+
else:
|
59 |
+
input_ids = torch.stack(input_ids_list)
|
60 |
+
labels = torch.stack(labels_list)
|
61 |
+
|
62 |
+
attention_mask = torch.zeros_like(input_ids).bool()
|
63 |
+
for i, length in enumerate(ori_length):
|
64 |
+
attention_mask[i, :length] = True # right padding
|
65 |
+
|
66 |
+
data_dict = {
|
67 |
+
'input_ids': input_ids,
|
68 |
+
'attention_mask': attention_mask,
|
69 |
+
'labels': labels,
|
70 |
+
'pixel_values': torch.stack(pixel_values_list) if len(pixel_values_list) > 0 else None
|
71 |
+
}
|
72 |
+
|
73 |
+
return {'data': data_dict, 'data_samples': None}
|
74 |
+
|
75 |
+
|
76 |
+
class CollateConcat(object):
|
77 |
+
def __init__(self, collate_fns, keys):
|
78 |
+
self.keys = keys
|
79 |
+
self.collate_fns = {}
|
80 |
+
for key, collate_fn in zip(keys, collate_fns):
|
81 |
+
func = collate_fn.pop('type')
|
82 |
+
self.collate_fns[key] = partial(func, **collate_fn)
|
83 |
+
|
84 |
+
def __call__(self, data_samples):
|
85 |
+
data_samples = [data_sample for data_sample in data_samples if len(data_sample) > 0]
|
86 |
+
data_dict = {}
|
87 |
+
key = data_samples[0]['type']
|
88 |
+
data_dict[key] = self.collate_fns[key](data_samples)['data']
|
89 |
+
|
90 |
+
return {'data': data_dict, 'data_samples': None}
|
src/datasets/samplers/multi_source_sampler.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import itertools
|
3 |
+
from typing import Iterator, List, Optional, Sized, Union
|
4 |
+
import torch
|
5 |
+
from mmengine.dist import get_dist_info, sync_random_seed
|
6 |
+
from torch.utils.data import Sampler
|
7 |
+
|
8 |
+
|
9 |
+
class FixedBatchMultiSourceSampler(Sampler):
|
10 |
+
r"""Multi-Source Infinite Sampler.
|
11 |
+
|
12 |
+
According to the sampling ratio, sample data from different
|
13 |
+
datasets to form batches.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
repeat (tuple): repeat factor
|
17 |
+
dataset (Sized): The dataset.
|
18 |
+
batch_size (int): Size of mini-batch.
|
19 |
+
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
20 |
+
seed (int, optional): Random seed. If None, set a random seed.
|
21 |
+
Defaults to None.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
repeat,
|
26 |
+
dataset: Sized,
|
27 |
+
batch_size: int,
|
28 |
+
shuffle: bool = True,
|
29 |
+
seed: Optional[int] = None) -> None:
|
30 |
+
|
31 |
+
assert hasattr(dataset, 'cumulative_sizes'),\
|
32 |
+
f'The dataset must be ConcatDataset, but get {dataset}'
|
33 |
+
assert isinstance(batch_size, int) and batch_size > 0, \
|
34 |
+
'batch_size must be a positive integer value, ' \
|
35 |
+
f'but got batch_size={batch_size}'
|
36 |
+
assert len(repeat) == len(dataset.cumulative_sizes), \
|
37 |
+
'The length of repeat must be equal to ' \
|
38 |
+
f'the number of datasets, but got repeat={repeat}'
|
39 |
+
|
40 |
+
rank, world_size = get_dist_info()
|
41 |
+
self.rank = rank
|
42 |
+
self.world_size = world_size
|
43 |
+
|
44 |
+
self.dataset = dataset
|
45 |
+
self.repeat = repeat
|
46 |
+
self.cumulative_sizes = [0] + dataset.cumulative_sizes
|
47 |
+
self.batch_size = batch_size
|
48 |
+
|
49 |
+
self.seed = sync_random_seed() if seed is None else seed
|
50 |
+
self.shuffle = shuffle
|
51 |
+
self.source2inds = {
|
52 |
+
source: self._indices_of_rank(len(ds))
|
53 |
+
for source, ds in enumerate(dataset.datasets)
|
54 |
+
}
|
55 |
+
|
56 |
+
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
|
57 |
+
"""Infinitely yield a sequence of indices."""
|
58 |
+
g = torch.Generator()
|
59 |
+
g.manual_seed(self.seed)
|
60 |
+
while True:
|
61 |
+
if self.shuffle:
|
62 |
+
yield from torch.randperm(sample_size, generator=g).tolist()
|
63 |
+
else:
|
64 |
+
yield from torch.arange(sample_size).tolist()
|
65 |
+
|
66 |
+
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
|
67 |
+
"""Slice the infinite indices by rank."""
|
68 |
+
yield from itertools.islice(
|
69 |
+
self._infinite_indices(sample_size), self.rank, None,
|
70 |
+
self.world_size)
|
71 |
+
|
72 |
+
def __len__(self) -> int:
|
73 |
+
return len(self.dataset)
|
74 |
+
|
75 |
+
def set_epoch(self, epoch: int) -> None:
|
76 |
+
"""Not supported in `epoch-based runner."""
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __iter__(self) -> Iterator[int]:
|
80 |
+
while True:
|
81 |
+
for source, repeat in enumerate(self.repeat):
|
82 |
+
for _ in range(repeat):
|
83 |
+
batch_buffer_per_source = []
|
84 |
+
while len(batch_buffer_per_source) < self.batch_size:
|
85 |
+
idx = next(self.source2inds[source])
|
86 |
+
idx += self.cumulative_sizes[source]
|
87 |
+
batch_buffer_per_source.append(idx)
|
88 |
+
|
89 |
+
yield from batch_buffer_per_source
|
90 |
+
|
91 |
+
|
92 |
+
class MultiSourceSampler(Sampler):
|
93 |
+
def __init__(self,
|
94 |
+
repeats,
|
95 |
+
dataset: Sized,
|
96 |
+
batch_sizes: list[int],
|
97 |
+
shuffle: bool = True,
|
98 |
+
seed: Optional[int] = None) -> None:
|
99 |
+
|
100 |
+
assert hasattr(dataset, 'cumulative_sizes'),\
|
101 |
+
f'The dataset must be ConcatDataset, but get {dataset}'
|
102 |
+
|
103 |
+
assert isinstance(batch_sizes, list), \
|
104 |
+
f'source_ratio must be a list, but got batch_sizes={batch_sizes}'
|
105 |
+
assert len(batch_sizes) == len(dataset.cumulative_sizes), \
|
106 |
+
'The length of batch_sizes must be equal to ' \
|
107 |
+
f'the number of datasets, but got batch_sizes={batch_sizes}'
|
108 |
+
|
109 |
+
rank, world_size = get_dist_info()
|
110 |
+
self.rank = rank
|
111 |
+
self.world_size = world_size
|
112 |
+
|
113 |
+
self.dataset = dataset
|
114 |
+
self.cumulative_sizes = [0] + dataset.cumulative_sizes
|
115 |
+
self.batch_sizes = batch_sizes
|
116 |
+
|
117 |
+
|
118 |
+
self.seed = sync_random_seed() if seed is None else seed
|
119 |
+
self.shuffle = shuffle
|
120 |
+
self.source2inds = {
|
121 |
+
source: self._indices_of_rank(len(ds))
|
122 |
+
for source, ds in enumerate(dataset.datasets)
|
123 |
+
}
|
124 |
+
|
125 |
+
self.repeats = repeats
|
126 |
+
assert len(self.repeats) == len(self.batch_sizes)
|
127 |
+
|
128 |
+
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
|
129 |
+
"""Infinitely yield a sequence of indices."""
|
130 |
+
g = torch.Generator()
|
131 |
+
g.manual_seed(self.seed)
|
132 |
+
while True:
|
133 |
+
if self.shuffle:
|
134 |
+
yield from torch.randperm(sample_size, generator=g).tolist()
|
135 |
+
else:
|
136 |
+
yield from torch.arange(sample_size).tolist()
|
137 |
+
|
138 |
+
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
|
139 |
+
"""Slice the infinite indices by rank."""
|
140 |
+
yield from itertools.islice(
|
141 |
+
self._infinite_indices(sample_size), self.rank, None,
|
142 |
+
self.world_size)
|
143 |
+
|
144 |
+
|
145 |
+
def __len__(self) -> int:
|
146 |
+
return len(self.dataset)
|
147 |
+
|
148 |
+
def set_epoch(self, epoch: int) -> None:
|
149 |
+
"""Not supported in `epoch-based runner."""
|
150 |
+
pass
|
151 |
+
|
152 |
+
def __iter__(self) -> Iterator[int]:
|
153 |
+
while True:
|
154 |
+
for source, (batch_size, repeat) in enumerate(zip(self.batch_sizes, self.repeats)):
|
155 |
+
for _ in range(repeat):
|
156 |
+
batch_buffer_per_source = []
|
157 |
+
while len(batch_buffer_per_source) < batch_size:
|
158 |
+
idx = next(self.source2inds[source])
|
159 |
+
idx += self.cumulative_sizes[source]
|
160 |
+
batch_buffer_per_source.append(idx)
|
161 |
+
|
162 |
+
yield from batch_buffer_per_source
|
163 |
+
|
164 |
+
@property
|
165 |
+
def batch_size(self):
|
166 |
+
batch_size_sum = sum([batch_size * repeat for batch_size, repeat in zip(self.batch_sizes, self.repeats)])
|
167 |
+
batch_size_ave = batch_size_sum // sum(self.repeats)
|
168 |
+
|
169 |
+
return batch_size_ave
|
170 |
+
|
171 |
+
|
172 |
+
class MultiSourceBatchSampler(Sampler[list[int]]):
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
sampler: Union[FixedBatchMultiSourceSampler, MultiSourceSampler],
|
176 |
+
batch_sizes: list[int],
|
177 |
+
repeats: list[int],
|
178 |
+
**kwargs
|
179 |
+
) -> None:
|
180 |
+
self.sampler = sampler
|
181 |
+
self.batch_sizes = batch_sizes
|
182 |
+
self.repeats = repeats
|
183 |
+
|
184 |
+
def __iter__(self) -> Iterator[list[int]]:
|
185 |
+
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
|
186 |
+
sampler_iter = iter(self.sampler)
|
187 |
+
|
188 |
+
while True:
|
189 |
+
for source, (batch_size, repeat) in enumerate(zip(self.batch_sizes, self.repeats)):
|
190 |
+
for _ in range(repeat):
|
191 |
+
batch = [*itertools.islice(sampler_iter, batch_size)]
|
192 |
+
yield batch
|
193 |
+
|
194 |
+
@property
|
195 |
+
def batch_size(self):
|
196 |
+
batch_size_sum = sum([batch_size * repeat for batch_size, repeat in zip(self.batch_sizes, self.repeats)])
|
197 |
+
batch_size_ave = batch_size_sum // sum(self.repeats)
|
198 |
+
|
199 |
+
return batch_size_ave
|
200 |
+
|
201 |
+
def __len__(self) -> int:
|
202 |
+
return len(self.sampler) // self.batch_size
|
src/datasets/text2image/__init__.py
ADDED
File without changes
|
src/datasets/text2image/text2image.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from einops import rearrange
|
9 |
+
from xtuner.registry import BUILDER
|
10 |
+
from mmengine.registry import DATASETS
|
11 |
+
from src.datasets.utils import crop2square
|
12 |
+
from glob import glob
|
13 |
+
from typing import List, Dict, Any, Optional
|
14 |
+
import mmap
|
15 |
+
import struct
|
16 |
+
from src.datasets.utils import crop2square, encode_fn
|
17 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
18 |
+
|
19 |
+
|
20 |
+
@BUILDER.register_module()
|
21 |
+
class Text2ImageDataset(Dataset):
|
22 |
+
def __init__(self,
|
23 |
+
data_path,
|
24 |
+
local_folder,
|
25 |
+
image_size,
|
26 |
+
unconditional=0.1,
|
27 |
+
tokenizer=None,
|
28 |
+
prompt_template=None,
|
29 |
+
max_length=1024,
|
30 |
+
crop_image=True,
|
31 |
+
cap_source='caption',
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.data_path = data_path
|
35 |
+
self._load_data(data_path)
|
36 |
+
self.unconditional = unconditional
|
37 |
+
self.local_folder = local_folder
|
38 |
+
self.cap_source = cap_source
|
39 |
+
self.image_size = image_size
|
40 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
41 |
+
|
42 |
+
self.prompt_template = prompt_template
|
43 |
+
self.max_length = max_length
|
44 |
+
self.crop_image = crop_image
|
45 |
+
self.metainfo = {'task': 'unified'}
|
46 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def _load_data(self, data_path):
|
51 |
+
with open(data_path, 'r') as f:
|
52 |
+
self.data_list = json.load(f)
|
53 |
+
|
54 |
+
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
|
55 |
+
|
56 |
+
def full_init(self):
|
57 |
+
"""Dummy full_init to be compatible with MMEngine ConcatDataset."""
|
58 |
+
return
|
59 |
+
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return len(self.data_list)
|
63 |
+
|
64 |
+
def _read_image(self, image_file):
|
65 |
+
image = Image.open(os.path.join(self.local_folder, image_file))
|
66 |
+
assert image.width > 8 and image.height > 8, f"Image: {image.size}"
|
67 |
+
assert image.width / image.height > 0.1, f"Image: {image.size}"
|
68 |
+
assert image.width / image.height < 10, f"Image: {image.size}"
|
69 |
+
return image
|
70 |
+
|
71 |
+
def _process_text(self, text):
|
72 |
+
if random.uniform(0, 1) < self.unconditional:
|
73 |
+
prompt = "Generate an image."
|
74 |
+
else:
|
75 |
+
prompt = f"Generate an image: {text.strip()}"
|
76 |
+
prompt = self.prompt_template['INSTRUCTION'].format(input=prompt)
|
77 |
+
input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0]
|
78 |
+
|
79 |
+
return dict(input_ids=input_ids[:self.max_length])
|
80 |
+
|
81 |
+
def _process_image(self, image):
|
82 |
+
data = dict()
|
83 |
+
|
84 |
+
if self.crop_image:
|
85 |
+
image = crop2square(image)
|
86 |
+
else:
|
87 |
+
target_size = max(image.size)
|
88 |
+
image = image.resize(size=(target_size, target_size))
|
89 |
+
|
90 |
+
image = image.resize(size=(self.image_size, self.image_size))
|
91 |
+
pixel_values = torch.from_numpy(np.array(image)).float()
|
92 |
+
pixel_values = pixel_values / 255
|
93 |
+
pixel_values = 2 * pixel_values - 1
|
94 |
+
pixel_values = rearrange(pixel_values, 'h w c -> c h w')
|
95 |
+
|
96 |
+
data.update(pixel_values=pixel_values)
|
97 |
+
|
98 |
+
return data
|
99 |
+
|
100 |
+
def _retry(self):
|
101 |
+
return self.__getitem__(random.choice(range(self.__len__())))
|
102 |
+
|
103 |
+
def __getitem__(self, idx):
|
104 |
+
try:
|
105 |
+
data_sample = self.data_list[idx]
|
106 |
+
image = self._read_image(data_sample['image']).convert('RGB')
|
107 |
+
|
108 |
+
caption = data_sample[self.cap_source]
|
109 |
+
data = self._process_image(image)
|
110 |
+
data.update(self._process_text(caption))
|
111 |
+
data.update(type='text2image')
|
112 |
+
|
113 |
+
return data
|
114 |
+
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True)
|
117 |
+
return self._retry()
|
118 |
+
|
119 |
+
@DATASETS.register_module()
|
120 |
+
@BUILDER.register_module()
|
121 |
+
class LargeText2ImageDataset(Text2ImageDataset):
|
122 |
+
# self.data_list only contains paths of images and captions
|
123 |
+
|
124 |
+
def __init__(self, cap_folder=None, *args, **kwargs):
|
125 |
+
super().__init__(*args, **kwargs)
|
126 |
+
self.cap_folder = self.local_folder if cap_folder is None else cap_folder
|
127 |
+
|
128 |
+
def _load_data(self, data_path): # image path and annotation path are saved in a json file
|
129 |
+
if data_path.endswith(".json"):
|
130 |
+
with open(data_path, 'r') as f:
|
131 |
+
self.data_list = json.load(f)
|
132 |
+
else:
|
133 |
+
self.data_list = []
|
134 |
+
json_files = glob(f'{data_path}/*.json')
|
135 |
+
for json_file in json_files:
|
136 |
+
with open(json_file, 'r') as f:
|
137 |
+
self.data_list += json.load(f)
|
138 |
+
|
139 |
+
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
|
140 |
+
|
141 |
+
def __getitem__(self, idx):
|
142 |
+
try:
|
143 |
+
data_sample = self.data_list[idx]
|
144 |
+
image = self._read_image(data_sample['image']).convert('RGB')
|
145 |
+
with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f:
|
146 |
+
caption = json.load(f)[self.cap_source]
|
147 |
+
data = self._process_image(image)
|
148 |
+
data.update(self._process_text(caption))
|
149 |
+
data.update(type='text2image')
|
150 |
+
return data
|
151 |
+
|
152 |
+
except Exception as e:
|
153 |
+
print(f"Error when reading {self.data_path}:{data_sample}: {e}", flush=True)
|
154 |
+
return self._retry()
|
155 |
+
|
156 |
+
|
157 |
+
@DATASETS.register_module()
|
158 |
+
@BUILDER.register_module()
|
159 |
+
class MMapT2IDataset(Dataset):
|
160 |
+
"""
|
161 |
+
Map-style Text2Image Dataset with mmap-based random access.
|
162 |
+
一次性在 __init__ 打开 mmap;__getitem__ O(1) 读取指定行。
|
163 |
+
"""
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
jsonl_path: str,
|
167 |
+
idx_path: str,
|
168 |
+
image_size: int,
|
169 |
+
tokenizer: Optional[Dict] = None,
|
170 |
+
template_map_fn: Optional[Dict] = None,
|
171 |
+
cap_source: str = "prompt",
|
172 |
+
max_length: int = 2048,
|
173 |
+
image_length: int = 512,
|
174 |
+
unconditional: float = 0.01,
|
175 |
+
crop_image: bool = False,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
|
179 |
+
# ---------- 基础参数 ----------
|
180 |
+
self.jsonl_path = jsonl_path
|
181 |
+
self.image_size = image_size
|
182 |
+
self.cap_source = cap_source
|
183 |
+
self.max_length = max_length
|
184 |
+
self.unconditional = unconditional
|
185 |
+
self.crop_image = crop_image
|
186 |
+
|
187 |
+
# ---------- tokenizer / template ----------
|
188 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
189 |
+
self.template_map_fn = template_map_fn
|
190 |
+
|
191 |
+
# ---------- mmap 加载 ----------
|
192 |
+
self._open_mmap(jsonl_path, idx_path)
|
193 |
+
self.metainfo = {'task' :'unified'}
|
194 |
+
# ===== mmap & index =====
|
195 |
+
def _open_mmap(self, jsonl_path: str, idx_path: str):
|
196 |
+
# mmap 文件
|
197 |
+
self._jsonl_fp = open(jsonl_path, "r+b")
|
198 |
+
self._mm = mmap.mmap(self._jsonl_fp.fileno(), 0, access=mmap.ACCESS_READ)
|
199 |
+
|
200 |
+
# 读取 offset 索引
|
201 |
+
with open(idx_path, "rb") as f:
|
202 |
+
nlines = struct.unpack("<Q", f.read(8))[0]
|
203 |
+
self._offsets = np.frombuffer(f.read(8 * nlines), dtype=np.uint64)
|
204 |
+
print(f"[MMapT2IDataset] {jsonl_path}: {nlines} lines indexed")
|
205 |
+
|
206 |
+
def __len__(self) -> int:
|
207 |
+
return self._offsets.size
|
208 |
+
|
209 |
+
def full_init(self):
|
210 |
+
"""Dummy full_init to be compatible with MMEngine ConcatDataset."""
|
211 |
+
return
|
212 |
+
def _read_line(self, idx: int) -> str:
|
213 |
+
off = int(self._offsets[idx])
|
214 |
+
self._mm.seek(off)
|
215 |
+
return self._mm.readline().decode("utf-8")
|
216 |
+
|
217 |
+
# ===== 核心处理 =====
|
218 |
+
def _load_image(self, path: str) -> torch.Tensor:
|
219 |
+
img = Image.open(path).convert("RGB")
|
220 |
+
|
221 |
+
# 预处理:裁剪成方形 / pad
|
222 |
+
if self.crop_image:
|
223 |
+
img = crop2square(img)
|
224 |
+
else:
|
225 |
+
target_size = max(img.size)
|
226 |
+
img = img.resize((target_size, target_size))
|
227 |
+
|
228 |
+
img = img.resize((self.image_size, self.image_size))
|
229 |
+
arr = np.asarray(img, dtype=np.uint8) # HWC uint8
|
230 |
+
px = torch.as_tensor(arr).float() / 255.0 # 0-1
|
231 |
+
px = 2 * px - 1 # -1 ~ 1
|
232 |
+
return rearrange(px, "h w c -> c h w") # CHW
|
233 |
+
|
234 |
+
def _build_prompt(self, caption: str) -> torch.Tensor:
|
235 |
+
if random.random() < self.unconditional:
|
236 |
+
caption = "Generate an image."
|
237 |
+
else:
|
238 |
+
caption = f"Generate an image: {caption.strip()}"
|
239 |
+
|
240 |
+
instr = self.template_map_fn["INSTRUCTION"].format(input=caption)
|
241 |
+
ids = self.tokenizer.encode(
|
242 |
+
instr, add_special_tokens=True, return_tensors="pt"
|
243 |
+
)[0][: self.max_length]
|
244 |
+
return ids
|
245 |
+
|
246 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
247 |
+
# 1) 取 jsonl 行
|
248 |
+
sample = json.loads(self._read_line(idx))
|
249 |
+
|
250 |
+
# 2) 加载 & 处理图像
|
251 |
+
pixel_values = self._load_image(sample["image"])
|
252 |
+
|
253 |
+
# 3) 处理文本
|
254 |
+
caption = sample.get(self.cap_source, "")
|
255 |
+
input_ids = self._build_prompt(caption)
|
256 |
+
|
257 |
+
# 4) 打包
|
258 |
+
data = dict(
|
259 |
+
pixel_values=pixel_values,
|
260 |
+
input_ids=input_ids,
|
261 |
+
type="text2image",
|
262 |
+
image_file=sample["image"],
|
263 |
+
idx=idx,
|
264 |
+
)
|
265 |
+
return data
|
266 |
+
|
267 |
+
|
268 |
+
@DATASETS.register_module()
|
269 |
+
@BUILDER.register_module()
|
270 |
+
class ReconstructDataset(Dataset):
|
271 |
+
def __init__(self,
|
272 |
+
data_path: str,
|
273 |
+
image_size: int,
|
274 |
+
tokenizer=None,
|
275 |
+
prompt_template=None,
|
276 |
+
cap_source: str = "prompt",
|
277 |
+
max_length: int = 8192,
|
278 |
+
crop_image: bool = True,
|
279 |
+
img_prefix: str = ""):
|
280 |
+
super().__init__()
|
281 |
+
self.image_size = image_size
|
282 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
283 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
284 |
+
self.prompt_template = prompt_template
|
285 |
+
self.cap_source = cap_source
|
286 |
+
self.max_length = max_length
|
287 |
+
self.crop_image = crop_image
|
288 |
+
self.img_prefix = img_prefix
|
289 |
+
self._load_data(data_path)
|
290 |
+
|
291 |
+
m = n = self.image_size // 16
|
292 |
+
self.image_token_repeat = m * n + 64
|
293 |
+
self.metainfo = {'task': 'unified'}
|
294 |
+
|
295 |
+
def full_init(self):
|
296 |
+
"""Dummy full_init to be compatible with MMEngine ConcatDataset."""
|
297 |
+
return
|
298 |
+
|
299 |
+
def _load_data(self, path):
|
300 |
+
with open(path) as f:
|
301 |
+
self.data_list = [json.loads(l) for l in f]
|
302 |
+
print(f"[I2ICaptionReconstructDataset] Loaded {len(self.data_list)} samples from {path}")
|
303 |
+
|
304 |
+
def _add_prefix(self, rel):
|
305 |
+
return os.path.join(self.img_prefix, rel.lstrip("/")) if self.img_prefix else rel
|
306 |
+
|
307 |
+
def _read_image(self, path):
|
308 |
+
img = Image.open(path).convert("RGB")
|
309 |
+
assert img.width > 8 and img.height > 8 and 0.1 < img.width / img.height < 10
|
310 |
+
return img
|
311 |
+
|
312 |
+
# ---------- preprocess ----------
|
313 |
+
def _process_image(self, img):
|
314 |
+
img = crop2square(img) if self.crop_image else img.resize((max(img.size),)*2)
|
315 |
+
img = img.resize((self.image_size, self.image_size))
|
316 |
+
px = torch.from_numpy(np.array(img)).float() / 255.
|
317 |
+
px = 2 * px - 1
|
318 |
+
return rearrange(px, "h w c -> c h w")
|
319 |
+
|
320 |
+
def _encode_prompt(self, text):
|
321 |
+
# for bad_token in ["[IMAGE]", "<image_placeholder>", "<image_plaeholder>"]:
|
322 |
+
# text = text.replace(bad_token, "")
|
323 |
+
text = "Repeat this image."
|
324 |
+
prompt_in = f"<image>\n{text.strip()}"
|
325 |
+
prompt = self.prompt_template["INSTRUCTION"].format(input=prompt_in)
|
326 |
+
prompt = prompt.replace("<image>", "<image>" * self.image_token_repeat)
|
327 |
+
input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")[0]
|
328 |
+
mask = (input_ids != self.tokenizer.pad_token_id).long()
|
329 |
+
return input_ids[:self.max_length], mask[:self.max_length]
|
330 |
+
|
331 |
+
def __len__(self):
|
332 |
+
return len(self.data_list)
|
333 |
+
|
334 |
+
def _retry(self):
|
335 |
+
return self.__getitem__(random.randrange(len(self)))
|
336 |
+
|
337 |
+
def __getitem__(self, idx):
|
338 |
+
try:
|
339 |
+
sample = self.data_list[idx]
|
340 |
+
src_img = self._read_image(self._add_prefix(sample["image"]))
|
341 |
+
tgt_img = src_img
|
342 |
+
caption = sample[self.cap_source]
|
343 |
+
|
344 |
+
px_src = self._process_image(src_img)
|
345 |
+
px_tgt = self._process_image(tgt_img)
|
346 |
+
input_ids, mask = self._encode_prompt(caption)
|
347 |
+
|
348 |
+
return {
|
349 |
+
"pixel_values_src": px_src,
|
350 |
+
"pixel_values": px_tgt,
|
351 |
+
"input_ids": input_ids,
|
352 |
+
"attention_mask": mask,
|
353 |
+
"type": "image_edit"
|
354 |
+
}
|
355 |
+
except Exception as e:
|
356 |
+
print(f"[I2ICaptionReconstructDataset] Error @ {idx}: {e}")
|
357 |
+
return self._retry()
|
358 |
+
|
359 |
+
@DATASETS.register_module()
|
360 |
+
@BUILDER.register_module()
|
361 |
+
class UncondReconstructDataset(Dataset):
|
362 |
+
def __init__(self,
|
363 |
+
data_path: str,
|
364 |
+
image_size: int,
|
365 |
+
tokenizer=None,
|
366 |
+
prompt_template=None,
|
367 |
+
cap_source: str = "prompt",
|
368 |
+
max_length: int = 8192,
|
369 |
+
crop_image: bool = True,
|
370 |
+
img_prefix: str = ""):
|
371 |
+
super().__init__()
|
372 |
+
self.image_size = image_size
|
373 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
374 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
375 |
+
self.prompt_template = prompt_template
|
376 |
+
self.max_length = max_length
|
377 |
+
self.crop_image = crop_image
|
378 |
+
self.img_prefix = img_prefix
|
379 |
+
self.cap_source = cap_source
|
380 |
+
|
381 |
+
|
382 |
+
self._load_data(data_path)
|
383 |
+
|
384 |
+
# 计算 image token 展开数量
|
385 |
+
m = n = self.image_size // 16
|
386 |
+
self.image_token_repeat = m * n + 64
|
387 |
+
self.metainfo = {'task': 'unified'}
|
388 |
+
|
389 |
+
def _load_data(self, path):
|
390 |
+
with open(path) as f:
|
391 |
+
self.data_list = [json.loads(l) for l in f]
|
392 |
+
print(f"[I2IUncondReconstructDataset] Loaded {len(self.data_list)} samples from {path}")
|
393 |
+
|
394 |
+
def _add_prefix(self, rel_path):
|
395 |
+
return os.path.join(self.img_prefix, rel_path.lstrip("/")) if self.img_prefix else rel_path
|
396 |
+
|
397 |
+
def full_init(self):
|
398 |
+
"""Dummy full_init to be compatible with MMEngine ConcatDataset."""
|
399 |
+
return
|
400 |
+
def _read_image(self, path):
|
401 |
+
image = Image.open(path).convert("RGB")
|
402 |
+
assert image.width > 8 and image.height > 8 and 0.1 < image.width / image.height < 10
|
403 |
+
return image
|
404 |
+
|
405 |
+
|
406 |
+
# ---------- preprocess ----------
|
407 |
+
def _process_image(self, img):
|
408 |
+
img = crop2square(img) if self.crop_image else img.resize((max(img.size),)*2)
|
409 |
+
img = img.resize((self.image_size, self.image_size))
|
410 |
+
px = torch.from_numpy(np.array(img)).float() / 255.
|
411 |
+
px = 2 * px - 1
|
412 |
+
return rearrange(px, "h w c -> c h w")
|
413 |
+
|
414 |
+
def __len__(self):
|
415 |
+
return len(self.data_list)
|
416 |
+
|
417 |
+
def _retry(self, max_tries=5):
|
418 |
+
for _ in range(max_tries):
|
419 |
+
try:
|
420 |
+
return self.__getitem__(random.randrange(len(self)))
|
421 |
+
except Exception:
|
422 |
+
continue
|
423 |
+
raise RuntimeError("Exceeded max retries in I2IUncondReconstructDataset")
|
424 |
+
|
425 |
+
def __getitem__(self, idx):
|
426 |
+
try:
|
427 |
+
sample = self.data_list[idx]
|
428 |
+
path = self._add_prefix(sample["image"])
|
429 |
+
img = self._read_image(path)
|
430 |
+
px = self._process_image(img)
|
431 |
+
|
432 |
+
# ==== 填入空文本 ====
|
433 |
+
input_ids = torch.zeros(0, dtype=torch.long)
|
434 |
+
attention_mask = torch.zeros(0, dtype=torch.long)
|
435 |
+
|
436 |
+
return {
|
437 |
+
"pixel_values_src": px,
|
438 |
+
"pixel_values": px.clone(),
|
439 |
+
"type": "image_edit",
|
440 |
+
"input_ids": input_ids,
|
441 |
+
"attention_mask": attention_mask,
|
442 |
+
# 重建任务不再输出 input_ids / attention_mask
|
443 |
+
}
|
444 |
+
except Exception as e:
|
445 |
+
print(f"[I2IUncondReconstructDataset] Error @ {idx}: {e}")
|
446 |
+
return self._retry()
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
@DATASETS.register_module()
|
451 |
+
@BUILDER.register_module()
|
452 |
+
class Text2ImageJSONLDataset(Dataset):
|
453 |
+
def __init__(self,
|
454 |
+
data_path,
|
455 |
+
image_size,
|
456 |
+
tokenizer=None,
|
457 |
+
prompt_template=None,
|
458 |
+
cap_source='prompt',
|
459 |
+
max_length=1024,
|
460 |
+
unconditional=0.1,
|
461 |
+
crop_image=True,
|
462 |
+
):
|
463 |
+
super().__init__()
|
464 |
+
self.data_path = data_path
|
465 |
+
self._load_data(data_path)
|
466 |
+
self.image_size = image_size
|
467 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
468 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
469 |
+
self.prompt_template = prompt_template
|
470 |
+
self.cap_source = cap_source
|
471 |
+
self.max_length = max_length
|
472 |
+
self.unconditional = unconditional
|
473 |
+
self.crop_image = crop_image
|
474 |
+
self.metainfo = {'task': 'unified'}
|
475 |
+
|
476 |
+
def _load_data(self, data_path):
|
477 |
+
self.data_list = []
|
478 |
+
with open(data_path, 'r') as f:
|
479 |
+
for line in f:
|
480 |
+
self.data_list.append(json.loads(line.strip()))
|
481 |
+
print(f"Loaded {len(self.data_list)} samples from {data_path}")
|
482 |
+
|
483 |
+
def full_init(self):
|
484 |
+
"""Dummy full_init for MMEngine ConcatDataset compatibility."""
|
485 |
+
pass
|
486 |
+
def __len__(self):
|
487 |
+
return len(self.data_list)
|
488 |
+
|
489 |
+
def _read_image(self, image_file):
|
490 |
+
image = Image.open(image_file).convert('RGB')
|
491 |
+
assert image.width > 8 and image.height > 8
|
492 |
+
assert 0.1 < image.width / image.height < 10
|
493 |
+
return image
|
494 |
+
|
495 |
+
def _process_image(self, image):
|
496 |
+
if self.crop_image:
|
497 |
+
image = crop2square(image)
|
498 |
+
else:
|
499 |
+
target_size = max(image.size)
|
500 |
+
image = image.resize((target_size, target_size))
|
501 |
+
|
502 |
+
image = image.resize((self.image_size, self.image_size))
|
503 |
+
pixel_values = torch.from_numpy(np.array(image)).float() / 255.0
|
504 |
+
pixel_values = 2 * pixel_values - 1 # [-1, 1]
|
505 |
+
pixel_values = rearrange(pixel_values, 'h w c -> c h w')
|
506 |
+
return dict(pixel_values=pixel_values)
|
507 |
+
|
508 |
+
def _process_text(self, text):
|
509 |
+
if random.uniform(0, 1) < self.unconditional:
|
510 |
+
text = "Generate an image."
|
511 |
+
else:
|
512 |
+
text = f"Generate an image: {text.strip()}"
|
513 |
+
prompt = self.prompt_template['INSTRUCTION'].format(input=text)
|
514 |
+
input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0]
|
515 |
+
return dict(input_ids=input_ids[:self.max_length])
|
516 |
+
|
517 |
+
def _retry(self):
|
518 |
+
return self.__getitem__(random.randint(0, len(self.data_list) - 1))
|
519 |
+
|
520 |
+
def __getitem__(self, idx):
|
521 |
+
try:
|
522 |
+
sample = self.data_list[idx]
|
523 |
+
image = self._read_image(sample['image'])
|
524 |
+
caption = sample[self.cap_source]
|
525 |
+
data = self._process_image(image)
|
526 |
+
data.update(self._process_text(caption))
|
527 |
+
data.update(type='text2image')
|
528 |
+
return data
|
529 |
+
except Exception as e:
|
530 |
+
print(f"[JSONLDataset] Error reading sample #{idx}: {e}")
|
531 |
+
return self._retry()
|
532 |
+
|
533 |
+
|
534 |
+
|
535 |
+
# 纯文生图没有占位符的问题,下面编辑数据集需要考虑占位符
|
536 |
+
@DATASETS.register_module()
|
537 |
+
@BUILDER.register_module()
|
538 |
+
class ImageEditJSONLDataset(Dataset):
|
539 |
+
"""
|
540 |
+
Dataset for <src, tgt, prompt> image editing, now decoupled from tokenization logic.
|
541 |
+
"""
|
542 |
+
def __init__(self,
|
543 |
+
data_path: str,
|
544 |
+
image_size: int,
|
545 |
+
tokenizer=None,
|
546 |
+
prompt_template=None,
|
547 |
+
max_length: int = 8192,
|
548 |
+
cap_source: str = "prompt",
|
549 |
+
unconditional: float = 0,
|
550 |
+
crop_image: bool = False,
|
551 |
+
img_prefix: str = ""):
|
552 |
+
super().__init__()
|
553 |
+
self.data_path = data_path
|
554 |
+
self.image_size = image_size
|
555 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
556 |
+
self.prompt_template = prompt_template
|
557 |
+
self.max_length = max_length
|
558 |
+
self.cap_source = cap_source
|
559 |
+
self.unconditional = unconditional
|
560 |
+
self.crop_image = crop_image
|
561 |
+
self.img_prefix = img_prefix
|
562 |
+
self._load_data(data_path)
|
563 |
+
# Calculate image token repetition length, consistent with inference.
|
564 |
+
m = n = self.image_size // 16
|
565 |
+
self.image_token_repeat = m * n + 64
|
566 |
+
self.metainfo = {'task': 'unified'}
|
567 |
+
|
568 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
569 |
+
self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
|
570 |
+
print(f"Registered <image> token at index {self.image_token_idx}")
|
571 |
+
|
572 |
+
def _load_data(self, path):
|
573 |
+
with open(path) as f:
|
574 |
+
self.data_list = [json.loads(l) for l in f]
|
575 |
+
print(f"[ImageEditJSONLDataset] Loaded {len(self.data_list)} samples from {path}")
|
576 |
+
|
577 |
+
def full_init(self):
|
578 |
+
"""Dummy full_init for MMEngine ConcatDataset compatibility."""
|
579 |
+
pass
|
580 |
+
|
581 |
+
def _add_prefix(self, rel_path):
|
582 |
+
return os.path.join(self.img_prefix, rel_path.lstrip("/")) if self.img_prefix else rel_path
|
583 |
+
|
584 |
+
def _read_image(self, path):
|
585 |
+
path = path.replace("datasets_vlm02", "datasets_vlm")
|
586 |
+
img = Image.open(path).convert("RGB")
|
587 |
+
assert img.width > 8 and img.height > 8 and 0.1 < img.width / img.height < 10
|
588 |
+
return img
|
589 |
+
|
590 |
+
def _process_image(self, img):
|
591 |
+
img = crop2square(img) if self.crop_image else img.resize((max(img.size),) * 2)
|
592 |
+
img = img.resize((self.image_size, self.image_size))
|
593 |
+
px = torch.from_numpy(np.array(img)).float() / 255.
|
594 |
+
px = 2 * px - 1
|
595 |
+
return rearrange(px, "h w c -> c h w")
|
596 |
+
|
597 |
+
# --- REFACTORED: This method now only prepares the raw prompt text ---
|
598 |
+
def _prepare_prompt_text(self, raw_text: str):
|
599 |
+
"""Cleans text and handles unconditional generation."""
|
600 |
+
|
601 |
+
for bad_token in ["[IMAGE]", "<image_placeholder>", "<image_plaeholder>", "<image>"]:
|
602 |
+
txt = raw_text.replace(bad_token, "")
|
603 |
+
txt = txt.strip()
|
604 |
+
|
605 |
+
if random.random() < self.unconditional:
|
606 |
+
txt = "Edit this image."
|
607 |
+
return txt
|
608 |
+
|
609 |
+
def _retry(self):
|
610 |
+
return self.__getitem__(random.randrange(len(self)))
|
611 |
+
|
612 |
+
def __len__(self):
|
613 |
+
return len(self.data_list)
|
614 |
+
|
615 |
+
def __getitem__(self, idx):
|
616 |
+
try:
|
617 |
+
sample = self.data_list[idx]
|
618 |
+
src_path, tgt_path = map(self._add_prefix, [sample["images"][0], sample["image"]])
|
619 |
+
src_img, tgt_img = map(self._read_image, [src_path, tgt_path])
|
620 |
+
|
621 |
+
px_src, px_tgt = map(self._process_image, [src_img, tgt_img])
|
622 |
+
|
623 |
+
# --- MODIFIED: Call the unified encode_fn ---
|
624 |
+
# 1. Prepare the raw prompt string
|
625 |
+
prompt_text = self._prepare_prompt_text(sample[self.cap_source])
|
626 |
+
|
627 |
+
# 2. Delegate all encoding and formatting to encode_fn
|
628 |
+
encoded_text = encode_fn(
|
629 |
+
example=prompt_text,
|
630 |
+
tokenizer=self.tokenizer,
|
631 |
+
prompt_template=self.prompt_template,
|
632 |
+
max_length=self.max_length,
|
633 |
+
image_length=self.image_token_repeat,
|
634 |
+
image_token_idx=self.image_token_idx
|
635 |
+
)
|
636 |
+
|
637 |
+
return {
|
638 |
+
"pixel_values_src": px_src,
|
639 |
+
"pixel_values": px_tgt,
|
640 |
+
"input_ids": torch.tensor(encoded_text["input_ids"], dtype=torch.long),
|
641 |
+
"attention_mask": torch.tensor(encoded_text["attention_mask"], dtype=torch.long),
|
642 |
+
"type": "image_edit",
|
643 |
+
}
|
644 |
+
except Exception as e:
|
645 |
+
print(f"[ImageEditJSONLDataset] Error @ {idx}: {e} from {self.data_path}")
|
646 |
+
return self._retry()
|
647 |
+
|
648 |
+
|
649 |
+
|
src/datasets/understanding/caption_datasets.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from einops import rearrange
|
10 |
+
try:
|
11 |
+
from aoss_client.client import Client
|
12 |
+
except:
|
13 |
+
try:
|
14 |
+
from petrel_client.client import Client
|
15 |
+
except:
|
16 |
+
Client = None
|
17 |
+
from glob import glob
|
18 |
+
from xtuner.registry import BUILDER
|
19 |
+
from xtuner.dataset.utils import expand2square
|
20 |
+
from src.datasets.utils import crop2square, encode_fn
|
21 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
22 |
+
from src.datasets.understanding.caption_prompts import dense_prompts, short_prompts
|
23 |
+
from typing import List, Dict, Any, Optional,Callable,Tuple
|
24 |
+
|
25 |
+
|
26 |
+
@BUILDER.register_module()
|
27 |
+
class CaptionDataset(Dataset):
|
28 |
+
def __init__(self,
|
29 |
+
data_path,
|
30 |
+
local_folder,
|
31 |
+
image_size,
|
32 |
+
ceph_folder=None,
|
33 |
+
ceph_config=None,
|
34 |
+
tokenizer=None,
|
35 |
+
template_map_fn=None,
|
36 |
+
max_length=2048,
|
37 |
+
min_image_size=80,
|
38 |
+
image_length=256,
|
39 |
+
pad_image=True,
|
40 |
+
brief=False,
|
41 |
+
cap_folder=None,
|
42 |
+
cap_source='caption',
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.data_path = data_path
|
46 |
+
self._load_data(data_path)
|
47 |
+
self.local_folder = local_folder
|
48 |
+
self.cap_folder = local_folder if cap_folder is None else cap_folder
|
49 |
+
self.cap_source = cap_source
|
50 |
+
|
51 |
+
self.image_size = image_size
|
52 |
+
|
53 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
54 |
+
self.prompt_template = template_map_fn['template']
|
55 |
+
self.template_map_fn = BUILDER.build(template_map_fn)
|
56 |
+
self.max_length = max_length
|
57 |
+
self.image_length = image_length
|
58 |
+
self.pad_image = pad_image
|
59 |
+
self.min_image_size = min_image_size
|
60 |
+
|
61 |
+
self.FILE_CLIENT = None
|
62 |
+
self.ceph_folder = ceph_folder
|
63 |
+
self.ceph_config = ceph_config
|
64 |
+
self.use_ceph = ((Client is not None) and (ceph_folder is not None)
|
65 |
+
and (ceph_config is not None) and os.path.exists(ceph_config))
|
66 |
+
|
67 |
+
self.brief = brief
|
68 |
+
self.caption_prompts = short_prompts if self.brief else dense_prompts
|
69 |
+
|
70 |
+
def _load_data(self, data_path: str): # image path and annotation path are saved in a json file
|
71 |
+
if data_path.endswith('.json'):
|
72 |
+
with open(data_path, 'r') as f:
|
73 |
+
self.data_list = json.load(f)
|
74 |
+
else:
|
75 |
+
json_files = glob(f"{data_path}/*.json")
|
76 |
+
data_list = []
|
77 |
+
for json_file in json_files:
|
78 |
+
with open(json_file, 'r') as f:
|
79 |
+
data_list += json.load(f)
|
80 |
+
|
81 |
+
self.data_list = data_list
|
82 |
+
|
83 |
+
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.data_list)
|
87 |
+
|
88 |
+
def _read_ceph(self, ceph_path):
|
89 |
+
if self.FILE_CLIENT is None:
|
90 |
+
self.FILE_CLIENT = Client(self.ceph_config)
|
91 |
+
data_bytes = self.FILE_CLIENT.get(ceph_path)
|
92 |
+
|
93 |
+
return io.BytesIO(data_bytes)
|
94 |
+
|
95 |
+
def _read_image(self, image_file):
|
96 |
+
if self.use_ceph:
|
97 |
+
image = Image.open(
|
98 |
+
self._read_ceph(
|
99 |
+
os.path.join(self.ceph_folder, image_file)
|
100 |
+
)
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
image = Image.open(
|
104 |
+
os.path.join(self.local_folder, image_file)
|
105 |
+
)
|
106 |
+
assert image.width > self.min_image_size and image.height > self.min_image_size, f"Image: {image.size}"
|
107 |
+
assert image.width / image.height > 0.1, f"Image: {image.size}"
|
108 |
+
assert image.width / image.height < 10, f"Image: {image.size}"
|
109 |
+
return image.convert('RGB')
|
110 |
+
|
111 |
+
def _read_json(self, annotation_file):
|
112 |
+
if self.use_ceph:
|
113 |
+
annotation = json.load(
|
114 |
+
self._read_ceph(
|
115 |
+
os.path.join(self.ceph_folder, annotation_file)
|
116 |
+
)
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
with open(os.path.join(self.local_folder, annotation_file), 'r') as f:
|
120 |
+
annotation = json.load(f)
|
121 |
+
|
122 |
+
return annotation
|
123 |
+
|
124 |
+
def _process_image(self, image):
|
125 |
+
data = dict()
|
126 |
+
if self.pad_image:
|
127 |
+
image = expand2square(image, (127, 127, 127))
|
128 |
+
else:
|
129 |
+
image = crop2square(image)
|
130 |
+
|
131 |
+
image = image.resize(size=(self.image_size, self.image_size))
|
132 |
+
pixel_values = torch.from_numpy(np.array(image)).float()
|
133 |
+
pixel_values = pixel_values / 255
|
134 |
+
pixel_values = 2 * pixel_values - 1
|
135 |
+
pixel_values = rearrange(pixel_values, 'h w c -> c h w')
|
136 |
+
|
137 |
+
data.update(pixel_values=pixel_values)
|
138 |
+
return data
|
139 |
+
|
140 |
+
def _process_text(self, text):
|
141 |
+
assert DEFAULT_IMAGE_TOKEN not in text, text
|
142 |
+
data_dict = dict(conversation=[{'input': f"{DEFAULT_IMAGE_TOKEN}\n{random.choice(self.caption_prompts)}",
|
143 |
+
'output': text.strip()}])
|
144 |
+
data_dict.update(self.template_map_fn(data_dict))
|
145 |
+
data_dict.update(encode_fn(data_dict, self.tokenizer, self.max_length,
|
146 |
+
self.image_length, True, True))
|
147 |
+
|
148 |
+
assert (torch.tensor(data_dict['input_ids']).long() == IMAGE_TOKEN_INDEX).sum() == self.image_length, \
|
149 |
+
"Error in image format"
|
150 |
+
|
151 |
+
data_dict['type'] = 'image2text'
|
152 |
+
return data_dict
|
153 |
+
|
154 |
+
def _retry(self):
|
155 |
+
return self.__getitem__(random.choice(range(self.__len__())))
|
156 |
+
|
157 |
+
def __getitem__(self, idx):
|
158 |
+
try:
|
159 |
+
data_sample = self.data_list[idx]
|
160 |
+
image = self._read_image(data_sample['image']).convert('RGB')
|
161 |
+
data = self._process_image(image)
|
162 |
+
del image
|
163 |
+
with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f:
|
164 |
+
caption = json.load(f)[self.cap_source]
|
165 |
+
data.update(self._process_text(caption))
|
166 |
+
|
167 |
+
data.update(image_dir=self.local_folder, image_file=data_sample['image'])
|
168 |
+
|
169 |
+
return data
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
print(f"Error when reading {self.data_path}:{data_sample['image']}: {e}", flush=True)
|
173 |
+
return self._retry()
|
174 |
+
|
175 |
+
|
176 |
+
@BUILDER.register_module()
|
177 |
+
class VqaDataset(Dataset):
|
178 |
+
"""Generic VQA / multimodal conversation dataset with robust IO & validation."""
|
179 |
+
# ---------- 初始化 ----------
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
data_path: str,
|
183 |
+
tokenizer, # ← 必填参数,放在最前
|
184 |
+
template_map_fn: Callable, # ← 必填参数,放在最前
|
185 |
+
img_prefix: Optional[str] = None,
|
186 |
+
image_size: int = 512,
|
187 |
+
max_length: int = 2048,
|
188 |
+
image_length: int = 1089,
|
189 |
+
pad_image: bool = True,
|
190 |
+
min_image_size: int = 80,
|
191 |
+
image_token_patterns: Tuple[str, ...] = ('<image>', '[image]', '<img>'),
|
192 |
+
max_retry: int = 5,
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.img_prefix = img_prefix.rstrip("/") if img_prefix else None
|
197 |
+
self.image_size = image_size
|
198 |
+
self.max_length = max_length
|
199 |
+
self.image_length = image_length
|
200 |
+
self.pad_image = pad_image
|
201 |
+
self.min_image_size = min_image_size
|
202 |
+
self.image_token_patterns = list(image_token_patterns)
|
203 |
+
self.max_retry = max_retry
|
204 |
+
|
205 |
+
# 构建 tokenizer 与模板
|
206 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
207 |
+
self.template_map_fn = BUILDER.build(template_map_fn) if template_map_fn else None
|
208 |
+
|
209 |
+
# 读取 jsonl / 目录
|
210 |
+
self.data_list = self._load_jsonl_list(data_path)
|
211 |
+
print(f"Loaded {len(self.data_list)} samples from {data_path}")
|
212 |
+
|
213 |
+
# ---------- 数据加载辅助 ----------
|
214 |
+
@staticmethod
|
215 |
+
def _load_jsonl_list(path: str) -> List[Dict[str, Any]]:
|
216 |
+
data: List[Dict[str, Any]] = []
|
217 |
+
if path.endswith(".jsonl"):
|
218 |
+
files = [path]
|
219 |
+
else:
|
220 |
+
files = sorted(glob(os.path.join(path, "**/*.jsonl"), recursive=True))
|
221 |
+
|
222 |
+
for file in files:
|
223 |
+
with open(file, "r") as f:
|
224 |
+
for line in f:
|
225 |
+
line = line.strip()
|
226 |
+
if line:
|
227 |
+
data.append(json.loads(line))
|
228 |
+
return data
|
229 |
+
|
230 |
+
# ---------- 基本接口 ----------
|
231 |
+
def __len__(self) -> int:
|
232 |
+
return len(self.data_list)
|
233 |
+
|
234 |
+
# ---------- 图像处理 ----------
|
235 |
+
def _get_image_path(self, img_file: str) -> str:
|
236 |
+
"""保持绝对路径不变,否则加前缀"""
|
237 |
+
return img_file if os.path.isabs(img_file) else os.path.join(self.img_prefix, img_file)
|
238 |
+
|
239 |
+
def _read_image(self, img_file: str) -> Image.Image:
|
240 |
+
img_path = self._get_image_path(img_file)
|
241 |
+
try:
|
242 |
+
image = Image.open(img_path).convert("RGB")
|
243 |
+
except Exception as e:
|
244 |
+
raise FileNotFoundError(f"Cannot open image: {img_path} ({e})")
|
245 |
+
|
246 |
+
w, h = image.size
|
247 |
+
if w < self.min_image_size or h < self.min_image_size:
|
248 |
+
raise ValueError(f"Image too small: {img_path} ({w}x{h})")
|
249 |
+
ratio = w / h
|
250 |
+
if not (0.1 < ratio < 10):
|
251 |
+
raise ValueError(f"Odd aspect ratio ({ratio:.3f}) for {img_path}")
|
252 |
+
|
253 |
+
# pad / crop
|
254 |
+
image = expand2square(image, (127, 127, 127)) if self.pad_image else crop2square(image)
|
255 |
+
image = image.resize((self.image_size, self.image_size), resample=Image.BICUBIC)
|
256 |
+
|
257 |
+
px = torch.from_numpy(np.asarray(image)).float() / 255.0
|
258 |
+
px = 2 * px - 1.0
|
259 |
+
px = rearrange(px, "h w c -> c h w") # CHW
|
260 |
+
return px
|
261 |
+
|
262 |
+
# ---------- 对话处理 ----------
|
263 |
+
def _replace_image_tokens(self, txt: str) -> str:
|
264 |
+
for pat in self.image_token_patterns:
|
265 |
+
if pat in txt:
|
266 |
+
txt = txt.replace(pat, str(self.image_token_idx))
|
267 |
+
return txt
|
268 |
+
|
269 |
+
def _format_conversation(self, turns: List[Dict[str, str]]) -> Dict[str, Any]:
|
270 |
+
"""
|
271 |
+
将多个 human/gpt 轮次合并为若干 {'input':..., 'output':...} 对。
|
272 |
+
遵循:human → gpt 为一对;若缺失 reply,用占位符。
|
273 |
+
"""
|
274 |
+
pairs = []
|
275 |
+
|
276 |
+
for i in range(0, len(turns), 2): # 每两回合一对,human 和 gpt
|
277 |
+
if i + 1 < len(turns): # 确保 gpt turn 存在
|
278 |
+
human_turn = turns[i]
|
279 |
+
gpt_turn = turns[i + 1]
|
280 |
+
|
281 |
+
human_content = human_turn.get("value", "").strip()
|
282 |
+
gpt_content = gpt_turn.get("value", "").strip()
|
283 |
+
|
284 |
+
if not human_content.lstrip().startswith("<image>"):
|
285 |
+
human_content = f"<image>\n{human_content}"
|
286 |
+
|
287 |
+
if not human_content or not gpt_content: # 如果某一方没有内容,跳过该对话
|
288 |
+
continue
|
289 |
+
|
290 |
+
# 只在 human turn 中加入图像 token
|
291 |
+
# human_content = self._replace_image_tokens(human_content) # 替换成 image_token_idx
|
292 |
+
|
293 |
+
pairs.append({"input": human_content, "output": gpt_content})
|
294 |
+
|
295 |
+
data_dict = {"conversation": pairs}
|
296 |
+
data_dict_ori = data_dict
|
297 |
+
if self.template_map_fn:
|
298 |
+
data_dict = self.template_map_fn(data_dict)
|
299 |
+
|
300 |
+
# 对输入进行编码
|
301 |
+
data_dict = encode_fn(
|
302 |
+
data_dict,
|
303 |
+
self.tokenizer,
|
304 |
+
self.max_length,
|
305 |
+
self.image_length,
|
306 |
+
input_ids_with_output=True,
|
307 |
+
with_image_token=True,
|
308 |
+
# 额外把 image_token_idx 传进去
|
309 |
+
image_token_idx=self.image_token_idx
|
310 |
+
)
|
311 |
+
|
312 |
+
# 动态校验:确保至少出现一次图像 token
|
313 |
+
img_tokens = (torch.tensor(data_dict["input_ids"]) == self.image_token_idx).sum().item()
|
314 |
+
|
315 |
+
# 使用f-string优化打印格式,确保输出类型安全
|
316 |
+
print(f"[校验日志] input_ids长度: {len(data_dict['input_ids'])}, 图像token出现次数: {img_tokens}\n")
|
317 |
+
# print(f"[校验日志] input_ids: {data_dict.get('input_ids', '未设置')}\n")
|
318 |
+
if img_tokens != 1088:
|
319 |
+
print(f"[异常对话]:{data_dict_ori}")
|
320 |
+
|
321 |
+
data_dict["type"] = "image2text" # 设置数据类型为 image2text
|
322 |
+
return data_dict
|
323 |
+
|
324 |
+
|
325 |
+
# ---------- 主接口 ----------
|
326 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
327 |
+
for attempt in range(self.max_retry):
|
328 |
+
try:
|
329 |
+
sample = self.data_list[idx]
|
330 |
+
img_tensor = self._read_image(sample["image"])
|
331 |
+
text_data = self._format_conversation(sample.get("conversations", []))
|
332 |
+
return {
|
333 |
+
**text_data,
|
334 |
+
"pixel_values": img_tensor,
|
335 |
+
"image_file": sample["image"],
|
336 |
+
}
|
337 |
+
except Exception as e:
|
338 |
+
print(f"[Retry {attempt+1}/{self.max_retry}] idx={idx} error: {e}")
|
339 |
+
idx = random.randint(0, len(self) - 1)
|
340 |
+
|
341 |
+
# 若多次失败则抛异常
|
342 |
+
raise RuntimeError(f"Failed to fetch valid sample after {self.max_retry} retries.")
|
src/datasets/understanding/caption_prompts.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dense_prompts = [
|
2 |
+
"Describe the image in detail.",
|
3 |
+
"Provide a comprehensive description of everything you see in the picture.",
|
4 |
+
"Explain the scene depicted in the image as if you were describing it to someone who cannot see it.",
|
5 |
+
"List all the objects and activities taking place in this image.",
|
6 |
+
"What is the story being told by this image? Describe in detail.",
|
7 |
+
"Imagine you are giving a detailed tour of the image's scene. How would you describe it?",
|
8 |
+
"Describe the foreground, background, and any notable features of the image.",
|
9 |
+
"How would you describe this image to build a replica of the scene?",
|
10 |
+
"Write a paragraph detailing the setting, characters, and actions visible in this image.",
|
11 |
+
"Describe every aspect of the image, including the environment, objects, and any people present.",
|
12 |
+
"Provide a detailed analysis of the composition and elements of the image.",
|
13 |
+
"What are the main focal points of this image? Describe them in detail.",
|
14 |
+
"Catalog all visible elements in the image and describe their significance to the overall scene."
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
short_prompts = [
|
19 |
+
"Briefly describe the image",
|
20 |
+
"Summarize the key elements of the image in one sentence.",
|
21 |
+
"Give a concise description of the scene.",
|
22 |
+
"Briefly, what is happening in this image?",
|
23 |
+
"What is the most noticeable feature of this image?",
|
24 |
+
"Summarize the image in a sentence.",
|
25 |
+
"What activity is being depicted in the image?",
|
26 |
+
"Describe the setting of the image in a few words.",
|
27 |
+
"Caption this image for a social media post."
|
28 |
+
]
|
src/datasets/understanding/vlm_datasets_sig.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from einops import rearrange
|
9 |
+
from xtuner.registry import BUILDER
|
10 |
+
from xtuner.dataset.utils import expand2square
|
11 |
+
from src.datasets.utils import crop2square, encode_fn, load_jsonl
|
12 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
13 |
+
from transformers import AutoImageProcessor
|
14 |
+
|
15 |
+
|
16 |
+
class VLMDataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
data_path,
|
20 |
+
image_size,
|
21 |
+
tokenizer=None,
|
22 |
+
template_map_fn=None,
|
23 |
+
max_length=2048,
|
24 |
+
min_image_size=80,
|
25 |
+
pad_image=True,
|
26 |
+
local_folder="",
|
27 |
+
key_value="conversations",
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.data_path = data_path
|
31 |
+
self._load_data(data_path)
|
32 |
+
self.image_size = image_size
|
33 |
+
|
34 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
35 |
+
self.prompt_template = template_map_fn["template"]
|
36 |
+
self.template_map_fn = BUILDER.build(template_map_fn)
|
37 |
+
self.max_length = max_length
|
38 |
+
self.pad_image = pad_image
|
39 |
+
self.min_image_size = min_image_size
|
40 |
+
self.key_value = key_value
|
41 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
42 |
+
"checkpoint/siglip2-so400m-patch16-512"
|
43 |
+
)
|
44 |
+
self.metainfo = {'task' :'unified'}
|
45 |
+
self.DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN
|
46 |
+
m = n = self.image_size // 16
|
47 |
+
self.image_token_repeat = m * n + 64
|
48 |
+
|
49 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
50 |
+
self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
|
51 |
+
print(f"Registered <image> token at index {self.image_token_idx}")
|
52 |
+
|
53 |
+
def _load_data(
|
54 |
+
self, data_path: str
|
55 |
+
): # image path and annotation path are saved in a json file
|
56 |
+
self.data_list = load_jsonl(data_path)
|
57 |
+
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
|
58 |
+
|
59 |
+
def full_init(self):
|
60 |
+
"""Dummy full_init to be compatible with MMEngine ConcatDataset."""
|
61 |
+
return
|
62 |
+
def __len__(self):
|
63 |
+
return len(self.data_list)
|
64 |
+
|
65 |
+
def _read_image(self, image_file):
|
66 |
+
image = Image.open(image_file)
|
67 |
+
assert (
|
68 |
+
image.width > self.min_image_size and image.height > self.min_image_size
|
69 |
+
), f"Image: {image.size}"
|
70 |
+
assert image.width / image.height > 0.1, f"Image: {image.size}"
|
71 |
+
assert image.width / image.height < 10, f"Image: {image.size}"
|
72 |
+
return image.convert("RGB")
|
73 |
+
|
74 |
+
# def _process_image(self, image):
|
75 |
+
# data = dict()
|
76 |
+
# # if self.pad_image:
|
77 |
+
# # image = expand2square(image, (127, 127, 127))
|
78 |
+
# # else:
|
79 |
+
# # image = crop2square(image)
|
80 |
+
|
81 |
+
# # image = image.resize(size=(self.image_size, self.image_size))
|
82 |
+
# # pixel_values = torch.from_numpy(np.array(image)).float()
|
83 |
+
# # pixel_values = pixel_values / 255
|
84 |
+
# # pixel_values = 2 * pixel_values - 1
|
85 |
+
# # pixel_values = rearrange(pixel_values, "h w c -> c h w")
|
86 |
+
# image = image.resize((self.image_size, self.image_size))
|
87 |
+
# inputs = self.processor(images=image, return_tensors="pt")
|
88 |
+
# pixel_values = inputs["pixel_values"].squeeze(0)
|
89 |
+
|
90 |
+
# data.update(pixel_values=pixel_values)
|
91 |
+
# return data
|
92 |
+
|
93 |
+
|
94 |
+
def _process_image(self, image: Image.Image):
|
95 |
+
# 1) 可选 crop/pad to square
|
96 |
+
if self.pad_image:
|
97 |
+
image = crop2square(image)
|
98 |
+
# 2) 手动 resize 到指定大小
|
99 |
+
image = image.resize((self.image_size, self.image_size))
|
100 |
+
# 3) to tensor & normalize
|
101 |
+
arr = np.array(image).astype(np.float32) / 255.0 # HWC
|
102 |
+
arr = 2 * arr - 1 # [-1,1]
|
103 |
+
tensor = torch.from_numpy(arr) # HWC
|
104 |
+
tensor = rearrange(tensor, "h w c -> c h w") # CHW
|
105 |
+
return {"pixel_values": tensor}
|
106 |
+
def _process_text(self, question, answer):
|
107 |
+
data_dict = dict(
|
108 |
+
conversation=[
|
109 |
+
{
|
110 |
+
"input": f"{self.DEFAULT_IMAGE_TOKEN}\n{question}",
|
111 |
+
"output": answer,
|
112 |
+
}
|
113 |
+
]
|
114 |
+
)
|
115 |
+
data_dict.update(self.template_map_fn(data_dict))
|
116 |
+
data_dict.update(
|
117 |
+
encode_fn(
|
118 |
+
example=data_dict,
|
119 |
+
tokenizer=self.tokenizer,
|
120 |
+
max_length=self.max_length,
|
121 |
+
image_length=self.image_token_repeat,
|
122 |
+
input_ids_with_output=True,
|
123 |
+
with_image_token=True,
|
124 |
+
truncation='right',
|
125 |
+
image_token_idx=self.image_token_idx,
|
126 |
+
image_token_str=self.DEFAULT_IMAGE_TOKEN,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
# assert (
|
131 |
+
# torch.tensor(data_dict["input_ids"]).long() == self.image_token_idx
|
132 |
+
# ).sum() == self.image_length, "Error in image format"
|
133 |
+
|
134 |
+
data_dict["type"] = "image2text"
|
135 |
+
return data_dict
|
136 |
+
|
137 |
+
def _retry(self):
|
138 |
+
return self.__getitem__(random.choice(range(self.__len__())))
|
139 |
+
|
140 |
+
def __getitem__(self, idx):
|
141 |
+
try:
|
142 |
+
data_sample = self.data_list[idx]
|
143 |
+
image = self._read_image(data_sample["image"]).convert("RGB")
|
144 |
+
data = self._process_image(image)
|
145 |
+
del image
|
146 |
+
question = (
|
147 |
+
data_sample[self.key_value][0]["value"]
|
148 |
+
.replace("<image>", "")
|
149 |
+
.strip()
|
150 |
+
)
|
151 |
+
answer = (
|
152 |
+
data_sample[self.key_value][1]["value"]
|
153 |
+
.replace("<image>", "")
|
154 |
+
.strip()
|
155 |
+
)
|
156 |
+
|
157 |
+
data.update(self._process_text(question, answer))
|
158 |
+
|
159 |
+
data.update(image_file=data_sample["image"])
|
160 |
+
|
161 |
+
return data
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
print(
|
165 |
+
f"Error when reading data_sample:{data_sample},{self.data_path}:{data_sample['image']}: {e}",
|
166 |
+
flush=True,
|
167 |
+
)
|
168 |
+
return self._retry()
|
src/datasets/utils.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
from xtuner.dataset.utils import get_bos_eos_token_ids
|
4 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
5 |
+
import json
|
6 |
+
|
7 |
+
|
8 |
+
# def crop2square(pil_img):
|
9 |
+
# width, height = pil_img.width, pil_img.height
|
10 |
+
|
11 |
+
# if width > height:
|
12 |
+
# y0, y1 = 0, height
|
13 |
+
# x0 = random.randint(0, width - height) # [0, w - h]
|
14 |
+
# x1 = x0 + height # [h, w]
|
15 |
+
# else:
|
16 |
+
# x0, x1 = 0, width
|
17 |
+
# y0 = random.randint(0, height - width) # [0, h - w]
|
18 |
+
# y1 = y0 + width # [w, h]
|
19 |
+
|
20 |
+
# return pil_img.crop(box=(x0, y0, x1, y1))
|
21 |
+
|
22 |
+
def crop2square(pil_img):
|
23 |
+
width, height = pil_img.width, pil_img.height
|
24 |
+
short = min(width, height)
|
25 |
+
left = (width - short) // 2
|
26 |
+
upper = (height - short) // 2
|
27 |
+
return pil_img.crop((left, upper, left + short, upper + short))
|
28 |
+
def load_jsonl(json_file):
|
29 |
+
with open(json_file) as f:
|
30 |
+
lines = f.readlines()
|
31 |
+
data = []
|
32 |
+
for line in lines:
|
33 |
+
data.append(json.loads(line))
|
34 |
+
return data
|
35 |
+
|
36 |
+
|
37 |
+
def encode_fn_original(example,
|
38 |
+
tokenizer,
|
39 |
+
max_length=None,
|
40 |
+
image_length=1,
|
41 |
+
input_ids_with_output=True,
|
42 |
+
with_image_token=False,
|
43 |
+
truncation='right',
|
44 |
+
image_token_idx=None,
|
45 |
+
image_token_str="<image>"):
|
46 |
+
"""We only support the following three scenarios:
|
47 |
+
|
48 |
+
1. Incremental pretraining dataset.
|
49 |
+
example['conversation'] = [
|
50 |
+
{
|
51 |
+
'input': '',
|
52 |
+
'output': '### Human: Can you write xxx'
|
53 |
+
}
|
54 |
+
]
|
55 |
+
|
56 |
+
2. Single-turn conversation dataset.
|
57 |
+
example['conversation'] = [
|
58 |
+
{
|
59 |
+
'input': 'Give three tips for staying healthy.',
|
60 |
+
'output': '1.Eat a balanced diet xxx'
|
61 |
+
}
|
62 |
+
]
|
63 |
+
|
64 |
+
3. Multi-turn conversation dataset.
|
65 |
+
example['conversation'] = [
|
66 |
+
{
|
67 |
+
'input': 'Give three tips for staying healthy.',
|
68 |
+
'output': '1.Eat a balanced diet xxx'
|
69 |
+
},
|
70 |
+
{
|
71 |
+
'input': 'Please expand on the second point.',
|
72 |
+
'output': 'Here is an expanded explanation of the xxx'
|
73 |
+
}
|
74 |
+
]
|
75 |
+
"""
|
76 |
+
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
|
77 |
+
if image_token_idx is None: # 如果没传,就退回库常量
|
78 |
+
image_token_idx = tokenizer.convert_tokens_to_ids("<image>")
|
79 |
+
|
80 |
+
is_multi_turn_conversation = len(example['conversation']) > 1
|
81 |
+
if is_multi_turn_conversation:
|
82 |
+
assert input_ids_with_output
|
83 |
+
|
84 |
+
input_ids, labels = [], []
|
85 |
+
next_needs_bos_token = True
|
86 |
+
for single_turn_conversation in example['conversation']:
|
87 |
+
input = single_turn_conversation['input']
|
88 |
+
if image_token_str in input and with_image_token:
|
89 |
+
chunk_encode = [
|
90 |
+
tokenizer.encode(chunk, add_special_tokens=False)
|
91 |
+
for chunk in input.split(image_token_str)
|
92 |
+
]
|
93 |
+
assert len(chunk_encode) == 2
|
94 |
+
input_encode = []
|
95 |
+
for idx, cur_chunk_encode in enumerate(chunk_encode):
|
96 |
+
input_encode.extend(cur_chunk_encode)
|
97 |
+
if idx != len(chunk_encode) - 1:
|
98 |
+
# input_encode.append(IMAGE_TOKEN_INDEX)
|
99 |
+
input_encode += [image_token_idx] * image_length
|
100 |
+
|
101 |
+
else:
|
102 |
+
input_encode = tokenizer.encode(input, add_special_tokens=False)
|
103 |
+
if next_needs_bos_token:
|
104 |
+
input_ids += bos_token_id
|
105 |
+
labels += [IGNORE_INDEX] * len(bos_token_id)
|
106 |
+
input_ids += input_encode
|
107 |
+
labels += [IGNORE_INDEX] * len(input_encode)
|
108 |
+
if input_ids_with_output and 'output' in single_turn_conversation:
|
109 |
+
# Add output
|
110 |
+
output_with_loss = single_turn_conversation.get(
|
111 |
+
'output_with_loss', True)
|
112 |
+
output = single_turn_conversation['output']
|
113 |
+
|
114 |
+
if image_token_str in output and with_image_token:
|
115 |
+
chunk_encode = [
|
116 |
+
tokenizer.encode(chunk, add_special_tokens=False)
|
117 |
+
for chunk in output.split(image_token_str)
|
118 |
+
]
|
119 |
+
assert len(chunk_encode) == 2
|
120 |
+
output_encode = []
|
121 |
+
for idx, cur_chunk_encode in enumerate(chunk_encode):
|
122 |
+
output_encode.extend(cur_chunk_encode)
|
123 |
+
if idx != len(chunk_encode) - 1:
|
124 |
+
output_encode += [image_token_idx] * image_length
|
125 |
+
else:
|
126 |
+
output_encode = tokenizer.encode(output, add_special_tokens=False)
|
127 |
+
# output_encode = tokenizer.encode(output, add_special_tokens=False)
|
128 |
+
input_ids += output_encode
|
129 |
+
if output_with_loss:
|
130 |
+
labels += copy.deepcopy(output_encode)
|
131 |
+
else:
|
132 |
+
labels += [IGNORE_INDEX] * len(output_encode)
|
133 |
+
# Add EOS_TOKEN (with loss)
|
134 |
+
if single_turn_conversation.get('need_eos_token', True):
|
135 |
+
next_needs_bos_token = True
|
136 |
+
input_ids += eos_token_id
|
137 |
+
if output_with_loss:
|
138 |
+
labels += copy.deepcopy(eos_token_id)
|
139 |
+
else:
|
140 |
+
labels += [IGNORE_INDEX] * len(eos_token_id)
|
141 |
+
else:
|
142 |
+
next_needs_bos_token = False
|
143 |
+
# Add SEP (without loss)
|
144 |
+
sep = single_turn_conversation.get('sep', '')
|
145 |
+
if sep != '':
|
146 |
+
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
|
147 |
+
input_ids += sep_encode
|
148 |
+
labels += [IGNORE_INDEX] * len(sep_encode)
|
149 |
+
|
150 |
+
if max_length is not None and len(input_ids) > max_length:
|
151 |
+
if truncation == 'right':
|
152 |
+
input_ids = input_ids[:max_length]
|
153 |
+
labels = labels[:max_length]
|
154 |
+
elif truncation == 'left':
|
155 |
+
input_ids = input_ids[-max_length:]
|
156 |
+
labels = labels[-max_length:]
|
157 |
+
else:
|
158 |
+
assert truncation is None
|
159 |
+
return {'input_ids': input_ids, 'labels': labels}
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def encode_fn(
|
164 |
+
example,
|
165 |
+
tokenizer,
|
166 |
+
prompt_template=None,
|
167 |
+
max_length=None,
|
168 |
+
image_length=1,
|
169 |
+
input_ids_with_output=True,
|
170 |
+
with_image_token=True,
|
171 |
+
truncation='right',
|
172 |
+
image_token_idx=None,
|
173 |
+
image_token_str="<image>",
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
A versatile encoding function for both image-to-text (conversation) and text-to-image/image-editing tasks.
|
177 |
+
|
178 |
+
- Image-to-Text: example = {"conversation": [...]}, outputs input_ids + labels.
|
179 |
+
- Text-to-Image/Editing: example = str (raw_text prompt), outputs input_ids + labels (with IGNORE_INDEX).
|
180 |
+
"""
|
181 |
+
# assert image_token_idx is not None, "Must pass image_token_idx explicitly"
|
182 |
+
# print(f"[DEBUG] image_token_idx = {image_token_idx}")
|
183 |
+
if image_token_idx is None:
|
184 |
+
tokenizer.add_tokens([image_token_str], special_tokens=True)
|
185 |
+
image_token_idx = tokenizer.convert_tokens_to_ids(image_token_str)
|
186 |
+
|
187 |
+
if isinstance(example, str):
|
188 |
+
assert prompt_template is not None, \
|
189 |
+
"prompt_template 不能为空(text2image/image-editing)"
|
190 |
+
|
191 |
+
# 1) 构造 prompt
|
192 |
+
# 直接在最前面加一个 <image> token,
|
193 |
+
# 然后空一行,再拼原始文本
|
194 |
+
prompt = f"{example.strip()}"
|
195 |
+
# 用模板包装
|
196 |
+
prompt = prompt_template["INSTRUCTION"].format(input=prompt)
|
197 |
+
|
198 |
+
# 2) 用 tokenizer 编码(不要让 tokenizer 把 <image> 当成普通字符切分)
|
199 |
+
# 一种简单做法:先去掉 tokenizer 里的特殊 token,再手动拼接
|
200 |
+
text_ids = tokenizer.encode(
|
201 |
+
prompt,
|
202 |
+
add_special_tokens=False,
|
203 |
+
truncation=True,
|
204 |
+
max_length=(max_length - image_length) if max_length else None
|
205 |
+
)
|
206 |
+
# 把 <image> token id 插到最前面(或者你想要的位置)
|
207 |
+
input_ids = [image_token_idx] * image_length + text_ids
|
208 |
+
|
209 |
+
# 3) 如果超长,直接截断
|
210 |
+
if max_length is not None and len(input_ids) > max_length:
|
211 |
+
input_ids = input_ids[:max_length]
|
212 |
+
|
213 |
+
# 4) attention_mask
|
214 |
+
attention_mask = [1] * len(input_ids)
|
215 |
+
|
216 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
217 |
+
|
218 |
+
# --- Image-to-text task: multi-turn conversation structure ---
|
219 |
+
assert isinstance(example, dict) and "conversation" in example
|
220 |
+
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
|
221 |
+
is_multi_turn = len(example["conversation"]) > 1
|
222 |
+
if is_multi_turn:
|
223 |
+
assert input_ids_with_output
|
224 |
+
|
225 |
+
input_ids, labels = [], []
|
226 |
+
next_needs_bos_token = True
|
227 |
+
|
228 |
+
for single_turn in example["conversation"]:
|
229 |
+
input_text = single_turn["input"]
|
230 |
+
|
231 |
+
# ==== Encode input ====
|
232 |
+
if with_image_token and image_token_str in input_text:
|
233 |
+
chunks = input_text.split(image_token_str)
|
234 |
+
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
|
235 |
+
assert len(chunk_encoded) >= 2
|
236 |
+
input_encode = []
|
237 |
+
for i, chunk in enumerate(chunk_encoded):
|
238 |
+
input_encode.extend(chunk)
|
239 |
+
if i < len(chunk_encoded) - 1:
|
240 |
+
input_encode.extend([image_token_idx] * image_length)
|
241 |
+
else:
|
242 |
+
input_encode = tokenizer.encode(input_text, add_special_tokens=False)
|
243 |
+
|
244 |
+
if next_needs_bos_token:
|
245 |
+
input_ids.extend(bos_token_id)
|
246 |
+
labels.extend([IGNORE_INDEX] * len(bos_token_id))
|
247 |
+
|
248 |
+
input_ids.extend(input_encode)
|
249 |
+
labels.extend([IGNORE_INDEX] * len(input_encode))
|
250 |
+
|
251 |
+
# ==== Encode output ====
|
252 |
+
if input_ids_with_output and "output" in single_turn:
|
253 |
+
output = single_turn["output"]
|
254 |
+
output_with_loss = single_turn.get("output_with_loss", True)
|
255 |
+
|
256 |
+
if with_image_token and image_token_str in output:
|
257 |
+
chunks = output.split(image_token_str)
|
258 |
+
chunk_encoded = [tokenizer.encode(c, add_special_tokens=False) for c in chunks]
|
259 |
+
assert len(chunk_encoded) >= 2
|
260 |
+
output_encode = []
|
261 |
+
for i, chunk in enumerate(chunk_encoded):
|
262 |
+
output_encode.extend(chunk)
|
263 |
+
if i < len(chunk_encoded) - 1:
|
264 |
+
output_encode.extend([image_token_idx] * image_length)
|
265 |
+
else:
|
266 |
+
output_encode = tokenizer.encode(output, add_special_tokens=False)
|
267 |
+
|
268 |
+
input_ids.extend(output_encode)
|
269 |
+
if output_with_loss:
|
270 |
+
labels.extend(output_encode.copy())
|
271 |
+
else:
|
272 |
+
labels.extend([IGNORE_INDEX] * len(output_encode))
|
273 |
+
|
274 |
+
# ==== Append EOS ====
|
275 |
+
if single_turn.get("need_eos_token", True):
|
276 |
+
next_needs_bos_token = True
|
277 |
+
input_ids.extend(eos_token_id)
|
278 |
+
if output_with_loss:
|
279 |
+
labels.extend(eos_token_id.copy())
|
280 |
+
else:
|
281 |
+
labels.extend([IGNORE_INDEX] * len(eos_token_id))
|
282 |
+
else:
|
283 |
+
next_needs_bos_token = False
|
284 |
+
|
285 |
+
# ==== Append separator ====
|
286 |
+
sep = single_turn.get("sep", "")
|
287 |
+
if sep:
|
288 |
+
sep_encoded = tokenizer.encode(sep, add_special_tokens=False)
|
289 |
+
input_ids.extend(sep_encoded)
|
290 |
+
labels.extend([IGNORE_INDEX] * len(sep_encoded))
|
291 |
+
|
292 |
+
# ==== Truncation ====
|
293 |
+
if max_length is not None and len(input_ids) > max_length:
|
294 |
+
if truncation == "right":
|
295 |
+
input_ids = input_ids[:max_length]
|
296 |
+
labels = labels[:max_length]
|
297 |
+
elif truncation == "left":
|
298 |
+
input_ids = input_ids[-max_length:]
|
299 |
+
labels = labels[-max_length:]
|
300 |
+
else:
|
301 |
+
raise ValueError("truncation must be 'left', 'right', or None")
|
302 |
+
|
303 |
+
return {"input_ids": input_ids, "labels": labels}
|
src/models/mar/decoder.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
from timm.models.vision_transformer import Block
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
|
8 |
+
class MARDecoder(nn.Module):
|
9 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
10 |
+
"""
|
11 |
+
def __init__(self, img_size=256, vae_stride=16,
|
12 |
+
patch_size=1,
|
13 |
+
# encoder_embed_dim=1024,
|
14 |
+
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
|
15 |
+
mlp_ratio=4.,
|
16 |
+
attn_dropout=0.1,
|
17 |
+
proj_dropout=0.1,
|
18 |
+
buffer_size=64,
|
19 |
+
grad_checkpointing=False,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
# --------------------------------------------------------------------------
|
24 |
+
# VAE
|
25 |
+
self.img_size = img_size
|
26 |
+
self.vae_stride = vae_stride
|
27 |
+
|
28 |
+
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
|
29 |
+
self.seq_len = self.seq_h * self.seq_w
|
30 |
+
|
31 |
+
self.grad_checkpointing = grad_checkpointing
|
32 |
+
|
33 |
+
# --------------------------------------------------------------------------
|
34 |
+
# MAR decoder specifics
|
35 |
+
self.buffer_size = buffer_size
|
36 |
+
# self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
37 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
38 |
+
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
|
39 |
+
self.decoder_blocks = nn.ModuleList([
|
40 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True,
|
41 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
42 |
+
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
|
43 |
+
|
44 |
+
self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6)
|
45 |
+
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
|
46 |
+
|
47 |
+
self.initialize_weights()
|
48 |
+
|
49 |
+
def initialize_weights(self):
|
50 |
+
# parameters
|
51 |
+
|
52 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
53 |
+
|
54 |
+
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
|
55 |
+
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
|
56 |
+
|
57 |
+
# initialize nn.Linear and nn.LayerNorm
|
58 |
+
self.apply(self._init_weights)
|
59 |
+
|
60 |
+
def _init_weights(self, m):
|
61 |
+
if isinstance(m, nn.Linear):
|
62 |
+
# we use xavier_uniform following official JAX ViT:
|
63 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
64 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
65 |
+
nn.init.constant_(m.bias, 0)
|
66 |
+
elif isinstance(m, nn.LayerNorm):
|
67 |
+
if m.bias is not None:
|
68 |
+
nn.init.constant_(m.bias, 0)
|
69 |
+
if m.weight is not None:
|
70 |
+
nn.init.constant_(m.weight, 1.0)
|
71 |
+
|
72 |
+
def forward(self, x, mask):
|
73 |
+
|
74 |
+
# x = self.decoder_embed(x)
|
75 |
+
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
|
76 |
+
|
77 |
+
# pad mask tokens
|
78 |
+
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
|
79 |
+
x_after_pad = mask_tokens.clone()
|
80 |
+
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
81 |
+
|
82 |
+
# decoder position embedding
|
83 |
+
x = x_after_pad + self.decoder_pos_embed_learned
|
84 |
+
|
85 |
+
# apply Transformer blocks
|
86 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
87 |
+
for block in self.decoder_blocks:
|
88 |
+
x = checkpoint(block, x)
|
89 |
+
else:
|
90 |
+
for block in self.decoder_blocks:
|
91 |
+
x = block(x)
|
92 |
+
x = self.decoder_norm(x)
|
93 |
+
|
94 |
+
x = x[:, self.buffer_size:]
|
95 |
+
x = x + self.diffusion_pos_embed_learned
|
96 |
+
return x
|
97 |
+
|
98 |
+
def gradient_checkpointing_enable(self):
|
99 |
+
self.grad_checkpointing = True
|
100 |
+
|
101 |
+
def gradient_checkpointing_disable(self):
|
102 |
+
self.grad_checkpointing = False
|
src/models/mar/diffloss.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
import math
|
5 |
+
|
6 |
+
from src.models.mar.diffusion import create_diffusion
|
7 |
+
|
8 |
+
|
9 |
+
class DiffLoss(nn.Module):
|
10 |
+
"""Diffusion Loss"""
|
11 |
+
def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
|
12 |
+
super(DiffLoss, self).__init__()
|
13 |
+
self.in_channels = target_channels
|
14 |
+
self.net = SimpleMLPAdaLN(
|
15 |
+
in_channels=target_channels,
|
16 |
+
model_channels=width,
|
17 |
+
out_channels=target_channels * 2, # for vlb loss
|
18 |
+
z_channels=z_channels,
|
19 |
+
num_res_blocks=depth,
|
20 |
+
grad_checkpointing=grad_checkpointing
|
21 |
+
)
|
22 |
+
|
23 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
|
24 |
+
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
|
25 |
+
|
26 |
+
def forward(self, target, z, mask=None):
|
27 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
28 |
+
model_kwargs = dict(c=z)
|
29 |
+
loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
|
30 |
+
loss = loss_dict["loss"]
|
31 |
+
if mask is not None:
|
32 |
+
loss = (loss * mask).sum() / mask.sum()
|
33 |
+
return loss.mean()
|
34 |
+
|
35 |
+
def sample(self, z, temperature=1.0, cfg=1.0):
|
36 |
+
# diffusion loss sampling
|
37 |
+
if not cfg == 1.0:
|
38 |
+
noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
|
39 |
+
noise = torch.cat([noise, noise], dim=0)
|
40 |
+
model_kwargs = dict(c=z, cfg_scale=cfg)
|
41 |
+
sample_fn = self.net.forward_with_cfg
|
42 |
+
else:
|
43 |
+
noise = torch.randn(z.shape[0], self.in_channels).cuda()
|
44 |
+
model_kwargs = dict(c=z)
|
45 |
+
sample_fn = self.net.forward
|
46 |
+
|
47 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
48 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
|
49 |
+
temperature=temperature
|
50 |
+
)
|
51 |
+
|
52 |
+
return sampled_token_latent
|
53 |
+
|
54 |
+
|
55 |
+
def modulate(x, shift, scale):
|
56 |
+
return x * (1 + scale) + shift
|
57 |
+
|
58 |
+
|
59 |
+
class TimestepEmbedder(nn.Module):
|
60 |
+
"""
|
61 |
+
Embeds scalar timesteps into vector representations.
|
62 |
+
"""
|
63 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
64 |
+
super().__init__()
|
65 |
+
self.mlp = nn.Sequential(
|
66 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
67 |
+
nn.SiLU(),
|
68 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
69 |
+
)
|
70 |
+
self.frequency_embedding_size = frequency_embedding_size
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def timestep_embedding(t, dim, max_period=10000):
|
74 |
+
"""
|
75 |
+
Create sinusoidal timestep embeddings.
|
76 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
77 |
+
These may be fractional.
|
78 |
+
:param dim: the dimension of the output.
|
79 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
80 |
+
:return: an (N, D) Tensor of positional embeddings.
|
81 |
+
"""
|
82 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
83 |
+
half = dim // 2
|
84 |
+
freqs = torch.exp(
|
85 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
86 |
+
).to(device=t.device)
|
87 |
+
args = t[:, None].float() * freqs[None]
|
88 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
89 |
+
if dim % 2:
|
90 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
91 |
+
return embedding
|
92 |
+
|
93 |
+
def forward(self, t):
|
94 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
95 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.data.dtype))
|
96 |
+
return t_emb
|
97 |
+
|
98 |
+
|
99 |
+
class ResBlock(nn.Module):
|
100 |
+
"""
|
101 |
+
A residual block that can optionally change the number of channels.
|
102 |
+
:param channels: the number of input channels.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
channels
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.channels = channels
|
111 |
+
|
112 |
+
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
113 |
+
self.mlp = nn.Sequential(
|
114 |
+
nn.Linear(channels, channels, bias=True),
|
115 |
+
nn.SiLU(),
|
116 |
+
nn.Linear(channels, channels, bias=True),
|
117 |
+
)
|
118 |
+
|
119 |
+
self.adaLN_modulation = nn.Sequential(
|
120 |
+
nn.SiLU(),
|
121 |
+
nn.Linear(channels, 3 * channels, bias=True)
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, x, y):
|
125 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
126 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
127 |
+
h = self.mlp(h)
|
128 |
+
return x + gate_mlp * h
|
129 |
+
|
130 |
+
|
131 |
+
class FinalLayer(nn.Module):
|
132 |
+
"""
|
133 |
+
The final layer adopted from DiT.
|
134 |
+
"""
|
135 |
+
def __init__(self, model_channels, out_channels):
|
136 |
+
super().__init__()
|
137 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
138 |
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
139 |
+
self.adaLN_modulation = nn.Sequential(
|
140 |
+
nn.SiLU(),
|
141 |
+
nn.Linear(model_channels, 2 * model_channels, bias=True)
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x, c):
|
145 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
146 |
+
x = modulate(self.norm_final(x), shift, scale)
|
147 |
+
x = self.linear(x)
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class SimpleMLPAdaLN(nn.Module):
|
152 |
+
"""
|
153 |
+
The MLP for Diffusion Loss.
|
154 |
+
:param in_channels: channels in the input Tensor.
|
155 |
+
:param model_channels: base channel count for the model.
|
156 |
+
:param out_channels: channels in the output Tensor.
|
157 |
+
:param z_channels: channels in the condition.
|
158 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
in_channels,
|
164 |
+
model_channels,
|
165 |
+
out_channels,
|
166 |
+
z_channels,
|
167 |
+
num_res_blocks,
|
168 |
+
grad_checkpointing=False
|
169 |
+
):
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
self.in_channels = in_channels
|
173 |
+
self.model_channels = model_channels
|
174 |
+
self.out_channels = out_channels
|
175 |
+
self.num_res_blocks = num_res_blocks
|
176 |
+
self.grad_checkpointing = grad_checkpointing
|
177 |
+
|
178 |
+
self.time_embed = TimestepEmbedder(model_channels)
|
179 |
+
self.cond_embed = nn.Linear(z_channels, model_channels)
|
180 |
+
|
181 |
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
182 |
+
|
183 |
+
res_blocks = []
|
184 |
+
for i in range(num_res_blocks):
|
185 |
+
res_blocks.append(ResBlock(
|
186 |
+
model_channels,
|
187 |
+
))
|
188 |
+
|
189 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
190 |
+
self.final_layer = FinalLayer(model_channels, out_channels)
|
191 |
+
|
192 |
+
self.initialize_weights()
|
193 |
+
|
194 |
+
def initialize_weights(self):
|
195 |
+
def _basic_init(module):
|
196 |
+
if isinstance(module, nn.Linear):
|
197 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
198 |
+
if module.bias is not None:
|
199 |
+
nn.init.constant_(module.bias, 0)
|
200 |
+
self.apply(_basic_init)
|
201 |
+
|
202 |
+
# Initialize timestep embedding MLP
|
203 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
204 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
205 |
+
|
206 |
+
# Zero-out adaLN modulation layers
|
207 |
+
for block in self.res_blocks:
|
208 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
209 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
210 |
+
|
211 |
+
# Zero-out output layers
|
212 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
213 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
214 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
215 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
216 |
+
|
217 |
+
def forward(self, x, t, c):
|
218 |
+
"""
|
219 |
+
Apply the model to an input batch.
|
220 |
+
:param x: an [N x C] Tensor of inputs.
|
221 |
+
:param t: a 1-D batch of timesteps.
|
222 |
+
:param c: conditioning from AR transformer.
|
223 |
+
:return: an [N x C] Tensor of outputs.
|
224 |
+
"""
|
225 |
+
# import pdb; pdb.set_trace()
|
226 |
+
x = self.input_proj(x.to(self.input_proj.weight.data.dtype))
|
227 |
+
t = self.time_embed(t)
|
228 |
+
c = self.cond_embed(c.to(self.cond_embed.weight.data.dtype))
|
229 |
+
|
230 |
+
y = t + c
|
231 |
+
|
232 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
233 |
+
for block in self.res_blocks:
|
234 |
+
x = checkpoint(block, x, y)
|
235 |
+
else:
|
236 |
+
for block in self.res_blocks:
|
237 |
+
x = block(x, y)
|
238 |
+
|
239 |
+
return self.final_layer(x, y)
|
240 |
+
|
241 |
+
def forward_with_cfg(self, x, t, c, cfg_scale):
|
242 |
+
half = x[: len(x) // 2]
|
243 |
+
combined = torch.cat([half, half], dim=0)
|
244 |
+
model_out = self.forward(combined, t, c)
|
245 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
246 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
247 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
248 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
249 |
+
return torch.cat([eps, rest], dim=1)
|
src/models/mar/diffusion/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from DiT, which is modified from OpenAI's diffusion repos
|
2 |
+
# DiT: https://github.com/facebookresearch/DiT/diffusion
|
3 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
4 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
5 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
6 |
+
|
7 |
+
from . import gaussian_diffusion as gd
|
8 |
+
from .respace import SpacedDiffusion, space_timesteps
|
9 |
+
|
10 |
+
|
11 |
+
def create_diffusion(
|
12 |
+
timestep_respacing,
|
13 |
+
noise_schedule="highres_cosine",
|
14 |
+
use_kl=False,
|
15 |
+
sigma_small=False,
|
16 |
+
predict_xstart=False,
|
17 |
+
learn_sigma=True,
|
18 |
+
rescale_learned_sigmas=False,
|
19 |
+
diffusion_steps=1000
|
20 |
+
):
|
21 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
22 |
+
if use_kl:
|
23 |
+
loss_type = gd.LossType.RESCALED_KL
|
24 |
+
elif rescale_learned_sigmas:
|
25 |
+
loss_type = gd.LossType.RESCALED_MSE
|
26 |
+
else:
|
27 |
+
loss_type = gd.LossType.MSE
|
28 |
+
if timestep_respacing is None or timestep_respacing == "":
|
29 |
+
timestep_respacing = [diffusion_steps]
|
30 |
+
return SpacedDiffusion(
|
31 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
32 |
+
betas=betas,
|
33 |
+
model_mean_type=(
|
34 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
35 |
+
),
|
36 |
+
model_var_type=(
|
37 |
+
(
|
38 |
+
gd.ModelVarType.FIXED_LARGE
|
39 |
+
if not sigma_small
|
40 |
+
else gd.ModelVarType.FIXED_SMALL
|
41 |
+
)
|
42 |
+
if not learn_sigma
|
43 |
+
else gd.ModelVarType.LEARNED_RANGE
|
44 |
+
),
|
45 |
+
loss_type=loss_type
|
46 |
+
# rescale_timesteps=rescale_timesteps,
|
47 |
+
)
|
src/models/mar/diffusion/diffusion_utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import torch as th
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [
|
26 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
27 |
+
for x in (logvar1, logvar2)
|
28 |
+
]
|
29 |
+
|
30 |
+
return 0.5 * (
|
31 |
+
-1.0
|
32 |
+
+ logvar2
|
33 |
+
- logvar1
|
34 |
+
+ th.exp(logvar1 - logvar2)
|
35 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def approx_standard_normal_cdf(x):
|
40 |
+
"""
|
41 |
+
A fast approximation of the cumulative distribution function of the
|
42 |
+
standard normal.
|
43 |
+
"""
|
44 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
45 |
+
|
46 |
+
|
47 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
48 |
+
"""
|
49 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
50 |
+
given image.
|
51 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
52 |
+
rescaled to the range [-1, 1].
|
53 |
+
:param means: the Gaussian mean Tensor.
|
54 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
55 |
+
:return: a tensor like x of log probabilities (in nats).
|
56 |
+
"""
|
57 |
+
assert x.shape == means.shape == log_scales.shape
|
58 |
+
centered_x = x - means
|
59 |
+
inv_stdv = th.exp(-log_scales)
|
60 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
61 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
62 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
63 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
64 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
65 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
66 |
+
cdf_delta = cdf_plus - cdf_min
|
67 |
+
log_probs = th.where(
|
68 |
+
x < -0.999,
|
69 |
+
log_cdf_plus,
|
70 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
71 |
+
)
|
72 |
+
assert log_probs.shape == x.shape
|
73 |
+
return log_probs
|
src/models/mar/diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
import enum
|
12 |
+
|
13 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
+
|
15 |
+
|
16 |
+
def mean_flat(tensor):
|
17 |
+
"""
|
18 |
+
Take the mean over all non-batch dimensions.
|
19 |
+
"""
|
20 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
21 |
+
|
22 |
+
|
23 |
+
class ModelMeanType(enum.Enum):
|
24 |
+
"""
|
25 |
+
Which type of output the model predicts.
|
26 |
+
"""
|
27 |
+
|
28 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
29 |
+
START_X = enum.auto() # the model predicts x_0
|
30 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
31 |
+
|
32 |
+
|
33 |
+
class ModelVarType(enum.Enum):
|
34 |
+
"""
|
35 |
+
What is used as the model's output variance.
|
36 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
37 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
38 |
+
"""
|
39 |
+
|
40 |
+
LEARNED = enum.auto()
|
41 |
+
FIXED_SMALL = enum.auto()
|
42 |
+
FIXED_LARGE = enum.auto()
|
43 |
+
LEARNED_RANGE = enum.auto()
|
44 |
+
|
45 |
+
|
46 |
+
class LossType(enum.Enum):
|
47 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
48 |
+
RESCALED_MSE = (
|
49 |
+
enum.auto()
|
50 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
51 |
+
KL = enum.auto() # use the variational lower-bound
|
52 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
53 |
+
|
54 |
+
def is_vb(self):
|
55 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
56 |
+
|
57 |
+
|
58 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
59 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
60 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
61 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
62 |
+
return betas
|
63 |
+
|
64 |
+
|
65 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
66 |
+
"""
|
67 |
+
This is the deprecated API for creating beta schedules.
|
68 |
+
See get_named_beta_schedule() for the new library of schedules.
|
69 |
+
"""
|
70 |
+
if beta_schedule == "quad":
|
71 |
+
betas = (
|
72 |
+
np.linspace(
|
73 |
+
beta_start ** 0.5,
|
74 |
+
beta_end ** 0.5,
|
75 |
+
num_diffusion_timesteps,
|
76 |
+
dtype=np.float64,
|
77 |
+
)
|
78 |
+
** 2
|
79 |
+
)
|
80 |
+
elif beta_schedule == "linear":
|
81 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
82 |
+
elif beta_schedule == "warmup10":
|
83 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
84 |
+
elif beta_schedule == "warmup50":
|
85 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
86 |
+
elif beta_schedule == "const":
|
87 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
88 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
89 |
+
betas = 1.0 / np.linspace(
|
90 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError(beta_schedule)
|
94 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
95 |
+
return betas
|
96 |
+
|
97 |
+
|
98 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
99 |
+
"""
|
100 |
+
Get a pre-defined beta schedule for the given name.
|
101 |
+
The beta schedule library consists of beta schedules which remain similar
|
102 |
+
in the limit of num_diffusion_timesteps.
|
103 |
+
Beta schedules may be added, but should not be removed or changed once
|
104 |
+
they are committed to maintain backwards compatibility.
|
105 |
+
"""
|
106 |
+
if schedule_name == "linear":
|
107 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
108 |
+
# diffusion steps.
|
109 |
+
scale = 1000 / num_diffusion_timesteps
|
110 |
+
return get_beta_schedule(
|
111 |
+
"linear",
|
112 |
+
beta_start=scale * 0.0001,
|
113 |
+
beta_end=scale * 0.02,
|
114 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
115 |
+
)
|
116 |
+
elif schedule_name == "cosine":
|
117 |
+
return betas_for_alpha_bar(
|
118 |
+
num_diffusion_timesteps,
|
119 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
120 |
+
)
|
121 |
+
elif schedule_name == "highres_cosine":
|
122 |
+
# Custom smoother cosine schedule for high-resolution diffusion
|
123 |
+
return betas_for_alpha_bar(
|
124 |
+
num_diffusion_timesteps,
|
125 |
+
lambda t: math.cos((t + 0.005) / 1.005 * math.pi / 2) ** 2,
|
126 |
+
max_beta=0.2, # conservative to avoid over-noising at high-res
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
130 |
+
|
131 |
+
|
132 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.2):
|
133 |
+
"""
|
134 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
135 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
136 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
137 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
138 |
+
produces the cumulative product of (1-beta) up to that
|
139 |
+
part of the diffusion process.
|
140 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
141 |
+
prevent singularities.
|
142 |
+
"""
|
143 |
+
betas = []
|
144 |
+
for i in range(num_diffusion_timesteps):
|
145 |
+
t1 = i / num_diffusion_timesteps
|
146 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
147 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
148 |
+
return np.array(betas)
|
149 |
+
|
150 |
+
|
151 |
+
class GaussianDiffusion:
|
152 |
+
"""
|
153 |
+
Utilities for training and sampling diffusion models.
|
154 |
+
Original ported from this codebase:
|
155 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
156 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
157 |
+
starting at T and going to 1.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
*,
|
163 |
+
betas,
|
164 |
+
model_mean_type,
|
165 |
+
model_var_type,
|
166 |
+
loss_type
|
167 |
+
):
|
168 |
+
|
169 |
+
self.model_mean_type = model_mean_type
|
170 |
+
self.model_var_type = model_var_type
|
171 |
+
self.loss_type = loss_type
|
172 |
+
|
173 |
+
# Use float64 for accuracy.
|
174 |
+
betas = np.array(betas, dtype=np.float64)
|
175 |
+
self.betas = betas
|
176 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
177 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
178 |
+
|
179 |
+
self.num_timesteps = int(betas.shape[0])
|
180 |
+
|
181 |
+
alphas = 1.0 - betas
|
182 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
183 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
184 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
185 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
186 |
+
|
187 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
188 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
189 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
190 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
191 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
192 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
193 |
+
|
194 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
195 |
+
self.posterior_variance = (
|
196 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
197 |
+
)
|
198 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
199 |
+
self.posterior_log_variance_clipped = np.log(
|
200 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
201 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
202 |
+
|
203 |
+
self.posterior_mean_coef1 = (
|
204 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
205 |
+
)
|
206 |
+
self.posterior_mean_coef2 = (
|
207 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
208 |
+
)
|
209 |
+
|
210 |
+
def q_mean_variance(self, x_start, t):
|
211 |
+
"""
|
212 |
+
Get the distribution q(x_t | x_0).
|
213 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
214 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
215 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
216 |
+
"""
|
217 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
218 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
219 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
220 |
+
return mean, variance, log_variance
|
221 |
+
|
222 |
+
def q_sample(self, x_start, t, noise=None):
|
223 |
+
"""
|
224 |
+
Diffuse the data for a given number of diffusion steps.
|
225 |
+
In other words, sample from q(x_t | x_0).
|
226 |
+
:param x_start: the initial data batch.
|
227 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
228 |
+
:param noise: if specified, the split-out normal noise.
|
229 |
+
:return: A noisy version of x_start.
|
230 |
+
"""
|
231 |
+
if noise is None:
|
232 |
+
noise = th.randn_like(x_start)
|
233 |
+
assert noise.shape == x_start.shape
|
234 |
+
return (
|
235 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
236 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
237 |
+
)
|
238 |
+
|
239 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
240 |
+
"""
|
241 |
+
Compute the mean and variance of the diffusion posterior:
|
242 |
+
q(x_{t-1} | x_t, x_0)
|
243 |
+
"""
|
244 |
+
assert x_start.shape == x_t.shape
|
245 |
+
posterior_mean = (
|
246 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
247 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
248 |
+
)
|
249 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
250 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
251 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
252 |
+
)
|
253 |
+
assert (
|
254 |
+
posterior_mean.shape[0]
|
255 |
+
== posterior_variance.shape[0]
|
256 |
+
== posterior_log_variance_clipped.shape[0]
|
257 |
+
== x_start.shape[0]
|
258 |
+
)
|
259 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
260 |
+
|
261 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
262 |
+
"""
|
263 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
264 |
+
the initial x, x_0.
|
265 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
266 |
+
as input.
|
267 |
+
:param x: the [N x C x ...] tensor at time t.
|
268 |
+
:param t: a 1-D Tensor of timesteps.
|
269 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
270 |
+
:param denoised_fn: if not None, a function which applies to the
|
271 |
+
x_start prediction before it is used to sample. Applies before
|
272 |
+
clip_denoised.
|
273 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
274 |
+
pass to the model. This can be used for conditioning.
|
275 |
+
:return: a dict with the following keys:
|
276 |
+
- 'mean': the model mean output.
|
277 |
+
- 'variance': the model variance output.
|
278 |
+
- 'log_variance': the log of 'variance'.
|
279 |
+
- 'pred_xstart': the prediction for x_0.
|
280 |
+
"""
|
281 |
+
if model_kwargs is None:
|
282 |
+
model_kwargs = {}
|
283 |
+
|
284 |
+
B, C = x.shape[:2]
|
285 |
+
assert t.shape == (B,)
|
286 |
+
model_output = model(x, t, **model_kwargs)
|
287 |
+
if isinstance(model_output, tuple):
|
288 |
+
model_output, extra = model_output
|
289 |
+
else:
|
290 |
+
extra = None
|
291 |
+
|
292 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
293 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
294 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
295 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
296 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
297 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
298 |
+
frac = (model_var_values + 1) / 2
|
299 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
300 |
+
model_variance = th.exp(model_log_variance)
|
301 |
+
else:
|
302 |
+
model_variance, model_log_variance = {
|
303 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
304 |
+
# to get a better decoder log likelihood.
|
305 |
+
ModelVarType.FIXED_LARGE: (
|
306 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
307 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
308 |
+
),
|
309 |
+
ModelVarType.FIXED_SMALL: (
|
310 |
+
self.posterior_variance,
|
311 |
+
self.posterior_log_variance_clipped,
|
312 |
+
),
|
313 |
+
}[self.model_var_type]
|
314 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
315 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
316 |
+
|
317 |
+
def process_xstart(x):
|
318 |
+
if denoised_fn is not None:
|
319 |
+
x = denoised_fn(x)
|
320 |
+
if clip_denoised:
|
321 |
+
return x.clamp(-1, 1)
|
322 |
+
return x
|
323 |
+
|
324 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
325 |
+
pred_xstart = process_xstart(model_output)
|
326 |
+
else:
|
327 |
+
pred_xstart = process_xstart(
|
328 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
329 |
+
)
|
330 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
331 |
+
|
332 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
333 |
+
return {
|
334 |
+
"mean": model_mean,
|
335 |
+
"variance": model_variance,
|
336 |
+
"log_variance": model_log_variance,
|
337 |
+
"pred_xstart": pred_xstart,
|
338 |
+
"extra": extra,
|
339 |
+
}
|
340 |
+
|
341 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
342 |
+
assert x_t.shape == eps.shape
|
343 |
+
return (
|
344 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
345 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
346 |
+
)
|
347 |
+
|
348 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
349 |
+
return (
|
350 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
351 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
352 |
+
|
353 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
354 |
+
"""
|
355 |
+
Compute the mean for the previous step, given a function cond_fn that
|
356 |
+
computes the gradient of a conditional log probability with respect to
|
357 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
358 |
+
condition on y.
|
359 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
360 |
+
"""
|
361 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
362 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
363 |
+
return new_mean
|
364 |
+
|
365 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
366 |
+
"""
|
367 |
+
Compute what the p_mean_variance output would have been, should the
|
368 |
+
model's score function be conditioned by cond_fn.
|
369 |
+
See condition_mean() for details on cond_fn.
|
370 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
371 |
+
from Song et al (2020).
|
372 |
+
"""
|
373 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
374 |
+
|
375 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
376 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
377 |
+
|
378 |
+
out = p_mean_var.copy()
|
379 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
380 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
381 |
+
return out
|
382 |
+
|
383 |
+
def p_sample(
|
384 |
+
self,
|
385 |
+
model,
|
386 |
+
x,
|
387 |
+
t,
|
388 |
+
clip_denoised=True,
|
389 |
+
denoised_fn=None,
|
390 |
+
cond_fn=None,
|
391 |
+
model_kwargs=None,
|
392 |
+
temperature=1.0
|
393 |
+
):
|
394 |
+
"""
|
395 |
+
Sample x_{t-1} from the model at the given timestep.
|
396 |
+
:param model: the model to sample from.
|
397 |
+
:param x: the current tensor at x_{t-1}.
|
398 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
399 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
400 |
+
:param denoised_fn: if not None, a function which applies to the
|
401 |
+
x_start prediction before it is used to sample.
|
402 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
403 |
+
similarly to the model.
|
404 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
405 |
+
pass to the model. This can be used for conditioning.
|
406 |
+
:param temperature: temperature scaling during Diff Loss sampling.
|
407 |
+
:return: a dict containing the following keys:
|
408 |
+
- 'sample': a random sample from the model.
|
409 |
+
- 'pred_xstart': a prediction of x_0.
|
410 |
+
"""
|
411 |
+
out = self.p_mean_variance(
|
412 |
+
model,
|
413 |
+
x,
|
414 |
+
t,
|
415 |
+
clip_denoised=clip_denoised,
|
416 |
+
denoised_fn=denoised_fn,
|
417 |
+
model_kwargs=model_kwargs,
|
418 |
+
)
|
419 |
+
noise = th.randn_like(x)
|
420 |
+
nonzero_mask = (
|
421 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
422 |
+
) # no noise when t == 0
|
423 |
+
if cond_fn is not None:
|
424 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
425 |
+
# scale the noise by temperature
|
426 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
|
427 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
428 |
+
|
429 |
+
def p_sample_loop(
|
430 |
+
self,
|
431 |
+
model,
|
432 |
+
shape,
|
433 |
+
noise=None,
|
434 |
+
clip_denoised=True,
|
435 |
+
denoised_fn=None,
|
436 |
+
cond_fn=None,
|
437 |
+
model_kwargs=None,
|
438 |
+
device=None,
|
439 |
+
progress=False,
|
440 |
+
temperature=1.0,
|
441 |
+
):
|
442 |
+
"""
|
443 |
+
Generate samples from the model.
|
444 |
+
:param model: the model module.
|
445 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
446 |
+
:param noise: if specified, the noise from the encoder to sample.
|
447 |
+
Should be of the same shape as `shape`.
|
448 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
449 |
+
:param denoised_fn: if not None, a function which applies to the
|
450 |
+
x_start prediction before it is used to sample.
|
451 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
452 |
+
similarly to the model.
|
453 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
454 |
+
pass to the model. This can be used for conditioning.
|
455 |
+
:param device: if specified, the device to create the samples on.
|
456 |
+
If not specified, use a model parameter's device.
|
457 |
+
:param progress: if True, show a tqdm progress bar.
|
458 |
+
:param temperature: temperature scaling during Diff Loss sampling.
|
459 |
+
:return: a non-differentiable batch of samples.
|
460 |
+
"""
|
461 |
+
final = None
|
462 |
+
for sample in self.p_sample_loop_progressive(
|
463 |
+
model,
|
464 |
+
shape,
|
465 |
+
noise=noise,
|
466 |
+
clip_denoised=clip_denoised,
|
467 |
+
denoised_fn=denoised_fn,
|
468 |
+
cond_fn=cond_fn,
|
469 |
+
model_kwargs=model_kwargs,
|
470 |
+
device=device,
|
471 |
+
progress=progress,
|
472 |
+
temperature=temperature,
|
473 |
+
):
|
474 |
+
final = sample
|
475 |
+
return final["sample"]
|
476 |
+
|
477 |
+
def p_sample_loop_progressive(
|
478 |
+
self,
|
479 |
+
model,
|
480 |
+
shape,
|
481 |
+
noise=None,
|
482 |
+
clip_denoised=True,
|
483 |
+
denoised_fn=None,
|
484 |
+
cond_fn=None,
|
485 |
+
model_kwargs=None,
|
486 |
+
device=None,
|
487 |
+
progress=False,
|
488 |
+
temperature=1.0,
|
489 |
+
):
|
490 |
+
"""
|
491 |
+
Generate samples from the model and yield intermediate samples from
|
492 |
+
each timestep of diffusion.
|
493 |
+
Arguments are the same as p_sample_loop().
|
494 |
+
Returns a generator over dicts, where each dict is the return value of
|
495 |
+
p_sample().
|
496 |
+
"""
|
497 |
+
assert isinstance(shape, (tuple, list))
|
498 |
+
if noise is not None:
|
499 |
+
img = noise
|
500 |
+
else:
|
501 |
+
img = th.randn(*shape).cuda()
|
502 |
+
indices = list(range(self.num_timesteps))[::-1]
|
503 |
+
|
504 |
+
if progress:
|
505 |
+
# Lazy import so that we don't depend on tqdm.
|
506 |
+
from tqdm.auto import tqdm
|
507 |
+
|
508 |
+
indices = tqdm(indices)
|
509 |
+
|
510 |
+
for i in indices:
|
511 |
+
t = th.tensor([i] * shape[0]).cuda()
|
512 |
+
with th.no_grad():
|
513 |
+
out = self.p_sample(
|
514 |
+
model,
|
515 |
+
img,
|
516 |
+
t,
|
517 |
+
clip_denoised=clip_denoised,
|
518 |
+
denoised_fn=denoised_fn,
|
519 |
+
cond_fn=cond_fn,
|
520 |
+
model_kwargs=model_kwargs,
|
521 |
+
temperature=temperature,
|
522 |
+
)
|
523 |
+
yield out
|
524 |
+
img = out["sample"]
|
525 |
+
|
526 |
+
def ddim_sample(
|
527 |
+
self,
|
528 |
+
model,
|
529 |
+
x,
|
530 |
+
t,
|
531 |
+
clip_denoised=True,
|
532 |
+
denoised_fn=None,
|
533 |
+
cond_fn=None,
|
534 |
+
model_kwargs=None,
|
535 |
+
eta=0.0,
|
536 |
+
):
|
537 |
+
"""
|
538 |
+
Sample x_{t-1} from the model using DDIM.
|
539 |
+
Same usage as p_sample().
|
540 |
+
"""
|
541 |
+
out = self.p_mean_variance(
|
542 |
+
model,
|
543 |
+
x,
|
544 |
+
t,
|
545 |
+
clip_denoised=clip_denoised,
|
546 |
+
denoised_fn=denoised_fn,
|
547 |
+
model_kwargs=model_kwargs,
|
548 |
+
)
|
549 |
+
if cond_fn is not None:
|
550 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
551 |
+
|
552 |
+
# Usually our model outputs epsilon, but we re-derive it
|
553 |
+
# in case we used x_start or x_prev prediction.
|
554 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
555 |
+
|
556 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
557 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
558 |
+
sigma = (
|
559 |
+
eta
|
560 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
561 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
562 |
+
)
|
563 |
+
# Equation 12.
|
564 |
+
noise = th.randn_like(x)
|
565 |
+
mean_pred = (
|
566 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
567 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
568 |
+
)
|
569 |
+
nonzero_mask = (
|
570 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
571 |
+
) # no noise when t == 0
|
572 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
573 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
574 |
+
|
575 |
+
def ddim_reverse_sample(
|
576 |
+
self,
|
577 |
+
model,
|
578 |
+
x,
|
579 |
+
t,
|
580 |
+
clip_denoised=True,
|
581 |
+
denoised_fn=None,
|
582 |
+
cond_fn=None,
|
583 |
+
model_kwargs=None,
|
584 |
+
eta=0.0,
|
585 |
+
):
|
586 |
+
"""
|
587 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
588 |
+
"""
|
589 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
590 |
+
out = self.p_mean_variance(
|
591 |
+
model,
|
592 |
+
x,
|
593 |
+
t,
|
594 |
+
clip_denoised=clip_denoised,
|
595 |
+
denoised_fn=denoised_fn,
|
596 |
+
model_kwargs=model_kwargs,
|
597 |
+
)
|
598 |
+
if cond_fn is not None:
|
599 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
600 |
+
# Usually our model outputs epsilon, but we re-derive it
|
601 |
+
# in case we used x_start or x_prev prediction.
|
602 |
+
eps = (
|
603 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
604 |
+
- out["pred_xstart"]
|
605 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
606 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
607 |
+
|
608 |
+
# Equation 12. reversed
|
609 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
610 |
+
|
611 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
612 |
+
|
613 |
+
def ddim_sample_loop(
|
614 |
+
self,
|
615 |
+
model,
|
616 |
+
shape,
|
617 |
+
noise=None,
|
618 |
+
clip_denoised=True,
|
619 |
+
denoised_fn=None,
|
620 |
+
cond_fn=None,
|
621 |
+
model_kwargs=None,
|
622 |
+
device=None,
|
623 |
+
progress=False,
|
624 |
+
eta=0.0,
|
625 |
+
):
|
626 |
+
"""
|
627 |
+
Generate samples from the model using DDIM.
|
628 |
+
Same usage as p_sample_loop().
|
629 |
+
"""
|
630 |
+
final = None
|
631 |
+
for sample in self.ddim_sample_loop_progressive(
|
632 |
+
model,
|
633 |
+
shape,
|
634 |
+
noise=noise,
|
635 |
+
clip_denoised=clip_denoised,
|
636 |
+
denoised_fn=denoised_fn,
|
637 |
+
cond_fn=cond_fn,
|
638 |
+
model_kwargs=model_kwargs,
|
639 |
+
device=device,
|
640 |
+
progress=progress,
|
641 |
+
eta=eta,
|
642 |
+
):
|
643 |
+
final = sample
|
644 |
+
return final["sample"]
|
645 |
+
|
646 |
+
def ddim_sample_loop_progressive(
|
647 |
+
self,
|
648 |
+
model,
|
649 |
+
shape,
|
650 |
+
noise=None,
|
651 |
+
clip_denoised=True,
|
652 |
+
denoised_fn=None,
|
653 |
+
cond_fn=None,
|
654 |
+
model_kwargs=None,
|
655 |
+
device=None,
|
656 |
+
progress=False,
|
657 |
+
eta=0.0,
|
658 |
+
):
|
659 |
+
"""
|
660 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
661 |
+
each timestep of DDIM.
|
662 |
+
Same usage as p_sample_loop_progressive().
|
663 |
+
"""
|
664 |
+
assert isinstance(shape, (tuple, list))
|
665 |
+
if noise is not None:
|
666 |
+
img = noise
|
667 |
+
else:
|
668 |
+
img = th.randn(*shape).cuda()
|
669 |
+
indices = list(range(self.num_timesteps))[::-1]
|
670 |
+
|
671 |
+
if progress:
|
672 |
+
# Lazy import so that we don't depend on tqdm.
|
673 |
+
from tqdm.auto import tqdm
|
674 |
+
|
675 |
+
indices = tqdm(indices)
|
676 |
+
|
677 |
+
for i in indices:
|
678 |
+
t = th.tensor([i] * shape[0]).cuda()
|
679 |
+
with th.no_grad():
|
680 |
+
out = self.ddim_sample(
|
681 |
+
model,
|
682 |
+
img,
|
683 |
+
t,
|
684 |
+
clip_denoised=clip_denoised,
|
685 |
+
denoised_fn=denoised_fn,
|
686 |
+
cond_fn=cond_fn,
|
687 |
+
model_kwargs=model_kwargs,
|
688 |
+
eta=eta,
|
689 |
+
)
|
690 |
+
yield out
|
691 |
+
img = out["sample"]
|
692 |
+
|
693 |
+
def _vb_terms_bpd(
|
694 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
695 |
+
):
|
696 |
+
"""
|
697 |
+
Get a term for the variational lower-bound.
|
698 |
+
The resulting units are bits (rather than nats, as one might expect).
|
699 |
+
This allows for comparison to other papers.
|
700 |
+
:return: a dict with the following keys:
|
701 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
702 |
+
- 'pred_xstart': the x_0 predictions.
|
703 |
+
"""
|
704 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
705 |
+
x_start=x_start, x_t=x_t, t=t
|
706 |
+
)
|
707 |
+
out = self.p_mean_variance(
|
708 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
709 |
+
)
|
710 |
+
kl = normal_kl(
|
711 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
712 |
+
)
|
713 |
+
kl = mean_flat(kl) / np.log(2.0)
|
714 |
+
|
715 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
716 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
717 |
+
)
|
718 |
+
assert decoder_nll.shape == x_start.shape
|
719 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
720 |
+
|
721 |
+
# At the first timestep return the decoder NLL,
|
722 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
723 |
+
output = th.where((t == 0), decoder_nll, kl)
|
724 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
725 |
+
|
726 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
727 |
+
"""
|
728 |
+
Compute training losses for a single timestep.
|
729 |
+
:param model: the model to evaluate loss on.
|
730 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
731 |
+
:param t: a batch of timestep indices.
|
732 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
733 |
+
pass to the model. This can be used for conditioning.
|
734 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
735 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
736 |
+
Some mean or variance settings may also have other keys.
|
737 |
+
"""
|
738 |
+
if model_kwargs is None:
|
739 |
+
model_kwargs = {}
|
740 |
+
if noise is None:
|
741 |
+
noise = th.randn_like(x_start)
|
742 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
743 |
+
|
744 |
+
terms = {}
|
745 |
+
|
746 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
747 |
+
terms["loss"] = self._vb_terms_bpd(
|
748 |
+
model=model,
|
749 |
+
x_start=x_start,
|
750 |
+
x_t=x_t,
|
751 |
+
t=t,
|
752 |
+
clip_denoised=False,
|
753 |
+
model_kwargs=model_kwargs,
|
754 |
+
)["output"]
|
755 |
+
if self.loss_type == LossType.RESCALED_KL:
|
756 |
+
terms["loss"] *= self.num_timesteps
|
757 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
758 |
+
model_output = model(x_t, t, **model_kwargs)
|
759 |
+
|
760 |
+
if self.model_var_type in [
|
761 |
+
ModelVarType.LEARNED,
|
762 |
+
ModelVarType.LEARNED_RANGE,
|
763 |
+
]:
|
764 |
+
B, C = x_t.shape[:2]
|
765 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
766 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
767 |
+
# Learn the variance using the variational bound, but don't let
|
768 |
+
# it affect our mean prediction.
|
769 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
770 |
+
terms["vb"] = self._vb_terms_bpd(
|
771 |
+
model=lambda *args, r=frozen_out: r,
|
772 |
+
x_start=x_start,
|
773 |
+
x_t=x_t,
|
774 |
+
t=t,
|
775 |
+
clip_denoised=False,
|
776 |
+
)["output"]
|
777 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
778 |
+
# Divide by 1000 for equivalence with initial implementation.
|
779 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
780 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
781 |
+
|
782 |
+
target = {
|
783 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
784 |
+
x_start=x_start, x_t=x_t, t=t
|
785 |
+
)[0],
|
786 |
+
ModelMeanType.START_X: x_start,
|
787 |
+
ModelMeanType.EPSILON: noise,
|
788 |
+
}[self.model_mean_type]
|
789 |
+
assert model_output.shape == target.shape == x_start.shape
|
790 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
791 |
+
if "vb" in terms:
|
792 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
793 |
+
else:
|
794 |
+
terms["loss"] = terms["mse"]
|
795 |
+
else:
|
796 |
+
raise NotImplementedError(self.loss_type)
|
797 |
+
|
798 |
+
return terms
|
799 |
+
|
800 |
+
def _prior_bpd(self, x_start):
|
801 |
+
"""
|
802 |
+
Get the prior KL term for the variational lower-bound, measured in
|
803 |
+
bits-per-dim.
|
804 |
+
This term can't be optimized, as it only depends on the encoder.
|
805 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
806 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
807 |
+
"""
|
808 |
+
batch_size = x_start.shape[0]
|
809 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
810 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
811 |
+
kl_prior = normal_kl(
|
812 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
813 |
+
)
|
814 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
815 |
+
|
816 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
817 |
+
"""
|
818 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
819 |
+
as well as other related quantities.
|
820 |
+
:param model: the model to evaluate loss on.
|
821 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
822 |
+
:param clip_denoised: if True, clip denoised samples.
|
823 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
824 |
+
pass to the model. This can be used for conditioning.
|
825 |
+
:return: a dict containing the following keys:
|
826 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
827 |
+
- prior_bpd: the prior term in the lower-bound.
|
828 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
829 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
830 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
831 |
+
"""
|
832 |
+
device = x_start.device
|
833 |
+
batch_size = x_start.shape[0]
|
834 |
+
|
835 |
+
vb = []
|
836 |
+
xstart_mse = []
|
837 |
+
mse = []
|
838 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
839 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
840 |
+
noise = th.randn_like(x_start)
|
841 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
842 |
+
# Calculate VLB term at the current timestep
|
843 |
+
with th.no_grad():
|
844 |
+
out = self._vb_terms_bpd(
|
845 |
+
model,
|
846 |
+
x_start=x_start,
|
847 |
+
x_t=x_t,
|
848 |
+
t=t_batch,
|
849 |
+
clip_denoised=clip_denoised,
|
850 |
+
model_kwargs=model_kwargs,
|
851 |
+
)
|
852 |
+
vb.append(out["output"])
|
853 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
854 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
855 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
856 |
+
|
857 |
+
vb = th.stack(vb, dim=1)
|
858 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
859 |
+
mse = th.stack(mse, dim=1)
|
860 |
+
|
861 |
+
prior_bpd = self._prior_bpd(x_start)
|
862 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
863 |
+
return {
|
864 |
+
"total_bpd": total_bpd,
|
865 |
+
"prior_bpd": prior_bpd,
|
866 |
+
"vb": vb,
|
867 |
+
"xstart_mse": xstart_mse,
|
868 |
+
"mse": mse,
|
869 |
+
}
|
870 |
+
|
871 |
+
|
872 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
873 |
+
"""
|
874 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
875 |
+
:param arr: the 1-D numpy array.
|
876 |
+
:param timesteps: a tensor of indices into the array to extract.
|
877 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
878 |
+
dimension equal to the length of timesteps.
|
879 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
880 |
+
"""
|
881 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
882 |
+
while len(res.shape) < len(broadcast_shape):
|
883 |
+
res = res[..., None]
|
884 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
src/models/mar/diffusion/respace.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
if section_count <= 1:
|
52 |
+
frac_stride = 1
|
53 |
+
else:
|
54 |
+
frac_stride = (size - 1) / (section_count - 1)
|
55 |
+
cur_idx = 0.0
|
56 |
+
taken_steps = []
|
57 |
+
for _ in range(section_count):
|
58 |
+
taken_steps.append(start_idx + round(cur_idx))
|
59 |
+
cur_idx += frac_stride
|
60 |
+
all_steps += taken_steps
|
61 |
+
start_idx += size
|
62 |
+
return set(all_steps)
|
63 |
+
|
64 |
+
|
65 |
+
class SpacedDiffusion(GaussianDiffusion):
|
66 |
+
"""
|
67 |
+
A diffusion process which can skip steps in a base diffusion process.
|
68 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
69 |
+
original diffusion process to retain.
|
70 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, use_timesteps, **kwargs):
|
74 |
+
self.use_timesteps = set(use_timesteps)
|
75 |
+
self.timestep_map = []
|
76 |
+
self.original_num_steps = len(kwargs["betas"])
|
77 |
+
|
78 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
79 |
+
last_alpha_cumprod = 1.0
|
80 |
+
new_betas = []
|
81 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
82 |
+
if i in self.use_timesteps:
|
83 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
84 |
+
last_alpha_cumprod = alpha_cumprod
|
85 |
+
self.timestep_map.append(i)
|
86 |
+
kwargs["betas"] = np.array(new_betas)
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
def p_mean_variance(
|
90 |
+
self, model, *args, **kwargs
|
91 |
+
): # pylint: disable=signature-differs
|
92 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
93 |
+
|
94 |
+
def training_losses(
|
95 |
+
self, model, *args, **kwargs
|
96 |
+
): # pylint: disable=signature-differs
|
97 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
98 |
+
|
99 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
100 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
101 |
+
|
102 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
103 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
104 |
+
|
105 |
+
def _wrap_model(self, model):
|
106 |
+
if isinstance(model, _WrappedModel):
|
107 |
+
return model
|
108 |
+
return _WrappedModel(
|
109 |
+
model, self.timestep_map, self.original_num_steps
|
110 |
+
)
|
111 |
+
|
112 |
+
def _scale_timesteps(self, t):
|
113 |
+
# Scaling is done by the wrapped model.
|
114 |
+
return t
|
115 |
+
|
116 |
+
|
117 |
+
class _WrappedModel:
|
118 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
119 |
+
self.model = model
|
120 |
+
self.timestep_map = timestep_map
|
121 |
+
# self.rescale_timesteps = rescale_timesteps
|
122 |
+
self.original_num_steps = original_num_steps
|
123 |
+
|
124 |
+
def __call__(self, x, ts, **kwargs):
|
125 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
126 |
+
new_ts = map_tensor[ts]
|
127 |
+
# if self.rescale_timesteps:
|
128 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
129 |
+
return self.model(x, new_ts, **kwargs)
|
src/models/mar/engine_mar.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import src.models.mar.misc as misc
|
3 |
+
import torch_fidelity
|
4 |
+
import shutil
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
|
10 |
+
|
11 |
+
def torch_evaluate(model, args):
|
12 |
+
model.eval()
|
13 |
+
num_steps = args.num_images // (args.batch_size * misc.get_world_size()) + 1
|
14 |
+
save_folder = os.path.join(args.output_dir, "ariter{}-temp{}-{}cfg{}-image{}".format(
|
15 |
+
args.num_iter, args.temperature, args.cfg_schedule, args.cfg, args.num_images))
|
16 |
+
|
17 |
+
print("Save to:", save_folder)
|
18 |
+
if misc.get_rank() == 0:
|
19 |
+
if not os.path.exists(save_folder):
|
20 |
+
os.makedirs(save_folder)
|
21 |
+
|
22 |
+
class_num = args.class_num
|
23 |
+
assert args.num_images % class_num == 0 # number of images per class must be the same
|
24 |
+
class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
|
25 |
+
class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
|
26 |
+
world_size = misc.get_world_size()
|
27 |
+
local_rank = misc.get_rank()
|
28 |
+
used_time = 0
|
29 |
+
gen_img_cnt = 0
|
30 |
+
|
31 |
+
for i in range(num_steps):
|
32 |
+
print("Generation step {}/{}".format(i, num_steps))
|
33 |
+
|
34 |
+
labels_gen = class_label_gen_world[world_size * args.batch_size * i + local_rank * args.batch_size:
|
35 |
+
world_size * args.batch_size * i + (local_rank + 1) * args.batch_size]
|
36 |
+
labels_gen = torch.Tensor(labels_gen).long().cuda()
|
37 |
+
|
38 |
+
torch.cuda.synchronize()
|
39 |
+
start_time = time.time()
|
40 |
+
|
41 |
+
# generation
|
42 |
+
with torch.no_grad():
|
43 |
+
with torch.cuda.amp.autocast():
|
44 |
+
# sampled_images = model.sample_official(bsz=args.batch_size, num_iter=args.num_iter, cfg=args.cfg,
|
45 |
+
# cfg_schedule=args.cfg_schedule, labels=labels_gen,
|
46 |
+
# temperature=args.temperature)
|
47 |
+
|
48 |
+
import pdb; pdb.set_trace()
|
49 |
+
if args.cfg != 1.0:
|
50 |
+
labels_gen = torch.cat([
|
51 |
+
labels_gen, torch.full_like(labels_gen, fill_value=-1)])
|
52 |
+
sampled_images = model.sample(labels_gen,
|
53 |
+
num_iter=args.num_iter, cfg=args.cfg, cfg_schedule=args.cfg_schedule,
|
54 |
+
temperature=args.temperature, progress=False)
|
55 |
+
|
56 |
+
# measure speed after the first generation batch
|
57 |
+
if i >= 1:
|
58 |
+
torch.cuda.synchronize()
|
59 |
+
used_time += time.time() - start_time
|
60 |
+
gen_img_cnt += args.batch_size
|
61 |
+
print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))
|
62 |
+
|
63 |
+
torch.distributed.barrier()
|
64 |
+
sampled_images = sampled_images.detach().cpu()
|
65 |
+
sampled_images = (sampled_images + 1) / 2
|
66 |
+
|
67 |
+
# distributed save
|
68 |
+
for b_id in range(sampled_images.size(0)):
|
69 |
+
img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
|
70 |
+
if img_id >= args.num_images:
|
71 |
+
break
|
72 |
+
gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
|
73 |
+
gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
|
74 |
+
cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)
|
75 |
+
|
76 |
+
torch.distributed.barrier()
|
77 |
+
time.sleep(10)
|
78 |
+
if misc.get_rank() == 0:
|
79 |
+
input2 = None
|
80 |
+
fid_statistics_file = 'fid_stats/adm_in256_stats.npz'
|
81 |
+
metrics_dict = torch_fidelity.calculate_metrics(
|
82 |
+
input1=save_folder,
|
83 |
+
input2=input2,
|
84 |
+
fid_statistics_file=fid_statistics_file,
|
85 |
+
cuda=True,
|
86 |
+
isc=True,
|
87 |
+
fid=True,
|
88 |
+
kid=False,
|
89 |
+
prc=False,
|
90 |
+
verbose=True,
|
91 |
+
)
|
92 |
+
fid = metrics_dict['frechet_inception_distance']
|
93 |
+
inception_score = metrics_dict['inception_score_mean']
|
94 |
+
print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
|
95 |
+
# remove temporal saving folder
|
96 |
+
shutil.rmtree(save_folder)
|
97 |
+
|
98 |
+
torch.distributed.barrier()
|
99 |
+
time.sleep(10)
|
src/models/mar/mar.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import scipy.stats as stats
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
from timm.models.vision_transformer import Block
|
13 |
+
|
14 |
+
from .diffloss import DiffLoss
|
15 |
+
|
16 |
+
|
17 |
+
def mask_by_order(mask_len, order, bsz, seq_len):
|
18 |
+
masking = torch.zeros(bsz, seq_len).to(order.device)
|
19 |
+
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
|
20 |
+
src=torch.ones(bsz, seq_len).to(order.device)).bool()
|
21 |
+
return masking
|
22 |
+
|
23 |
+
|
24 |
+
class MAR(nn.Module):
|
25 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
26 |
+
"""
|
27 |
+
def __init__(self, img_size=256, vae_stride=16, patch_size=1,
|
28 |
+
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
|
29 |
+
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
|
30 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm,
|
31 |
+
vae_embed_dim=16,
|
32 |
+
mask_ratio_min=0.7,
|
33 |
+
label_drop_prob=0.1,
|
34 |
+
class_num=1000,
|
35 |
+
attn_dropout=0.1,
|
36 |
+
proj_dropout=0.1,
|
37 |
+
buffer_size=64,
|
38 |
+
diffloss_d=3,
|
39 |
+
diffloss_w=1024,
|
40 |
+
num_sampling_steps='100',
|
41 |
+
diffusion_batch_mul=4,
|
42 |
+
grad_checkpointing=False,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
# --------------------------------------------------------------------------
|
47 |
+
# VAE and patchify specifics
|
48 |
+
self.vae_embed_dim = vae_embed_dim
|
49 |
+
|
50 |
+
self.img_size = img_size
|
51 |
+
self.vae_stride = vae_stride
|
52 |
+
self.patch_size = patch_size
|
53 |
+
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
|
54 |
+
self.seq_len = self.seq_h * self.seq_w
|
55 |
+
self.token_embed_dim = vae_embed_dim * patch_size**2
|
56 |
+
self.grad_checkpointing = grad_checkpointing
|
57 |
+
|
58 |
+
# --------------------------------------------------------------------------
|
59 |
+
# Class Embedding
|
60 |
+
self.num_classes = class_num
|
61 |
+
self.class_emb = nn.Embedding(class_num, encoder_embed_dim)
|
62 |
+
self.label_drop_prob = label_drop_prob
|
63 |
+
# Fake class embedding for CFG's unconditional generation
|
64 |
+
self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
|
65 |
+
|
66 |
+
# --------------------------------------------------------------------------
|
67 |
+
# MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
|
68 |
+
self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
|
69 |
+
|
70 |
+
# --------------------------------------------------------------------------
|
71 |
+
# MAR encoder specifics
|
72 |
+
self.encoder_embed_dim = encoder_embed_dim
|
73 |
+
self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
|
74 |
+
self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
|
75 |
+
self.buffer_size = buffer_size
|
76 |
+
self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
|
77 |
+
|
78 |
+
self.encoder_blocks = nn.ModuleList([
|
79 |
+
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
80 |
+
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
|
81 |
+
self.encoder_norm = norm_layer(encoder_embed_dim)
|
82 |
+
|
83 |
+
# --------------------------------------------------------------------------
|
84 |
+
# MAR decoder specifics
|
85 |
+
self.decoder_embed_dim = decoder_embed_dim
|
86 |
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
87 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
88 |
+
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
|
89 |
+
|
90 |
+
self.decoder_blocks = nn.ModuleList([
|
91 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
92 |
+
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
|
93 |
+
|
94 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
95 |
+
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
|
96 |
+
|
97 |
+
self.initialize_weights()
|
98 |
+
|
99 |
+
# --------------------------------------------------------------------------
|
100 |
+
# Diffusion Loss
|
101 |
+
self.diffloss = DiffLoss(
|
102 |
+
target_channels=self.token_embed_dim,
|
103 |
+
z_channels=decoder_embed_dim,
|
104 |
+
width=diffloss_w,
|
105 |
+
depth=diffloss_d,
|
106 |
+
num_sampling_steps=num_sampling_steps,
|
107 |
+
grad_checkpointing=self.grad_checkpointing
|
108 |
+
)
|
109 |
+
self.diffusion_batch_mul = diffusion_batch_mul
|
110 |
+
|
111 |
+
def get_encoder_pos_embed(self, h, w):
|
112 |
+
if h == self.seq_h and w == self.seq_w:
|
113 |
+
return self.encoder_pos_embed_learned
|
114 |
+
buffer_pe, image_pe = self.encoder_pos_embed_learned.split(
|
115 |
+
[self.buffer_size, self.seq_len], dim=1)
|
116 |
+
image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
|
117 |
+
h=self.seq_h, w=self.seq_w)
|
118 |
+
image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
|
119 |
+
image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
|
120 |
+
|
121 |
+
return torch.cat([buffer_pe, image_pe], dim=1)
|
122 |
+
|
123 |
+
def get_decoder_pos_embed(self, h, w):
|
124 |
+
if h == self.seq_h and w == self.seq_w:
|
125 |
+
return self.decoder_pos_embed_learned
|
126 |
+
buffer_pe, image_pe = self.decoder_pos_embed_learned.split(
|
127 |
+
[self.buffer_size, self.seq_len], dim=1)
|
128 |
+
image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
|
129 |
+
h=self.seq_h, w=self.seq_w)
|
130 |
+
image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
|
131 |
+
image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
|
132 |
+
|
133 |
+
return torch.cat([buffer_pe, image_pe], dim=1)
|
134 |
+
|
135 |
+
def get_diffusion_pos_embed(self, h, w):
|
136 |
+
if h == self.seq_h and w == self.seq_w:
|
137 |
+
return self.diffusion_pos_embed_learned
|
138 |
+
image_pe = self.diffusion_pos_embed_learned
|
139 |
+
image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
|
140 |
+
h=self.seq_h, w=self.seq_w)
|
141 |
+
image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
|
142 |
+
image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
|
143 |
+
|
144 |
+
return image_pe
|
145 |
+
|
146 |
+
def initialize_weights(self):
|
147 |
+
# parameters
|
148 |
+
torch.nn.init.normal_(self.class_emb.weight, std=.02)
|
149 |
+
torch.nn.init.normal_(self.fake_latent, std=.02)
|
150 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
151 |
+
torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
|
152 |
+
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
|
153 |
+
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
|
154 |
+
|
155 |
+
# initialize nn.Linear and nn.LayerNorm
|
156 |
+
self.apply(self._init_weights)
|
157 |
+
|
158 |
+
def _init_weights(self, m):
|
159 |
+
if isinstance(m, nn.Linear):
|
160 |
+
# we use xavier_uniform following official JAX ViT:
|
161 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
162 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
163 |
+
nn.init.constant_(m.bias, 0)
|
164 |
+
elif isinstance(m, nn.LayerNorm):
|
165 |
+
if m.bias is not None:
|
166 |
+
nn.init.constant_(m.bias, 0)
|
167 |
+
if m.weight is not None:
|
168 |
+
nn.init.constant_(m.weight, 1.0)
|
169 |
+
|
170 |
+
@property
|
171 |
+
def device(self):
|
172 |
+
return self.fake_latent.data.device
|
173 |
+
|
174 |
+
@property
|
175 |
+
def dtype(self):
|
176 |
+
return self.fake_latent.data.dtype
|
177 |
+
|
178 |
+
def patchify(self, x):
|
179 |
+
bsz, c, h, w = x.shape
|
180 |
+
p = self.patch_size
|
181 |
+
h_, w_ = h // p, w // p
|
182 |
+
|
183 |
+
x = x.reshape(bsz, c, h_, p, w_, p)
|
184 |
+
x = torch.einsum('nchpwq->nhwcpq', x)
|
185 |
+
x = x.reshape(bsz, h_ * w_, c * p ** 2)
|
186 |
+
return x # [n, l, d]
|
187 |
+
|
188 |
+
def unpatchify(self, x):
|
189 |
+
bsz = x.shape[0]
|
190 |
+
p = self.patch_size
|
191 |
+
c = self.vae_embed_dim
|
192 |
+
h_, w_ = self.seq_h, self.seq_w
|
193 |
+
|
194 |
+
x = x.reshape(bsz, h_, w_, c, p, p)
|
195 |
+
x = torch.einsum('nhwcpq->nchpwq', x)
|
196 |
+
x = x.reshape(bsz, c, h_ * p, w_ * p)
|
197 |
+
return x # [n, c, h, w]
|
198 |
+
|
199 |
+
def sample_orders(self, bsz, seq_len=None):
|
200 |
+
if seq_len is None:
|
201 |
+
seq_len = self.seq_len
|
202 |
+
# generate a batch of random generation orders
|
203 |
+
orders = []
|
204 |
+
for _ in range(bsz):
|
205 |
+
order = np.array(list(range(seq_len)))
|
206 |
+
np.random.shuffle(order)
|
207 |
+
orders.append(order)
|
208 |
+
orders = torch.Tensor(np.array(orders)).to(self.device).long()
|
209 |
+
return orders
|
210 |
+
|
211 |
+
def random_masking(self, x, orders):
|
212 |
+
# generate token mask
|
213 |
+
bsz, seq_len, embed_dim = x.shape
|
214 |
+
assert seq_len == orders.shape[1]
|
215 |
+
mask_rate = self.mask_ratio_generator.rvs(1)[0]
|
216 |
+
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
|
217 |
+
mask = torch.zeros(bsz, seq_len, device=x.device)
|
218 |
+
mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
|
219 |
+
src=torch.ones(bsz, seq_len, device=x.device))
|
220 |
+
return mask
|
221 |
+
|
222 |
+
def forward_mae_encoder(self, x, mask, class_embedding, image_shape=None):
|
223 |
+
x = x.to(self.dtype)
|
224 |
+
x = self.z_proj(x)
|
225 |
+
bsz, seq_len, embed_dim = x.shape
|
226 |
+
|
227 |
+
# concat buffer
|
228 |
+
x = torch.cat([x.new_zeros(bsz, self.buffer_size, embed_dim), x], dim=1)
|
229 |
+
mask_with_buffer = torch.cat([mask.new_zeros(x.size(0), self.buffer_size), mask], dim=1)
|
230 |
+
|
231 |
+
# random drop class embedding during training
|
232 |
+
# if self.training:
|
233 |
+
# drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
|
234 |
+
# drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(self.device).to(x.dtype)
|
235 |
+
# class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
|
236 |
+
|
237 |
+
x[:, :self.buffer_size] = class_embedding.view(bsz, -1, embed_dim)
|
238 |
+
|
239 |
+
# encoder position embedding
|
240 |
+
# x = x + self.encoder_pos_embed_learned
|
241 |
+
if image_shape is None:
|
242 |
+
x = x + self.encoder_pos_embed_learned
|
243 |
+
else:
|
244 |
+
h, w = image_shape
|
245 |
+
assert h * w == seq_len
|
246 |
+
x = x + self.get_encoder_pos_embed(h=h, w=w)
|
247 |
+
# import pdb; pdb.set_trace()
|
248 |
+
x = self.z_proj_ln(x)
|
249 |
+
|
250 |
+
# dropping
|
251 |
+
x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
|
252 |
+
|
253 |
+
# apply Transformer blocks
|
254 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
255 |
+
for block in self.encoder_blocks:
|
256 |
+
x = checkpoint(block, x,
|
257 |
+
use_reentrant=False
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
for block in self.encoder_blocks:
|
261 |
+
x = block(x)
|
262 |
+
x = self.encoder_norm(x)
|
263 |
+
|
264 |
+
return x
|
265 |
+
|
266 |
+
def forward_mae_decoder(self, x, mask, image_shape=None, x_con=None):
|
267 |
+
bsz, seq_len = mask.shape
|
268 |
+
|
269 |
+
x = self.decoder_embed(x)
|
270 |
+
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
|
271 |
+
|
272 |
+
# pad mask tokens
|
273 |
+
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
|
274 |
+
|
275 |
+
if x_con is not None:
|
276 |
+
x_after_pad = self.decoder_embed(x_con)
|
277 |
+
else:
|
278 |
+
x_after_pad = mask_tokens.clone()
|
279 |
+
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
280 |
+
|
281 |
+
# decoder position embedding
|
282 |
+
# x = x_after_pad + self.decoder_pos_embed_learned
|
283 |
+
if image_shape is None:
|
284 |
+
x = x_after_pad + self.decoder_pos_embed_learned
|
285 |
+
else:
|
286 |
+
h, w = image_shape
|
287 |
+
assert h * w == seq_len
|
288 |
+
x = x_after_pad + self.get_decoder_pos_embed(h=h, w=w)
|
289 |
+
|
290 |
+
# apply Transformer blocks
|
291 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
292 |
+
for block in self.decoder_blocks:
|
293 |
+
x = checkpoint(block, x,
|
294 |
+
# use_reentrant=False
|
295 |
+
)
|
296 |
+
else:
|
297 |
+
for block in self.decoder_blocks:
|
298 |
+
x = block(x)
|
299 |
+
x = self.decoder_norm(x)
|
300 |
+
|
301 |
+
x = x[:, self.buffer_size:]
|
302 |
+
# x = x + self.diffusion_pos_embed_learned
|
303 |
+
if image_shape is None:
|
304 |
+
x = x + self.diffusion_pos_embed_learned
|
305 |
+
else:
|
306 |
+
h, w = image_shape
|
307 |
+
assert h * w == seq_len
|
308 |
+
x = x + self.get_diffusion_pos_embed(h=h, w=w)
|
309 |
+
return x
|
310 |
+
|
311 |
+
def mae_decoder_prepare(self, x, mask):
|
312 |
+
x = self.decoder_embed(x)
|
313 |
+
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
|
314 |
+
|
315 |
+
# pad mask tokens
|
316 |
+
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
|
317 |
+
x_after_pad = mask_tokens.clone()
|
318 |
+
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
319 |
+
|
320 |
+
# decoder position embedding
|
321 |
+
x = x_after_pad + self.decoder_pos_embed_learned
|
322 |
+
|
323 |
+
return x
|
324 |
+
|
325 |
+
def mae_decoder_forward(self, x):
|
326 |
+
# apply Transformer blocks
|
327 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
328 |
+
for block in self.decoder_blocks:
|
329 |
+
x = checkpoint(block, x,
|
330 |
+
# use_reentrant=False
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
for block in self.decoder_blocks:
|
334 |
+
x = block(x)
|
335 |
+
x = self.decoder_norm(x)
|
336 |
+
|
337 |
+
x = x[:, self.buffer_size:]
|
338 |
+
x = x + self.diffusion_pos_embed_learned
|
339 |
+
return x
|
340 |
+
|
341 |
+
def forward_loss(self, z, target, mask):
|
342 |
+
bsz, seq_len, _ = target.shape
|
343 |
+
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
|
344 |
+
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
|
345 |
+
mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
|
346 |
+
loss = self.diffloss(z=z, target=target, mask=mask)
|
347 |
+
return loss
|
348 |
+
|
349 |
+
def forward(self, imgs, labels):
|
350 |
+
|
351 |
+
# class embed
|
352 |
+
class_embedding = self.class_emb(labels)
|
353 |
+
|
354 |
+
# patchify and mask (drop) tokens
|
355 |
+
x = self.patchify(imgs)
|
356 |
+
gt_latents = x.clone().detach()
|
357 |
+
orders = self.sample_orders(bsz=x.size(0))
|
358 |
+
mask = self.random_masking(x, orders)
|
359 |
+
|
360 |
+
# mae encoder
|
361 |
+
x = self.forward_mae_encoder(x, mask, class_embedding)
|
362 |
+
|
363 |
+
# mae decoder
|
364 |
+
z = self.forward_mae_decoder(x, mask)
|
365 |
+
|
366 |
+
# diffloss
|
367 |
+
loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
|
368 |
+
|
369 |
+
return loss
|
370 |
+
|
371 |
+
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
|
372 |
+
import pdb; pdb.set_trace()
|
373 |
+
# init and sample generation orders
|
374 |
+
mask = torch.ones(bsz, self.seq_len).to(self.device)
|
375 |
+
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(self.device)
|
376 |
+
orders = self.sample_orders(bsz)
|
377 |
+
|
378 |
+
indices = list(range(num_iter))
|
379 |
+
if progress:
|
380 |
+
indices = tqdm(indices)
|
381 |
+
# generate latents
|
382 |
+
for step in indices:
|
383 |
+
cur_tokens = tokens.clone()
|
384 |
+
|
385 |
+
# class embedding and CFG
|
386 |
+
if labels is not None:
|
387 |
+
class_embedding = self.class_emb(labels)
|
388 |
+
else:
|
389 |
+
class_embedding = self.fake_latent.repeat(bsz, 1)
|
390 |
+
if not cfg == 1.0:
|
391 |
+
tokens = torch.cat([tokens, tokens], dim=0)
|
392 |
+
class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
|
393 |
+
mask = torch.cat([mask, mask], dim=0)
|
394 |
+
|
395 |
+
# mae encoder
|
396 |
+
x = self.forward_mae_encoder(tokens, mask.to(self.dtype), class_embedding)
|
397 |
+
|
398 |
+
# mae decoder
|
399 |
+
z = self.forward_mae_decoder(x, mask.to(self.dtype))
|
400 |
+
import pdb; pdb.set_trace()
|
401 |
+
|
402 |
+
# mask ratio for the next round, following MaskGIT and MAGE.
|
403 |
+
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
|
404 |
+
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(self.device)
|
405 |
+
import pdb; pdb.set_trace()
|
406 |
+
# masks out at least one for the next iteration
|
407 |
+
mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
|
408 |
+
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
|
409 |
+
import pdb; pdb.set_trace()
|
410 |
+
# get masking for next iteration and locations to be predicted in this iteration
|
411 |
+
mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
|
412 |
+
import pdb; pdb.set_trace()
|
413 |
+
if step >= num_iter - 1:
|
414 |
+
mask_to_pred = mask[:bsz].bool()
|
415 |
+
else:
|
416 |
+
mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
|
417 |
+
mask = mask_next
|
418 |
+
if not cfg == 1.0:
|
419 |
+
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
420 |
+
import pdb; pdb.set_trace()
|
421 |
+
# sample token latents for this step
|
422 |
+
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
423 |
+
# cfg schedule follow Muse
|
424 |
+
if cfg_schedule == "linear":
|
425 |
+
cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
|
426 |
+
elif cfg_schedule == "constant":
|
427 |
+
cfg_iter = cfg
|
428 |
+
else:
|
429 |
+
raise NotImplementedError
|
430 |
+
sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
|
431 |
+
if not cfg == 1.0:
|
432 |
+
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
|
433 |
+
mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
|
434 |
+
import pdb; pdb.set_trace()
|
435 |
+
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
|
436 |
+
tokens = cur_tokens.clone()
|
437 |
+
|
438 |
+
# unpatchify
|
439 |
+
tokens = self.unpatchify(tokens)
|
440 |
+
return tokens
|
441 |
+
|
442 |
+
def gradient_checkpointing_enable(self):
|
443 |
+
self.grad_checkpointing = True
|
444 |
+
|
445 |
+
def gradient_checkpointing_disable(self):
|
446 |
+
self.grad_checkpointing = False
|
447 |
+
|
448 |
+
|
449 |
+
def mar_base(**kwargs):
|
450 |
+
model = MAR(
|
451 |
+
encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
|
452 |
+
decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
|
453 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
454 |
+
return model
|
455 |
+
|
456 |
+
|
457 |
+
def mar_large(**kwargs):
|
458 |
+
model = MAR(
|
459 |
+
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
|
460 |
+
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
|
461 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
462 |
+
return model
|
463 |
+
|
464 |
+
|
465 |
+
def mar_huge(**kwargs):
|
466 |
+
model = MAR(
|
467 |
+
encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
|
468 |
+
decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
|
469 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
470 |
+
return model
|
471 |
+
|
472 |
+
def mar_max(**kwargs):
|
473 |
+
model = MAR(
|
474 |
+
encoder_embed_dim=1536, encoder_depth=24, encoder_num_heads=16,
|
475 |
+
decoder_embed_dim=1536, decoder_depth=24, decoder_num_heads=16,
|
476 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
477 |
+
return model
|
src/models/mar/misc.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import builtins
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
11 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
12 |
+
|
13 |
+
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
14 |
+
from torch._six import inf
|
15 |
+
else:
|
16 |
+
from torch import inf
|
17 |
+
import copy
|
18 |
+
|
19 |
+
|
20 |
+
class SmoothedValue(object):
|
21 |
+
"""Track a series of values and provide access to smoothed values over a
|
22 |
+
window or the global series average.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, window_size=20, fmt=None):
|
26 |
+
if fmt is None:
|
27 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
28 |
+
self.deque = deque(maxlen=window_size)
|
29 |
+
self.total = 0.0
|
30 |
+
self.count = 0
|
31 |
+
self.fmt = fmt
|
32 |
+
|
33 |
+
def update(self, value, n=1):
|
34 |
+
self.deque.append(value)
|
35 |
+
self.count += n
|
36 |
+
self.total += value * n
|
37 |
+
|
38 |
+
def synchronize_between_processes(self):
|
39 |
+
"""
|
40 |
+
Warning: does not synchronize the deque!
|
41 |
+
"""
|
42 |
+
if not is_dist_avail_and_initialized():
|
43 |
+
return
|
44 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
45 |
+
dist.barrier()
|
46 |
+
dist.all_reduce(t)
|
47 |
+
t = t.tolist()
|
48 |
+
self.count = int(t[0])
|
49 |
+
self.total = t[1]
|
50 |
+
|
51 |
+
@property
|
52 |
+
def median(self):
|
53 |
+
d = torch.tensor(list(self.deque))
|
54 |
+
return d.median().item()
|
55 |
+
|
56 |
+
@property
|
57 |
+
def avg(self):
|
58 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
59 |
+
return d.mean().item()
|
60 |
+
|
61 |
+
@property
|
62 |
+
def global_avg(self):
|
63 |
+
return self.total / self.count
|
64 |
+
|
65 |
+
@property
|
66 |
+
def max(self):
|
67 |
+
return max(self.deque)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def value(self):
|
71 |
+
return self.deque[-1]
|
72 |
+
|
73 |
+
def __str__(self):
|
74 |
+
return self.fmt.format(
|
75 |
+
median=self.median,
|
76 |
+
avg=self.avg,
|
77 |
+
global_avg=self.global_avg,
|
78 |
+
max=self.max,
|
79 |
+
value=self.value)
|
80 |
+
|
81 |
+
|
82 |
+
class MetricLogger(object):
|
83 |
+
def __init__(self, delimiter="\t"):
|
84 |
+
self.meters = defaultdict(SmoothedValue)
|
85 |
+
self.delimiter = delimiter
|
86 |
+
|
87 |
+
def update(self, **kwargs):
|
88 |
+
for k, v in kwargs.items():
|
89 |
+
if v is None:
|
90 |
+
continue
|
91 |
+
if isinstance(v, torch.Tensor):
|
92 |
+
v = v.item()
|
93 |
+
assert isinstance(v, (float, int))
|
94 |
+
self.meters[k].update(v)
|
95 |
+
|
96 |
+
def __getattr__(self, attr):
|
97 |
+
if attr in self.meters:
|
98 |
+
return self.meters[attr]
|
99 |
+
if attr in self.__dict__:
|
100 |
+
return self.__dict__[attr]
|
101 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
102 |
+
type(self).__name__, attr))
|
103 |
+
|
104 |
+
def __str__(self):
|
105 |
+
loss_str = []
|
106 |
+
for name, meter in self.meters.items():
|
107 |
+
loss_str.append(
|
108 |
+
"{}: {}".format(name, str(meter))
|
109 |
+
)
|
110 |
+
return self.delimiter.join(loss_str)
|
111 |
+
|
112 |
+
def synchronize_between_processes(self):
|
113 |
+
for meter in self.meters.values():
|
114 |
+
meter.synchronize_between_processes()
|
115 |
+
|
116 |
+
def add_meter(self, name, meter):
|
117 |
+
self.meters[name] = meter
|
118 |
+
|
119 |
+
def log_every(self, iterable, print_freq, header=None):
|
120 |
+
i = 0
|
121 |
+
if not header:
|
122 |
+
header = ''
|
123 |
+
start_time = time.time()
|
124 |
+
end = time.time()
|
125 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
126 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
127 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
128 |
+
log_msg = [
|
129 |
+
header,
|
130 |
+
'[{0' + space_fmt + '}/{1}]',
|
131 |
+
'eta: {eta}',
|
132 |
+
'{meters}',
|
133 |
+
'time: {time}',
|
134 |
+
'data: {data}'
|
135 |
+
]
|
136 |
+
if torch.cuda.is_available():
|
137 |
+
log_msg.append('max mem: {memory:.0f}')
|
138 |
+
log_msg = self.delimiter.join(log_msg)
|
139 |
+
MB = 1024.0 * 1024.0
|
140 |
+
for obj in iterable:
|
141 |
+
data_time.update(time.time() - end)
|
142 |
+
yield obj
|
143 |
+
iter_time.update(time.time() - end)
|
144 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
145 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
146 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
print(log_msg.format(
|
149 |
+
i, len(iterable), eta=eta_string,
|
150 |
+
meters=str(self),
|
151 |
+
time=str(iter_time), data=str(data_time),
|
152 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
153 |
+
else:
|
154 |
+
print(log_msg.format(
|
155 |
+
i, len(iterable), eta=eta_string,
|
156 |
+
meters=str(self),
|
157 |
+
time=str(iter_time), data=str(data_time)))
|
158 |
+
i += 1
|
159 |
+
end = time.time()
|
160 |
+
total_time = time.time() - start_time
|
161 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
162 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
163 |
+
header, total_time_str, total_time / len(iterable)))
|
164 |
+
|
165 |
+
|
166 |
+
def setup_for_distributed(is_master):
|
167 |
+
"""
|
168 |
+
This function disables printing when not in master process
|
169 |
+
"""
|
170 |
+
builtin_print = builtins.print
|
171 |
+
|
172 |
+
def print(*args, **kwargs):
|
173 |
+
force = kwargs.pop('force', False)
|
174 |
+
force = force or (get_world_size() > 8)
|
175 |
+
if is_master or force:
|
176 |
+
now = datetime.datetime.now().time()
|
177 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
178 |
+
builtin_print(*args, **kwargs)
|
179 |
+
|
180 |
+
builtins.print = print
|
181 |
+
|
182 |
+
|
183 |
+
def is_dist_avail_and_initialized():
|
184 |
+
if not dist.is_available():
|
185 |
+
return False
|
186 |
+
if not dist.is_initialized():
|
187 |
+
return False
|
188 |
+
return True
|
189 |
+
|
190 |
+
|
191 |
+
def get_world_size():
|
192 |
+
if not is_dist_avail_and_initialized():
|
193 |
+
return 1
|
194 |
+
return dist.get_world_size()
|
195 |
+
|
196 |
+
|
197 |
+
def get_rank():
|
198 |
+
if not is_dist_avail_and_initialized():
|
199 |
+
return 0
|
200 |
+
return dist.get_rank()
|
201 |
+
|
202 |
+
|
203 |
+
def is_main_process():
|
204 |
+
return get_rank() == 0
|
205 |
+
|
206 |
+
|
207 |
+
def save_on_master(*args, **kwargs):
|
208 |
+
if is_main_process():
|
209 |
+
torch.save(*args, **kwargs)
|
210 |
+
|
211 |
+
|
212 |
+
def init_distributed_mode(args):
|
213 |
+
if args.dist_on_itp:
|
214 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
215 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
216 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
217 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
218 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
219 |
+
os.environ['RANK'] = str(args.rank)
|
220 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
221 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
222 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
223 |
+
args.rank = int(os.environ["RANK"])
|
224 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
225 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
226 |
+
elif 'SLURM_PROCID' in os.environ:
|
227 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
228 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
229 |
+
else:
|
230 |
+
print('Not using distributed mode')
|
231 |
+
setup_for_distributed(is_master=True) # hack
|
232 |
+
args.distributed = False
|
233 |
+
return
|
234 |
+
|
235 |
+
args.distributed = True
|
236 |
+
|
237 |
+
torch.cuda.set_device(args.gpu)
|
238 |
+
args.dist_backend = 'nccl'
|
239 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
240 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
241 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
242 |
+
world_size=args.world_size, rank=args.rank)
|
243 |
+
torch.distributed.barrier()
|
244 |
+
setup_for_distributed(args.rank == 0)
|
245 |
+
|
246 |
+
|
247 |
+
class NativeScalerWithGradNormCount:
|
248 |
+
state_dict_key = "amp_scaler"
|
249 |
+
|
250 |
+
def __init__(self):
|
251 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
252 |
+
|
253 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
254 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
255 |
+
if update_grad:
|
256 |
+
if clip_grad is not None:
|
257 |
+
assert parameters is not None
|
258 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
259 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
260 |
+
else:
|
261 |
+
self._scaler.unscale_(optimizer)
|
262 |
+
norm = get_grad_norm_(parameters)
|
263 |
+
self._scaler.step(optimizer)
|
264 |
+
self._scaler.update()
|
265 |
+
else:
|
266 |
+
norm = None
|
267 |
+
return norm
|
268 |
+
|
269 |
+
def state_dict(self):
|
270 |
+
return self._scaler.state_dict()
|
271 |
+
|
272 |
+
def load_state_dict(self, state_dict):
|
273 |
+
self._scaler.load_state_dict(state_dict)
|
274 |
+
|
275 |
+
|
276 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
277 |
+
if isinstance(parameters, torch.Tensor):
|
278 |
+
parameters = [parameters]
|
279 |
+
parameters = [p for p in parameters if p.grad is not None]
|
280 |
+
norm_type = float(norm_type)
|
281 |
+
if len(parameters) == 0:
|
282 |
+
return torch.tensor(0.)
|
283 |
+
device = parameters[0].grad.device
|
284 |
+
if norm_type == inf:
|
285 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
286 |
+
else:
|
287 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
288 |
+
return total_norm
|
289 |
+
|
290 |
+
|
291 |
+
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
292 |
+
decay = []
|
293 |
+
no_decay = []
|
294 |
+
for name, param in model.named_parameters():
|
295 |
+
if not param.requires_grad:
|
296 |
+
continue # frozen weights
|
297 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
|
298 |
+
no_decay.append(param) # no weight decay on bias, norm and diffloss
|
299 |
+
else:
|
300 |
+
decay.append(param)
|
301 |
+
return [
|
302 |
+
{'params': no_decay, 'weight_decay': 0.},
|
303 |
+
{'params': decay, 'weight_decay': weight_decay}]
|
304 |
+
|
305 |
+
|
306 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None, epoch_name=None):
|
307 |
+
if epoch_name is None:
|
308 |
+
epoch_name = str(epoch)
|
309 |
+
output_dir = Path(args.output_dir)
|
310 |
+
checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
|
311 |
+
|
312 |
+
# ema
|
313 |
+
if ema_params is not None:
|
314 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
315 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
316 |
+
assert name in ema_state_dict
|
317 |
+
ema_state_dict[name] = ema_params[i]
|
318 |
+
else:
|
319 |
+
ema_state_dict = None
|
320 |
+
|
321 |
+
to_save = {
|
322 |
+
'model': model_without_ddp.state_dict(),
|
323 |
+
'model_ema': ema_state_dict,
|
324 |
+
'optimizer': optimizer.state_dict(),
|
325 |
+
'epoch': epoch,
|
326 |
+
'scaler': loss_scaler.state_dict(),
|
327 |
+
'args': args,
|
328 |
+
}
|
329 |
+
save_on_master(to_save, checkpoint_path)
|
330 |
+
|
331 |
+
|
332 |
+
def all_reduce_mean(x):
|
333 |
+
world_size = get_world_size()
|
334 |
+
if world_size > 1:
|
335 |
+
x_reduce = torch.tensor(x).cuda()
|
336 |
+
dist.all_reduce(x_reduce)
|
337 |
+
x_reduce /= world_size
|
338 |
+
return x_reduce.item()
|
339 |
+
else:
|
340 |
+
return x
|
src/models/mar/vae.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def nonlinearity(x):
|
9 |
+
# swish
|
10 |
+
return x * torch.sigmoid(x)
|
11 |
+
|
12 |
+
|
13 |
+
def Normalize(in_channels, num_groups=32):
|
14 |
+
return torch.nn.GroupNorm(
|
15 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class Upsample(nn.Module):
|
20 |
+
def __init__(self, in_channels, with_conv):
|
21 |
+
super().__init__()
|
22 |
+
self.with_conv = with_conv
|
23 |
+
if self.with_conv:
|
24 |
+
self.conv = torch.nn.Conv2d(
|
25 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
30 |
+
if self.with_conv:
|
31 |
+
x = self.conv(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class Downsample(nn.Module):
|
36 |
+
def __init__(self, in_channels, with_conv):
|
37 |
+
super().__init__()
|
38 |
+
self.with_conv = with_conv
|
39 |
+
if self.with_conv:
|
40 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
41 |
+
self.conv = torch.nn.Conv2d(
|
42 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
if self.with_conv:
|
47 |
+
pad = (0, 1, 0, 1)
|
48 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
49 |
+
x = self.conv(x)
|
50 |
+
else:
|
51 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class ResnetBlock(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
*,
|
59 |
+
in_channels,
|
60 |
+
out_channels=None,
|
61 |
+
conv_shortcut=False,
|
62 |
+
dropout,
|
63 |
+
temb_channels=512,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.in_channels = in_channels
|
67 |
+
out_channels = in_channels if out_channels is None else out_channels
|
68 |
+
self.out_channels = out_channels
|
69 |
+
self.use_conv_shortcut = conv_shortcut
|
70 |
+
|
71 |
+
self.norm1 = Normalize(in_channels)
|
72 |
+
self.conv1 = torch.nn.Conv2d(
|
73 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
74 |
+
)
|
75 |
+
if temb_channels > 0:
|
76 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
77 |
+
self.norm2 = Normalize(out_channels)
|
78 |
+
self.dropout = torch.nn.Dropout(dropout)
|
79 |
+
self.conv2 = torch.nn.Conv2d(
|
80 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
81 |
+
)
|
82 |
+
if self.in_channels != self.out_channels:
|
83 |
+
if self.use_conv_shortcut:
|
84 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
85 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
89 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x, temb):
|
93 |
+
h = x
|
94 |
+
h = self.norm1(h)
|
95 |
+
h = nonlinearity(h)
|
96 |
+
h = self.conv1(h)
|
97 |
+
|
98 |
+
if temb is not None:
|
99 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
100 |
+
|
101 |
+
h = self.norm2(h)
|
102 |
+
h = nonlinearity(h)
|
103 |
+
h = self.dropout(h)
|
104 |
+
h = self.conv2(h)
|
105 |
+
|
106 |
+
if self.in_channels != self.out_channels:
|
107 |
+
if self.use_conv_shortcut:
|
108 |
+
x = self.conv_shortcut(x)
|
109 |
+
else:
|
110 |
+
x = self.nin_shortcut(x)
|
111 |
+
|
112 |
+
return x + h
|
113 |
+
|
114 |
+
|
115 |
+
class AttnBlock(nn.Module):
|
116 |
+
def __init__(self, in_channels):
|
117 |
+
super().__init__()
|
118 |
+
self.in_channels = in_channels
|
119 |
+
|
120 |
+
self.norm = Normalize(in_channels)
|
121 |
+
self.q = torch.nn.Conv2d(
|
122 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
123 |
+
)
|
124 |
+
self.k = torch.nn.Conv2d(
|
125 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
126 |
+
)
|
127 |
+
self.v = torch.nn.Conv2d(
|
128 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
129 |
+
)
|
130 |
+
self.proj_out = torch.nn.Conv2d(
|
131 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
132 |
+
)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
h_ = x
|
136 |
+
h_ = self.norm(h_)
|
137 |
+
q = self.q(h_)
|
138 |
+
k = self.k(h_)
|
139 |
+
v = self.v(h_)
|
140 |
+
|
141 |
+
# compute attention
|
142 |
+
b, c, h, w = q.shape
|
143 |
+
q = q.reshape(b, c, h * w)
|
144 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
145 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
146 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
147 |
+
w_ = w_ * (int(c) ** (-0.5))
|
148 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
149 |
+
|
150 |
+
# attend to values
|
151 |
+
v = v.reshape(b, c, h * w)
|
152 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
153 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
154 |
+
h_ = h_.reshape(b, c, h, w)
|
155 |
+
|
156 |
+
h_ = self.proj_out(h_)
|
157 |
+
|
158 |
+
return x + h_
|
159 |
+
|
160 |
+
|
161 |
+
class Encoder(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
*,
|
165 |
+
ch=128,
|
166 |
+
out_ch=3,
|
167 |
+
ch_mult=(1, 1, 2, 2, 4),
|
168 |
+
num_res_blocks=2,
|
169 |
+
attn_resolutions=(16,),
|
170 |
+
dropout=0.0,
|
171 |
+
resamp_with_conv=True,
|
172 |
+
in_channels=3,
|
173 |
+
resolution=256,
|
174 |
+
z_channels=16,
|
175 |
+
double_z=True,
|
176 |
+
**ignore_kwargs,
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.ch = ch
|
180 |
+
self.temb_ch = 0
|
181 |
+
self.num_resolutions = len(ch_mult)
|
182 |
+
self.num_res_blocks = num_res_blocks
|
183 |
+
self.resolution = resolution
|
184 |
+
self.in_channels = in_channels
|
185 |
+
|
186 |
+
# downsampling
|
187 |
+
self.conv_in = torch.nn.Conv2d(
|
188 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
189 |
+
)
|
190 |
+
|
191 |
+
curr_res = resolution
|
192 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
193 |
+
self.down = nn.ModuleList()
|
194 |
+
for i_level in range(self.num_resolutions):
|
195 |
+
block = nn.ModuleList()
|
196 |
+
attn = nn.ModuleList()
|
197 |
+
block_in = ch * in_ch_mult[i_level]
|
198 |
+
block_out = ch * ch_mult[i_level]
|
199 |
+
for i_block in range(self.num_res_blocks):
|
200 |
+
block.append(
|
201 |
+
ResnetBlock(
|
202 |
+
in_channels=block_in,
|
203 |
+
out_channels=block_out,
|
204 |
+
temb_channels=self.temb_ch,
|
205 |
+
dropout=dropout,
|
206 |
+
)
|
207 |
+
)
|
208 |
+
block_in = block_out
|
209 |
+
if curr_res in attn_resolutions:
|
210 |
+
attn.append(AttnBlock(block_in))
|
211 |
+
down = nn.Module()
|
212 |
+
down.block = block
|
213 |
+
down.attn = attn
|
214 |
+
if i_level != self.num_resolutions - 1:
|
215 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
216 |
+
curr_res = curr_res // 2
|
217 |
+
self.down.append(down)
|
218 |
+
|
219 |
+
# middle
|
220 |
+
self.mid = nn.Module()
|
221 |
+
self.mid.block_1 = ResnetBlock(
|
222 |
+
in_channels=block_in,
|
223 |
+
out_channels=block_in,
|
224 |
+
temb_channels=self.temb_ch,
|
225 |
+
dropout=dropout,
|
226 |
+
)
|
227 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
228 |
+
self.mid.block_2 = ResnetBlock(
|
229 |
+
in_channels=block_in,
|
230 |
+
out_channels=block_in,
|
231 |
+
temb_channels=self.temb_ch,
|
232 |
+
dropout=dropout,
|
233 |
+
)
|
234 |
+
|
235 |
+
# end
|
236 |
+
self.norm_out = Normalize(block_in)
|
237 |
+
self.conv_out = torch.nn.Conv2d(
|
238 |
+
block_in,
|
239 |
+
2 * z_channels if double_z else z_channels,
|
240 |
+
kernel_size=3,
|
241 |
+
stride=1,
|
242 |
+
padding=1,
|
243 |
+
)
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
247 |
+
|
248 |
+
# timestep embedding
|
249 |
+
temb = None
|
250 |
+
|
251 |
+
# downsampling
|
252 |
+
hs = [self.conv_in(x)]
|
253 |
+
for i_level in range(self.num_resolutions):
|
254 |
+
for i_block in range(self.num_res_blocks):
|
255 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
256 |
+
if len(self.down[i_level].attn) > 0:
|
257 |
+
h = self.down[i_level].attn[i_block](h)
|
258 |
+
hs.append(h)
|
259 |
+
if i_level != self.num_resolutions - 1:
|
260 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
261 |
+
|
262 |
+
# middle
|
263 |
+
h = hs[-1]
|
264 |
+
h = self.mid.block_1(h, temb)
|
265 |
+
h = self.mid.attn_1(h)
|
266 |
+
h = self.mid.block_2(h, temb)
|
267 |
+
|
268 |
+
# end
|
269 |
+
h = self.norm_out(h)
|
270 |
+
h = nonlinearity(h)
|
271 |
+
h = self.conv_out(h)
|
272 |
+
return h
|
273 |
+
|
274 |
+
|
275 |
+
class Decoder(nn.Module):
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
*,
|
279 |
+
ch=128,
|
280 |
+
out_ch=3,
|
281 |
+
ch_mult=(1, 1, 2, 2, 4),
|
282 |
+
num_res_blocks=2,
|
283 |
+
attn_resolutions=(),
|
284 |
+
dropout=0.0,
|
285 |
+
resamp_with_conv=True,
|
286 |
+
in_channels=3,
|
287 |
+
resolution=256,
|
288 |
+
z_channels=16,
|
289 |
+
give_pre_end=False,
|
290 |
+
**ignore_kwargs,
|
291 |
+
):
|
292 |
+
super().__init__()
|
293 |
+
self.ch = ch
|
294 |
+
self.temb_ch = 0
|
295 |
+
self.num_resolutions = len(ch_mult)
|
296 |
+
self.num_res_blocks = num_res_blocks
|
297 |
+
self.resolution = resolution
|
298 |
+
self.in_channels = in_channels
|
299 |
+
self.give_pre_end = give_pre_end
|
300 |
+
|
301 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
302 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
303 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
304 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
305 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
306 |
+
print(
|
307 |
+
"Working with z of shape {} = {} dimensions.".format(
|
308 |
+
self.z_shape, np.prod(self.z_shape)
|
309 |
+
)
|
310 |
+
)
|
311 |
+
|
312 |
+
# z to block_in
|
313 |
+
self.conv_in = torch.nn.Conv2d(
|
314 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
315 |
+
)
|
316 |
+
|
317 |
+
# middle
|
318 |
+
self.mid = nn.Module()
|
319 |
+
self.mid.block_1 = ResnetBlock(
|
320 |
+
in_channels=block_in,
|
321 |
+
out_channels=block_in,
|
322 |
+
temb_channels=self.temb_ch,
|
323 |
+
dropout=dropout,
|
324 |
+
)
|
325 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
326 |
+
self.mid.block_2 = ResnetBlock(
|
327 |
+
in_channels=block_in,
|
328 |
+
out_channels=block_in,
|
329 |
+
temb_channels=self.temb_ch,
|
330 |
+
dropout=dropout,
|
331 |
+
)
|
332 |
+
|
333 |
+
# upsampling
|
334 |
+
self.up = nn.ModuleList()
|
335 |
+
for i_level in reversed(range(self.num_resolutions)):
|
336 |
+
block = nn.ModuleList()
|
337 |
+
attn = nn.ModuleList()
|
338 |
+
block_out = ch * ch_mult[i_level]
|
339 |
+
for i_block in range(self.num_res_blocks + 1):
|
340 |
+
block.append(
|
341 |
+
ResnetBlock(
|
342 |
+
in_channels=block_in,
|
343 |
+
out_channels=block_out,
|
344 |
+
temb_channels=self.temb_ch,
|
345 |
+
dropout=dropout,
|
346 |
+
)
|
347 |
+
)
|
348 |
+
block_in = block_out
|
349 |
+
if curr_res in attn_resolutions:
|
350 |
+
attn.append(AttnBlock(block_in))
|
351 |
+
up = nn.Module()
|
352 |
+
up.block = block
|
353 |
+
up.attn = attn
|
354 |
+
if i_level != 0:
|
355 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
356 |
+
curr_res = curr_res * 2
|
357 |
+
self.up.insert(0, up) # prepend to get consistent order
|
358 |
+
|
359 |
+
# end
|
360 |
+
self.norm_out = Normalize(block_in)
|
361 |
+
self.conv_out = torch.nn.Conv2d(
|
362 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, z):
|
366 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
367 |
+
self.last_z_shape = z.shape
|
368 |
+
|
369 |
+
# timestep embedding
|
370 |
+
temb = None
|
371 |
+
|
372 |
+
# z to block_in
|
373 |
+
h = self.conv_in(z)
|
374 |
+
|
375 |
+
# middle
|
376 |
+
h = self.mid.block_1(h, temb)
|
377 |
+
h = self.mid.attn_1(h)
|
378 |
+
h = self.mid.block_2(h, temb)
|
379 |
+
|
380 |
+
# upsampling
|
381 |
+
for i_level in reversed(range(self.num_resolutions)):
|
382 |
+
for i_block in range(self.num_res_blocks + 1):
|
383 |
+
h = self.up[i_level].block[i_block](h, temb)
|
384 |
+
if len(self.up[i_level].attn) > 0:
|
385 |
+
h = self.up[i_level].attn[i_block](h)
|
386 |
+
if i_level != 0:
|
387 |
+
h = self.up[i_level].upsample(h)
|
388 |
+
|
389 |
+
# end
|
390 |
+
if self.give_pre_end:
|
391 |
+
return h
|
392 |
+
|
393 |
+
h = self.norm_out(h)
|
394 |
+
h = nonlinearity(h)
|
395 |
+
h = self.conv_out(h)
|
396 |
+
return h
|
397 |
+
|
398 |
+
|
399 |
+
class DiagonalGaussianDistribution(object):
|
400 |
+
def __init__(self, parameters, deterministic=False):
|
401 |
+
self.parameters = parameters
|
402 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
403 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
404 |
+
self.deterministic = deterministic
|
405 |
+
self.std = torch.exp(0.5 * self.logvar)
|
406 |
+
self.var = torch.exp(self.logvar)
|
407 |
+
if self.deterministic:
|
408 |
+
self.var = self.std = torch.zeros_like(self.mean).to(
|
409 |
+
device=self.parameters.device
|
410 |
+
)
|
411 |
+
|
412 |
+
def sample(self):
|
413 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
414 |
+
device=self.parameters.device
|
415 |
+
)
|
416 |
+
return x
|
417 |
+
|
418 |
+
def kl(self, other=None):
|
419 |
+
if self.deterministic:
|
420 |
+
return torch.Tensor([0.0])
|
421 |
+
else:
|
422 |
+
if other is None:
|
423 |
+
return 0.5 * torch.sum(
|
424 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
425 |
+
dim=[1, 2, 3],
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
return 0.5 * torch.sum(
|
429 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
430 |
+
+ self.var / other.var
|
431 |
+
- 1.0
|
432 |
+
- self.logvar
|
433 |
+
+ other.logvar,
|
434 |
+
dim=[1, 2, 3],
|
435 |
+
)
|
436 |
+
|
437 |
+
def nll(self, sample, dims=[1, 2, 3]):
|
438 |
+
if self.deterministic:
|
439 |
+
return torch.Tensor([0.0])
|
440 |
+
logtwopi = np.log(2.0 * np.pi)
|
441 |
+
return 0.5 * torch.sum(
|
442 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
443 |
+
dim=dims,
|
444 |
+
)
|
445 |
+
|
446 |
+
def mode(self):
|
447 |
+
return self.mean
|
448 |
+
|
449 |
+
|
450 |
+
class AutoencoderKL(nn.Module):
|
451 |
+
def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None):
|
452 |
+
super().__init__()
|
453 |
+
self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim)
|
454 |
+
self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim)
|
455 |
+
self.use_variational = use_variational
|
456 |
+
mult = 2 if self.use_variational else 1
|
457 |
+
self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1)
|
458 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1)
|
459 |
+
self.embed_dim = embed_dim
|
460 |
+
if ckpt_path is not None:
|
461 |
+
self.init_from_ckpt(ckpt_path)
|
462 |
+
|
463 |
+
def init_from_ckpt(self, path):
|
464 |
+
sd = torch.load(path, map_location="cpu")["model"]
|
465 |
+
msg = self.load_state_dict(sd, strict=False)
|
466 |
+
print("Loading pre-trained KL-VAE")
|
467 |
+
print("Missing keys:")
|
468 |
+
print(msg.missing_keys)
|
469 |
+
print("Unexpected keys:")
|
470 |
+
print(msg.unexpected_keys)
|
471 |
+
print(f"Restored from {path}")
|
472 |
+
|
473 |
+
def encode(self, x):
|
474 |
+
h = self.encoder(x)
|
475 |
+
moments = self.quant_conv(h)
|
476 |
+
if not self.use_variational:
|
477 |
+
moments = torch.cat((moments, torch.ones_like(moments)), 1)
|
478 |
+
posterior = DiagonalGaussianDistribution(moments)
|
479 |
+
return posterior
|
480 |
+
|
481 |
+
def decode(self, z):
|
482 |
+
z = self.post_quant_conv(z)
|
483 |
+
dec = self.decoder(z)
|
484 |
+
return dec
|
485 |
+
|
486 |
+
def forward(self, inputs, disable=True, train=True, optimizer_idx=0):
|
487 |
+
if train:
|
488 |
+
return self.training_step(inputs, disable, optimizer_idx)
|
489 |
+
else:
|
490 |
+
return self.validation_step(inputs, disable)
|
491 |
+
|
492 |
+
|
493 |
+
if __name__ == "__main__":
|
494 |
+
from PIL import Image
|
495 |
+
import numpy as np
|
496 |
+
import torch.nn.functional as F
|
497 |
+
|
498 |
+
vae = AutoencoderKL(
|
499 |
+
embed_dim=16, ch_mult=(1, 1, 2, 2, 4),
|
500 |
+
ckpt_path='checkpoints/kl16.ckpt')
|
501 |
+
|
502 |
+
image = Image.open('data/ILSVRC2012_val_00023344.JPEG')
|
503 |
+
image = torch.from_numpy(np.array(image))
|
504 |
+
image = image.permute(2, 0, 1).float() / 255
|
505 |
+
image = 2 * image - 1
|
506 |
+
|
507 |
+
x = F.interpolate(image[None], size=(256, 256), mode='bilinear', align_corners=True)
|
508 |
+
|
509 |
+
print(x.shape)
|
510 |
+
|
511 |
+
with torch.no_grad():
|
512 |
+
z = vae.encode(x).sample()
|
513 |
+
print(z.shape)
|
514 |
+
x_rec = vae.decode(z)[0]
|
515 |
+
|
516 |
+
x_rec = (x_rec + 1.0) * 255 / 2
|
517 |
+
x_rec = torch.clamp(x_rec, min=0, max=255)
|
518 |
+
x_rec = x_rec.to(torch.uint8)
|
519 |
+
|
520 |
+
x_rec = x_rec.permute(1, 2, 0)
|
521 |
+
|
522 |
+
x_rec = Image.fromarray(x_rec.numpy())
|
523 |
+
|
524 |
+
x_rec.show()
|
525 |
+
|
src/models/skywork_unipic_dev.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.nn.modules.module import T
|
4 |
+
from mmengine.model import BaseModel
|
5 |
+
from torch.autograd.function import Function
|
6 |
+
from mmengine.logging import print_log
|
7 |
+
from xtuner.model.utils import guess_load_checkpoint
|
8 |
+
import os
|
9 |
+
#from .skywork_unipic import SkyworkUnipic
|
10 |
+
from .skywork_unipic_siglip import SkyworkUnipic
|
11 |
+
from xtuner.utils import IMAGE_TOKEN_INDEX
|
12 |
+
import torch.distributed as dist
|
13 |
+
import json
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
|
17 |
+
def _load_state_dict_with_ds(module_to_load, state_dict, start_prefix="", strict=True):
|
18 |
+
try:
|
19 |
+
import deepspeed
|
20 |
+
except ImportError:
|
21 |
+
raise ImportError("deepspeed is not installed. Please install deepspeed to use this feature.")
|
22 |
+
|
23 |
+
# copy state_dict so _load_from_state_dict can modify it
|
24 |
+
metadata = getattr(state_dict, "_metadata", None)
|
25 |
+
state_dict = state_dict.copy()
|
26 |
+
if metadata is not None:
|
27 |
+
state_dict._metadata = metadata
|
28 |
+
|
29 |
+
error_msgs = []
|
30 |
+
missing_keys = []
|
31 |
+
unexpected_keys = []
|
32 |
+
|
33 |
+
def load(module: torch.nn.Module, state_dict, prefix=""):
|
34 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
35 |
+
args = (state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
36 |
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
37 |
+
# state_dict
|
38 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
39 |
+
# In sharded models, each shard has only part of the full state_dict, so only gather
|
40 |
+
# parameters that are in the current state_dict.
|
41 |
+
named_parameters = dict(
|
42 |
+
module.named_parameters(prefix=prefix[:-1], recurse=False)
|
43 |
+
)
|
44 |
+
params_to_gather = [
|
45 |
+
named_parameters[k]
|
46 |
+
for k in state_dict.keys()
|
47 |
+
if k in named_parameters
|
48 |
+
]
|
49 |
+
if len(params_to_gather) > 0:
|
50 |
+
# because zero3 puts placeholders in model params, this context
|
51 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
52 |
+
# the state dict and then re-partitions them again
|
53 |
+
with deepspeed.zero.GatheredParameters(
|
54 |
+
params_to_gather, modifier_rank=0
|
55 |
+
):
|
56 |
+
if deepspeed.comm.get_rank() == 0:
|
57 |
+
module._load_from_state_dict(*args)
|
58 |
+
else:
|
59 |
+
module._load_from_state_dict(*args)
|
60 |
+
|
61 |
+
for name, child in module._modules.items():
|
62 |
+
if child is not None:
|
63 |
+
load(child, state_dict, prefix + name + ".")
|
64 |
+
|
65 |
+
load(module_to_load, state_dict, start_prefix)
|
66 |
+
if len(missing_keys) > 0:
|
67 |
+
print_log(f"[WARNING] Missing keys: {missing_keys}")
|
68 |
+
if len(unexpected_keys) > 0:
|
69 |
+
print_log(f"[WARNING] Unexpected keys: {unexpected_keys}")
|
70 |
+
if error_msgs:
|
71 |
+
raise RuntimeError(
|
72 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
73 |
+
module_to_load.__class__.__name__, "\n\t".join(error_msgs)
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
class _ScaleGradient(Function):
|
79 |
+
@staticmethod
|
80 |
+
def forward(ctx, input, scale):
|
81 |
+
ctx.scale = scale
|
82 |
+
return input
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def backward(ctx, grad_output):
|
86 |
+
return grad_output * ctx.scale, None
|
87 |
+
|
88 |
+
|
89 |
+
class SkyworkUnipicDev(SkyworkUnipic, BaseModel):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
grad_scale=0.1,
|
93 |
+
loss_weights=None,
|
94 |
+
pretrained_pth=None,
|
95 |
+
mar_path=None,
|
96 |
+
siglip_proj_path=None,
|
97 |
+
freeze_llm=False,
|
98 |
+
freeze_mar=False,
|
99 |
+
freeze_mar_decoder=False,
|
100 |
+
freeze_siglip_proj=False,
|
101 |
+
gradient_checkpointing=True,
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
if loss_weights is None:
|
105 |
+
loss_weights = {
|
106 |
+
"image2text": 0.01,
|
107 |
+
"text2image": 1.0,
|
108 |
+
"image_edit": 1.0,
|
109 |
+
"contrastive": 0.1,
|
110 |
+
}
|
111 |
+
super().__init__(**kwargs)
|
112 |
+
|
113 |
+
self.grad_scale = grad_scale
|
114 |
+
self.loss_weights = loss_weights
|
115 |
+
self.pretrained_pth = pretrained_pth
|
116 |
+
self.mar_path = mar_path
|
117 |
+
self.siglip_proj_path = siglip_proj_path
|
118 |
+
|
119 |
+
# 判断分布式 rank
|
120 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
121 |
+
|
122 |
+
# === 加载预训练权重 ===
|
123 |
+
if pretrained_pth:
|
124 |
+
self.load_hf_weights(
|
125 |
+
skywork_unipic_ckpt=pretrained_pth,
|
126 |
+
siglip_proj_path=siglip_proj_path,
|
127 |
+
mar_path=mar_path
|
128 |
+
)
|
129 |
+
|
130 |
+
# === 冻结模块 ===
|
131 |
+
if freeze_llm:
|
132 |
+
self.llm.requires_grad_(False)
|
133 |
+
if freeze_mar:
|
134 |
+
self.mar.requires_grad_(False)
|
135 |
+
if freeze_mar_decoder:
|
136 |
+
# 仅冻结 MAR 解码器部件
|
137 |
+
for param in self.mar.decoder_embed.parameters():
|
138 |
+
param.requires_grad = False
|
139 |
+
for block in self.mar.decoder_blocks:
|
140 |
+
for param in block.parameters():
|
141 |
+
param.requires_grad = False
|
142 |
+
for param in self.mar.decoder_norm.parameters():
|
143 |
+
param.requires_grad = False
|
144 |
+
if isinstance(self.mar.decoder_pos_embed_learned, torch.nn.Parameter):
|
145 |
+
self.mar.decoder_pos_embed_learned.requires_grad = False
|
146 |
+
if isinstance(self.mar.diffusion_pos_embed_learned, torch.nn.Parameter):
|
147 |
+
self.mar.diffusion_pos_embed_learned.requires_grad = False
|
148 |
+
if freeze_siglip_proj:
|
149 |
+
self.siglip2_proj.requires_grad_(False)
|
150 |
+
|
151 |
+
# === 梯度检查点 ===
|
152 |
+
if gradient_checkpointing:
|
153 |
+
self.gradient_checkpointing_enable()
|
154 |
+
else:
|
155 |
+
self.gradient_checkpointing_disable()
|
156 |
+
|
157 |
+
|
158 |
+
def load_hf_weights(self,
|
159 |
+
skywork_unipic_ckpt: str = None,
|
160 |
+
siglip_proj_path: str = None,
|
161 |
+
mar_path: str = None):
|
162 |
+
"""统一加载 SkyworkUnipic(可选) + SigLIP2 + MAR"""
|
163 |
+
device = "cpu"
|
164 |
+
state_dict = {}
|
165 |
+
|
166 |
+
def _print_load_result(module_name, missing, unexpected):
|
167 |
+
print_log(f"[INFO] Loaded {module_name}. missing={len(missing)}, unexpected={len(unexpected)}")
|
168 |
+
|
169 |
+
# === SkyworkUnipic 主模型(可选) ===
|
170 |
+
if skywork_unipic_ckpt:
|
171 |
+
print_log(f"[INFO] Loading SkyworkUnipic checkpoint from: {skywork_unipic_ckpt}")
|
172 |
+
# 加载 checkpoint(支持文件或目录)
|
173 |
+
if os.path.isfile(skywork_unipic_ckpt):
|
174 |
+
skywork_unipic_state = torch.load(skywork_unipic_ckpt, map_location=device)
|
175 |
+
else:
|
176 |
+
idx = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin.index.json")
|
177 |
+
if os.path.exists(idx):
|
178 |
+
with open(idx, 'r') as f:
|
179 |
+
index = json.load(f)
|
180 |
+
skywork_unipic_state = {}
|
181 |
+
for shard in sorted(set(index["weight_map"].values())):
|
182 |
+
shard_path = os.path.join(skywork_unipic_ckpt, shard)
|
183 |
+
skywork_unipic_state.update(torch.load(shard_path, map_location=device))
|
184 |
+
else:
|
185 |
+
bin_path = os.path.join(skywork_unipic_ckpt, "pytorch_model.bin")
|
186 |
+
skywork_unipic_state = torch.load(bin_path, map_location=device)
|
187 |
+
|
188 |
+
# 删除 SkyworkUnipic checkpoint 中可能带的 MAR pos_embed,避免覆盖
|
189 |
+
|
190 |
+
# for key in [
|
191 |
+
# "mar.encoder_pos_embed_learned",
|
192 |
+
# "mar.decoder_pos_embed_learned",
|
193 |
+
# "mar.diffusion_pos_embed_learned"
|
194 |
+
# ]:
|
195 |
+
# if key in skywork_unipic_state:
|
196 |
+
# print_log(f"[INFO] Dropping `{key}` from SkyworkUnipic checkpoint")
|
197 |
+
# del skywork_unipic_state[key]
|
198 |
+
model_dict = self.state_dict()
|
199 |
+
|
200 |
+
filtered_checkpoint = {}
|
201 |
+
shape_mismatch_keys = []
|
202 |
+
|
203 |
+
for k, v in skywork_unipic_state.items():
|
204 |
+
if k in model_dict:
|
205 |
+
if v.shape == model_dict[k].shape:
|
206 |
+
filtered_checkpoint[k] = v
|
207 |
+
else:
|
208 |
+
shape_mismatch_keys.append((k, v.shape, model_dict[k].shape))
|
209 |
+
|
210 |
+
missing, unexpected = self.load_state_dict(filtered_checkpoint, strict=False)
|
211 |
+
# 打印不匹配的 key 及其形状
|
212 |
+
if shape_mismatch_keys:
|
213 |
+
print("以下 key 因形状不匹配被跳过:")
|
214 |
+
for k, checkpoint_shape, model_shape in shape_mismatch_keys:
|
215 |
+
print(f" - {k}:")
|
216 |
+
print(f" checkpoint 中的形状: {checkpoint_shape}")
|
217 |
+
print(f" 当前模型的形状: {model_shape}")
|
218 |
+
else:
|
219 |
+
print("所有 key 形状匹配,未跳过任何参数")
|
220 |
+
|
221 |
+
# missing, unexpected = self.load_state_dict(skywork_unipic_state, strict=False)
|
222 |
+
_print_load_result("SkyworkUnipic", missing, unexpected)
|
223 |
+
else:
|
224 |
+
print_log("[INFO] Skipping SkyworkUnipic checkpoint loading")
|
225 |
+
|
226 |
+
# === SigLIP2 权重 ===
|
227 |
+
if siglip_proj_path:
|
228 |
+
print_log(f"[INFO] Loading SigLIP2 weights from: {siglip_proj_path}")
|
229 |
+
siglip_state = torch.load(
|
230 |
+
siglip_proj_path, map_location="cpu", weights_only=False
|
231 |
+
)
|
232 |
+
# 如果 checkpoint 是 {"model": {...}}
|
233 |
+
if isinstance(siglip_state, dict) and "model" in siglip_state:
|
234 |
+
siglip_state = siglip_state["model"]
|
235 |
+
missing, unexpected = self.siglip2_proj.load_state_dict(
|
236 |
+
siglip_state, strict=False
|
237 |
+
)
|
238 |
+
_print_load_result("SigLIP2", missing, unexpected)
|
239 |
+
else:
|
240 |
+
print_log("[INFO] No SigLIP2 checkpoint provided, skipping")
|
241 |
+
|
242 |
+
# === MAR 权重 ===
|
243 |
+
if mar_path:
|
244 |
+
print_log(f"[INFO] Loading MAR weights from: {mar_path}")
|
245 |
+
mar_state = torch.load(mar_path, map_location="cpu", weights_only=False)
|
246 |
+
# 兼容 model_ema or model dict
|
247 |
+
|
248 |
+
if isinstance(mar_state, dict) and "model_ema" in mar_state:
|
249 |
+
mar_state = mar_state["model_ema"]
|
250 |
+
|
251 |
+
elif isinstance(mar_state, dict) and "model" in mar_state:
|
252 |
+
mar_state = mar_state["model"]
|
253 |
+
|
254 |
+
|
255 |
+
# 如果 key 带有 “mar.” 前缀,批量去掉
|
256 |
+
if any(k.startswith("mar.") for k in mar_state):
|
257 |
+
filtered_mar = {
|
258 |
+
k.replace("mar.", "", 1): v
|
259 |
+
for k, v in mar_state.items()
|
260 |
+
if k.startswith("mar.")
|
261 |
+
}
|
262 |
+
else:
|
263 |
+
filtered_mar = mar_state
|
264 |
+
|
265 |
+
missing, unexpected = self.mar.load_state_dict(
|
266 |
+
filtered_mar, strict=False
|
267 |
+
)
|
268 |
+
_print_load_result("MAR", missing, unexpected)
|
269 |
+
else:
|
270 |
+
print_log("[INFO] No MAR checkpoint provided, skipping")
|
271 |
+
|
272 |
+
return state_dict
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
def gradient_checkpointing_disable(self):
|
277 |
+
self.llm.gradient_checkpointing_disable()
|
278 |
+
self.mar.gradient_checkpointing_disable()
|
279 |
+
|
280 |
+
def gradient_checkpointing_enable(self):
|
281 |
+
self.llm.gradient_checkpointing_enable()
|
282 |
+
self.mar.gradient_checkpointing_enable()
|
283 |
+
|
284 |
+
def state_dict(self, *args, **kwargs):
|
285 |
+
state_dict = super().state_dict(*args, **kwargs)
|
286 |
+
state_dict = {k: v for k, v in state_dict.items()
|
287 |
+
if 'vae.' not in k}
|
288 |
+
|
289 |
+
return state_dict
|
290 |
+
|
291 |
+
def train(self: T, mode: bool = True) -> T:
|
292 |
+
super().train(mode=mode)
|
293 |
+
self.vae.train(mode=False)
|
294 |
+
return self
|
295 |
+
|
296 |
+
def text2image_loss(self, data_dict):
|
297 |
+
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
|
298 |
+
x = self.encode(x) # b m n c
|
299 |
+
b, m, n, _ = x.shape
|
300 |
+
gt_latents = x.clone().detach().view(b, m*n, -1)
|
301 |
+
|
302 |
+
orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
|
303 |
+
mask = self.mar.random_masking(x.flatten(1, 2), orders)
|
304 |
+
|
305 |
+
input_ids = data_dict['input_ids'].to(self.device)
|
306 |
+
attention_mask = data_dict['attention_mask'].to(self.device)
|
307 |
+
x_enc = self.forward_mae_encoder(x, mask, input_ids=input_ids,
|
308 |
+
attention_mask=attention_mask)
|
309 |
+
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n))
|
310 |
+
|
311 |
+
loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
|
312 |
+
|
313 |
+
return loss
|
314 |
+
|
315 |
+
def image2text_loss(self, data_dict):
|
316 |
+
input_ids = data_dict['input_ids'].to(self.device)
|
317 |
+
attention_mask = data_dict['attention_mask'].to(self.device)
|
318 |
+
labels = data_dict['labels'].to(self.device)
|
319 |
+
|
320 |
+
pixel_values = data_dict.get('pixel_values', None)
|
321 |
+
# print("pixel_values batch:", pixel_values.shape)
|
322 |
+
# print("input_ids batch:", input_ids.shape)
|
323 |
+
if pixel_values is None:
|
324 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
325 |
+
_, z_null = self.extract_visual_feature(
|
326 |
+
torch.zeros(1, 16, 16, self.token_embed_dim,
|
327 |
+
dtype=self.dtype, device=self.device)
|
328 |
+
)
|
329 |
+
loss_null = z_null.mean() * 0.0
|
330 |
+
print(f"No image found in this batch!", flush=True)
|
331 |
+
else:
|
332 |
+
x = pixel_values.to(dtype=self.dtype, device=self.device)
|
333 |
+
x = self.encode(x) # b m n c
|
334 |
+
_, z_enc = self.extract_visual_feature(x)
|
335 |
+
|
336 |
+
if self.grad_scale is not None:
|
337 |
+
z_enc = _ScaleGradient.apply(z_enc, self.grad_scale)
|
338 |
+
|
339 |
+
inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
|
340 |
+
|
341 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
342 |
+
IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
|
343 |
+
# print(f"IMAGE_TOKEN_INDEX: {IMAGE_TOKEN_INDEX}")
|
344 |
+
img_tokens = (torch.tensor(input_ids) == IMAGE_TOKEN_INDEX).sum().item()
|
345 |
+
# print(f"[校验日志] input_ids长度: {len('input_ids')}, 图像token出现次数: {img_tokens}\n")
|
346 |
+
|
347 |
+
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_enc.flatten(0, 1)
|
348 |
+
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
|
349 |
+
input_ids[input_ids != IMAGE_TOKEN_INDEX])
|
350 |
+
loss_null = 0.0
|
351 |
+
|
352 |
+
output = self.llm_model(inputs_embeds=inputs_embeds,
|
353 |
+
attention_mask=attention_mask,
|
354 |
+
return_dict=True)
|
355 |
+
|
356 |
+
last_hidden_state = output.last_hidden_state[:, :-1]
|
357 |
+
labels = labels[:, 1:]
|
358 |
+
last_hidden_state = last_hidden_state[labels >= 0]
|
359 |
+
labels = labels[labels >= 0]
|
360 |
+
logits = self.llm.get_output_embeddings()(last_hidden_state)
|
361 |
+
|
362 |
+
loss_i2t = F.cross_entropy(input=logits, target=labels)
|
363 |
+
|
364 |
+
return loss_i2t + loss_null
|
365 |
+
|
366 |
+
# def image_edit_loss(self, data_dict):
|
367 |
+
# # 1. 图像前向:拼 batch 并编码到视觉特征
|
368 |
+
# x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device) # 源图像批次,shape=[b_src, C, H, W]
|
369 |
+
# x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device) # 编辑图像批次,shape=[b_edit, C, H, W]
|
370 |
+
# print_log(f"[DEBUG image_edit_loss] x_src.shape = {x_src.shape}, x.shape = {x.shape}", level="WARNING")
|
371 |
+
|
372 |
+
# # b_edit 应该 >= b_src
|
373 |
+
# assert x.shape[0] >= x_src.shape[0], \
|
374 |
+
# f"编辑批次大小 ({x.shape[0]}) 必须 >= 源图像批次大小 ({x_src.shape[0]})"
|
375 |
+
|
376 |
+
# # 拼接并一次性编码
|
377 |
+
# x_all = torch.cat([x_src, x], dim=0) # shape=[b_src + b_edit, C, H, W]
|
378 |
+
# x_all = self.encode(x_all) # shape=[b_src + b_edit, m, n, c]
|
379 |
+
# # 分割回源/编辑两部分
|
380 |
+
# x_src_enc, x_enc = x_all.split([x_src.shape[0], x.shape[0]], dim=0)
|
381 |
+
# # x_src_enc.shape=[b_src, m, n, c], x_enc.shape=[b_edit, m, n, c]
|
382 |
+
|
383 |
+
# # 2. 提取视觉特征:x_con 用于 decoder 条件,z_src 用于填充文本中的 <image> token
|
384 |
+
# x_con, z_src = self.extract_visual_feature(x_src_enc)
|
385 |
+
# if self.grad_scale is not None:
|
386 |
+
# x_con = _ScaleGradient.apply(x_con, self.grad_scale)
|
387 |
+
# z_src = _ScaleGradient.apply(z_src, self.grad_scale)
|
388 |
+
# # z_src.shape = [b_src, m*n, C]
|
389 |
+
|
390 |
+
# # 3. 文本条件分支:构造 inputs_embeds
|
391 |
+
# attention_mask = data_dict['attention_mask'].to(self.device) # shape=[b_edit, seq_len]
|
392 |
+
# input_ids = data_dict['input_ids'].to(self.device) # shape=[b_edit, seq_len]
|
393 |
+
# b_edit, seq_len = input_ids.shape
|
394 |
+
# hidden_size = self.llm.config.hidden_size
|
395 |
+
|
396 |
+
# # 先准备一个全 0 的 inputs_embeds
|
397 |
+
# inputs_embeds = z_src.new_zeros(b_edit, seq_len, hidden_size) # shape=[b_edit, seq_len, hidden_size]
|
398 |
+
|
399 |
+
# # 找到所有 <image> token 位置的 mask
|
400 |
+
# mask_imgpos = (input_ids == IMAGE_TOKEN_INDEX) # bool tensor [b_edit, seq_len]
|
401 |
+
|
402 |
+
# # 需要将单个 z_src 展开成 b_edit 份,再按 mask_imgpos 填入
|
403 |
+
# # 1) expand:把 z_src 从 [b_src, m*n, C] → [b_edit, m*n, C]
|
404 |
+
# # (一般 b_src=1,所以就是复制那一份)
|
405 |
+
# z_src_rep = z_src.expand(b_edit, -1, -1) # [b_edit, m*n, C]
|
406 |
+
# # 2) flatten:将二维展开到一维,对应 mask_imgpos.sum() 个位置
|
407 |
+
# flat_z = z_src_rep.flatten(0, 1) # [b_edit*m*n, C]
|
408 |
+
|
409 |
+
# # **重要检查**:保证 mask_imgpos 中 True 的数量 == flat_z.shape[0]
|
410 |
+
# img_tokens_count = mask_imgpos.sum().item()
|
411 |
+
# assert img_tokens_count == flat_z.shape[0], \
|
412 |
+
# f"<image> token 数 ({img_tokens_count}) 不等于视觉特征数 ({flat_z.shape[0]})"
|
413 |
+
|
414 |
+
# # 填充视觉 token 对应位置
|
415 |
+
# inputs_embeds[mask_imgpos] = flat_z
|
416 |
+
|
417 |
+
# # 剩下的位置用文本 embedding
|
418 |
+
# txt_pos = ~mask_imgpos
|
419 |
+
# txt_embeddings = self.llm.get_input_embeddings()(input_ids[txt_pos])
|
420 |
+
# inputs_embeds[txt_pos] = txt_embeddings
|
421 |
+
|
422 |
+
# # 4. MAE-style 重建分支:在 decoder 前注入 inputs_embeds 与 attention_mask
|
423 |
+
# b, m, n, c = x_enc.shape
|
424 |
+
# gt = x_enc.view(b, m*n, c) # 作为重建目标
|
425 |
+
# orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
|
426 |
+
# mask = self.mar.random_masking(x_enc.flatten(1, 2), orders)
|
427 |
+
|
428 |
+
# # 带条件的 encoder forward
|
429 |
+
# x_enc_out = self.forward_mae_encoder(
|
430 |
+
# x_enc,
|
431 |
+
# mask,
|
432 |
+
# inputs_embeds=inputs_embeds,
|
433 |
+
# attention_mask=attention_mask
|
434 |
+
# )
|
435 |
+
# # decoder 重建
|
436 |
+
# z_dec = self.mar.forward_mae_decoder(
|
437 |
+
# x_enc_out,
|
438 |
+
# mask,
|
439 |
+
# image_shape=(m, n),
|
440 |
+
# x_con=x_con
|
441 |
+
# )
|
442 |
+
# # 计算损失
|
443 |
+
# loss = self.mar.forward_loss(z=z_dec, target=gt, mask=mask)
|
444 |
+
# return loss
|
445 |
+
|
446 |
+
|
447 |
+
# def image_edit_loss_vae(self, data_dict):
|
448 |
+
# """
|
449 |
+
# 计算图像编辑任务的损失。
|
450 |
+
# 参考图(x_src)的特征直接作为条件(x_con)送入解码器,不参与编码器重建。
|
451 |
+
# 编码器(encoder)仅在目标图(x_tgt)上进行掩码重建,并接收文本和参考图的上下文信息。
|
452 |
+
# """
|
453 |
+
# # === 步骤 1: 读入数据 ===
|
454 |
+
# x_src = data_dict['pixel_values_src'].to(self.device).to(self.dtype)
|
455 |
+
# x_tgt = data_dict['pixel_values'].to(self.device).to(self.dtype)
|
456 |
+
# attention_mask = data_dict['attention_mask'].to(self.device)
|
457 |
+
# input_ids = data_dict['input_ids'].to(self.device)
|
458 |
+
# # IMG_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
|
459 |
+
# B = x_tgt.shape[0]
|
460 |
+
|
461 |
+
# # === 步骤 2: 处理参考图 (Reference Image) ===
|
462 |
+
# # VAE编码,不计算梯度
|
463 |
+
# with torch.no_grad():
|
464 |
+
# z_src_latent = self.encode(x_src) # [B, m, n, token_dim]
|
465 |
+
|
466 |
+
# # 将VAE潜变量转换为解码器条件(x_con)和LLM输入(z_src_buf)
|
467 |
+
# # 这一步实现了 "参考图潜变量 -> 解码器" 的直接通路
|
468 |
+
# x_con, z_src_buf = self.vae_latent_to_decoder_feature(z_src_latent)
|
469 |
+
# # x_con: [B, 4096, enc_dim] -> 用于解码器
|
470 |
+
# # z_src_buf: [B, 4160, llm_dim] -> 用于LLM
|
471 |
+
|
472 |
+
# # === 步骤 3: 构建LLM的输入 (inputs_embeds) ===
|
473 |
+
# # 结合文本指令(input_ids)和参考图特征(z_src_buf)
|
474 |
+
# _, T = input_ids.shape
|
475 |
+
# H_llm = self.llm.config.hidden_size
|
476 |
+
# inputs_embeds = torch.zeros(B, T, H_llm, device=self.device, dtype=z_src_buf.dtype)
|
477 |
+
|
478 |
+
# # 填充<image> token和文本token的嵌入
|
479 |
+
# inputs_embeds[input_ids == IMG_TOKEN_INDEX] = z_src_buf.flatten(0, 1)
|
480 |
+
# # input_ids 为33280
|
481 |
+
# # z_src_buf.flatten(0, 1) 为33792 为什么 会比input_ids 多512个呢?
|
482 |
+
# inputs_embeds[input_ids != IMG_TOKEN_INDEX] = self.llm.get_input_embeddings()(
|
483 |
+
# input_ids[input_ids != IMG_TOKEN_INDEX]
|
484 |
+
# )
|
485 |
+
|
486 |
+
# # === 步骤 4: 处理目标图 (Target Image) 并进行编码器前向传播 ===
|
487 |
+
# # VAE编码目标图,不计算梯度
|
488 |
+
# with torch.no_grad():
|
489 |
+
# z_tgt_latent = self.encode(x_tgt) # [B, m, n, token_dim]
|
490 |
+
|
491 |
+
# # 为目标图潜变量创建掩码(mask)以进行MAE重建
|
492 |
+
# B, m, n, token_dim = z_tgt_latent.shape
|
493 |
+
# patch_tokens_tgt = z_tgt_latent.view(B, m * n, token_dim) # 作为重建的目标
|
494 |
+
# orders = self.mar.sample_orders(bsz=B, seq_len=m * n)
|
495 |
+
# mask = self.mar.random_masking(patch_tokens_tgt, orders)
|
496 |
+
|
497 |
+
# # **核心**: 编码器只处理目标图(z_tgt_latent)的可见部分,并接收LLM的上下文
|
498 |
+
# x_enc = self.forward_mae_encoder(
|
499 |
+
# z_tgt_latent, # 目标图潜变量
|
500 |
+
# mask,
|
501 |
+
# detach=False,
|
502 |
+
# inputs_embeds=inputs_embeds, # 包含文本和参考图信息的上下文
|
503 |
+
# attention_mask=attention_mask
|
504 |
+
# )
|
505 |
+
|
506 |
+
# # === 步骤 5: 解码器重建 ===
|
507 |
+
# # 解码器使用编码器的输出(x_enc)和参考图的特征(x_con)来重建完整的潜在表示
|
508 |
+
# z_pred = self.mar.forward_mae_decoder(
|
509 |
+
# x_enc,
|
510 |
+
# mask,
|
511 |
+
# image_shape=(m, n),
|
512 |
+
# x_con=x_con # ★ 参考图特征直接作用于此
|
513 |
+
# )
|
514 |
+
|
515 |
+
# # === 步骤 6: 计算损失 ===
|
516 |
+
# loss = self.mar.forward_loss(
|
517 |
+
# z=z_pred,
|
518 |
+
# target=patch_tokens_tgt,
|
519 |
+
# mask=mask
|
520 |
+
# )
|
521 |
+
# return loss
|
522 |
+
|
523 |
+
def image_edit_loss_contrastive(self, data_dict):
|
524 |
+
# Step 1: 获取图像特征
|
525 |
+
x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device)
|
526 |
+
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
|
527 |
+
assert len(x_src) >= len(x)
|
528 |
+
x_src, x = self.encode(torch.cat([x_src, x])).split([len(x_src), len(x)], dim=0)
|
529 |
+
|
530 |
+
# Step 2: 文本输入部分
|
531 |
+
attention_mask = data_dict['attention_mask'].to(self.device)
|
532 |
+
input_ids = data_dict['input_ids'].to(self.device)
|
533 |
+
|
534 |
+
x_con, z_src = self.extract_visual_feature(x_src)
|
535 |
+
if self.grad_scale is not None:
|
536 |
+
z_src = _ScaleGradient.apply(z_src, self.grad_scale)
|
537 |
+
x_con = _ScaleGradient.apply(x_con, self.grad_scale)
|
538 |
+
|
539 |
+
inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
|
540 |
+
# IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
|
541 |
+
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1)
|
542 |
+
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
|
543 |
+
input_ids[input_ids != IMAGE_TOKEN_INDEX]
|
544 |
+
)
|
545 |
+
|
546 |
+
# Step 3: 计算 reconstruction loss
|
547 |
+
b, m, n, _ = x.shape
|
548 |
+
gt_latents = x.clone().detach().view(b, m * n, -1)
|
549 |
+
orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
|
550 |
+
mask = self.mar.random_masking(x.flatten(1, 2), orders)
|
551 |
+
x_enc = self.forward_mae_encoder(x, mask,
|
552 |
+
inputs_embeds=inputs_embeds,
|
553 |
+
attention_mask=attention_mask)
|
554 |
+
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con)
|
555 |
+
rec_loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
|
556 |
+
|
557 |
+
# Step 4: Contrastive loss between repeat and edit
|
558 |
+
# 假设 batch 是偶数,按 (repeat, edit) 对排列
|
559 |
+
z_src_flat = z_src.mean(dim=1) # [B, D] 全局池化
|
560 |
+
z_src_flat = F.normalize(z_src_flat, dim=-1)
|
561 |
+
|
562 |
+
repeat_z = z_src_flat[::2] # even index
|
563 |
+
edit_z = z_src_flat[1::2] # odd index
|
564 |
+
|
565 |
+
logits = torch.matmul(edit_z, repeat_z.T) / 0.07 # [B, B]
|
566 |
+
labels = torch.arange(logits.size(0), device=logits.device)
|
567 |
+
contrastive_loss = F.cross_entropy(logits, labels)
|
568 |
+
|
569 |
+
return rec_loss + self.loss_weights.get("contrastive") * contrastive_loss
|
570 |
+
|
571 |
+
def image_edit_loss(self, data_dict):
|
572 |
+
# Multi-turn editing is also supported
|
573 |
+
x_src = data_dict['pixel_values_src'].to(dtype=self.dtype, device=self.device)
|
574 |
+
x = data_dict['pixel_values'].to(dtype=self.dtype, device=self.device)
|
575 |
+
# print_log(f"[DEBUG] x_src.shape = {x_src.shape}, x.shape = {x.shape}")
|
576 |
+
|
577 |
+
# assert len(x_src) >= len(x)
|
578 |
+
x_cat = torch.cat([x_src, x], dim=0)
|
579 |
+
x_src, x = self.encode(x_cat).split([len(x_src), len(x)], dim=0)
|
580 |
+
|
581 |
+
# Prepare context, including source images and instructions
|
582 |
+
attention_mask = data_dict['attention_mask'].to(self.device)
|
583 |
+
input_ids = data_dict['input_ids'].to(self.device)
|
584 |
+
|
585 |
+
x_con, z_src = self.extract_visual_feature(x_src)
|
586 |
+
if self.grad_scale is not None:
|
587 |
+
z_src = _ScaleGradient.apply(z_src, self.grad_scale)
|
588 |
+
x_con = _ScaleGradient.apply(x_con, self.grad_scale)
|
589 |
+
|
590 |
+
inputs_embeds = z_src.new_zeros(*input_ids.shape, self.llm.config.hidden_size)
|
591 |
+
|
592 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
593 |
+
|
594 |
+
IMAGE_TOKEN_INDEX = self.tokenizer.convert_tokens_to_ids("<image>")
|
595 |
+
# print("tokenizer idx in skywork_unipic_dev=", self.tokenizer.convert_tokens_to_ids("<image>"))
|
596 |
+
|
597 |
+
inputs_embeds[input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1)
|
598 |
+
inputs_embeds[input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
|
599 |
+
input_ids[input_ids != IMAGE_TOKEN_INDEX]
|
600 |
+
)
|
601 |
+
|
602 |
+
# --------------------------------------------------
|
603 |
+
# 3. MAE-style 重建
|
604 |
+
# --------------------------------------------------
|
605 |
+
|
606 |
+
b, m, n, _ = x.shape
|
607 |
+
gt_latents = x.clone().detach().view(b, m * n, -1)
|
608 |
+
orders = self.mar.sample_orders(bsz=b, seq_len=m*n)
|
609 |
+
mask = self.mar.random_masking(x.flatten(1, 2), orders)
|
610 |
+
x_enc = self.forward_mae_encoder(x, mask,
|
611 |
+
inputs_embeds=inputs_embeds,
|
612 |
+
attention_mask=attention_mask)
|
613 |
+
z = self.mar.forward_mae_decoder(x_enc, mask, image_shape=(m, n), x_con=x_con)
|
614 |
+
|
615 |
+
loss = self.mar.forward_loss(z=z, target=gt_latents, mask=mask)
|
616 |
+
return loss
|
617 |
+
|
618 |
+
|
619 |
+
|
620 |
+
def forward(self, data, data_samples=None, mode='loss'):
|
621 |
+
if mode == 'loss':
|
622 |
+
return self.compute_loss(data_dict=data)
|
623 |
+
else:
|
624 |
+
raise NotImplementedError
|
625 |
+
|
626 |
+
def compute_loss(self, data_dict):
|
627 |
+
losses = {}
|
628 |
+
for data_type, batch_data in data_dict.items():
|
629 |
+
if 'text2image' in data_type:
|
630 |
+
loss = self.text2image_loss(batch_data)
|
631 |
+
elif 'image2text' in data_type:
|
632 |
+
loss = self.image2text_loss(batch_data)
|
633 |
+
elif 'image_edit' in data_type:
|
634 |
+
loss = self.image_edit_loss(batch_data)
|
635 |
+
else:
|
636 |
+
raise NotImplementedError(f"Unknown data_type: {data_type}")
|
637 |
+
weight = self.loss_weights.get(data_type, 1.0)
|
638 |
+
losses[f'loss_{data_type}'] = loss * weight
|
639 |
+
return losses
|
640 |
+
|
641 |
+
|
642 |
+
|
643 |
+
|
644 |
+
|
645 |
+
|
src/models/skywork_unipic_ori.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
import contextlib
|
6 |
+
from einops import rearrange
|
7 |
+
from transformers.cache_utils import DynamicCache
|
8 |
+
from src.builder import BUILDER
|
9 |
+
from tqdm import tqdm
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
+
from transformers.integrations.deepspeed import (
|
12 |
+
is_deepspeed_zero3_enabled,
|
13 |
+
set_hf_deepspeed_config,
|
14 |
+
unset_hf_deepspeed_config,
|
15 |
+
deepspeed_config
|
16 |
+
)
|
17 |
+
|
18 |
+
@contextlib.contextmanager
|
19 |
+
def temporarily_disable_deepspeed_zero3():
|
20 |
+
if is_deepspeed_zero3_enabled():
|
21 |
+
config = deepspeed_config()
|
22 |
+
print(f'[DEBUG] ds config={config}')
|
23 |
+
unset_hf_deepspeed_config()
|
24 |
+
yield
|
25 |
+
set_hf_deepspeed_config(config)
|
26 |
+
else:
|
27 |
+
yield
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def build_mlp(hidden_size, projector_dim, z_dim):
|
32 |
+
return nn.Sequential(
|
33 |
+
nn.Linear(hidden_size, projector_dim),
|
34 |
+
nn.SiLU(),
|
35 |
+
nn.Linear(projector_dim, z_dim),)
|
36 |
+
|
37 |
+
|
38 |
+
def mask_by_order(mask_len, order, bsz, seq_len):
|
39 |
+
masking = torch.zeros(bsz, seq_len, device=order.device)
|
40 |
+
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
|
41 |
+
src=torch.ones(bsz, seq_len, device=order.device)).bool()
|
42 |
+
return masking
|
43 |
+
|
44 |
+
|
45 |
+
class SkyworkUnipic(nn.Module):
|
46 |
+
def __init__(self,
|
47 |
+
vae,
|
48 |
+
vae_scale,
|
49 |
+
llm,
|
50 |
+
mar,
|
51 |
+
tokenizer,
|
52 |
+
prompt_template):
|
53 |
+
super().__init__()
|
54 |
+
with temporarily_disable_deepspeed_zero3():
|
55 |
+
# VAE
|
56 |
+
self.vae = BUILDER.build(vae)
|
57 |
+
self.vae.requires_grad_(False)
|
58 |
+
self.vae_scale = vae_scale
|
59 |
+
|
60 |
+
# LLM
|
61 |
+
self.llm = BUILDER.build(llm)
|
62 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
63 |
+
self.prompt_template = prompt_template
|
64 |
+
|
65 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
66 |
+
image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
|
67 |
+
print(f"Registered <image> token at index {image_token_idx}")
|
68 |
+
|
69 |
+
# MAR
|
70 |
+
self.mar = BUILDER.build(mar)
|
71 |
+
# projection layers
|
72 |
+
self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim,
|
73 |
+
projector_dim=self.llm.config.hidden_size,
|
74 |
+
z_dim=self.llm.config.hidden_size)
|
75 |
+
self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size,
|
76 |
+
projector_dim=self.llm.config.hidden_size,
|
77 |
+
z_dim=self.mar.encoder_embed_dim)
|
78 |
+
|
79 |
+
|
80 |
+
@property
|
81 |
+
def llm_model(self):
|
82 |
+
return self.llm.model
|
83 |
+
|
84 |
+
@property
|
85 |
+
def device(self):
|
86 |
+
return self.llm.device
|
87 |
+
|
88 |
+
@property
|
89 |
+
def dtype(self):
|
90 |
+
return self.llm.dtype
|
91 |
+
|
92 |
+
@property
|
93 |
+
def gen_seq_len(self):
|
94 |
+
return self.mar.seq_len
|
95 |
+
|
96 |
+
@property
|
97 |
+
def token_embed_dim(self):
|
98 |
+
return self.vae.embed_dim * (self.mar.patch_size ** 2)
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def encode(self, x):
|
102 |
+
posterior = self.vae.encode(x)
|
103 |
+
z = posterior.mode().mul_(self.vae_scale)
|
104 |
+
z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)',
|
105 |
+
p=self.mar.patch_size, q=self.mar.patch_size)
|
106 |
+
|
107 |
+
return z
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def decode(self, z):
|
111 |
+
z /= self.vae_scale
|
112 |
+
z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)',
|
113 |
+
p=self.mar.patch_size, q=self.mar.patch_size)
|
114 |
+
|
115 |
+
x = self.vae.decode(z)
|
116 |
+
return x
|
117 |
+
|
118 |
+
def prepare_forward_input(self,
|
119 |
+
x,
|
120 |
+
inputs_embeds=None,
|
121 |
+
input_ids=None,
|
122 |
+
attention_mask=None,
|
123 |
+
past_key_values=None):
|
124 |
+
b, l, _ = x.shape
|
125 |
+
attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
|
126 |
+
attention_mask = torch.cat([
|
127 |
+
attention_mask, attention_mask.new_ones(b, l)
|
128 |
+
], dim=1)
|
129 |
+
position_ids = torch.cumsum(attention_mask, dim=1) - 1
|
130 |
+
position_ids[position_ids < 0] = 0
|
131 |
+
|
132 |
+
# import pdb; pdb.set_trace()
|
133 |
+
|
134 |
+
# prepare context
|
135 |
+
if past_key_values is not None:
|
136 |
+
inputs_embeds = x
|
137 |
+
position_ids = position_ids[:, -l:]
|
138 |
+
else:
|
139 |
+
if inputs_embeds is None:
|
140 |
+
input_ids = input_ids.to(self.device)
|
141 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
142 |
+
inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
|
143 |
+
|
144 |
+
return dict(inputs_embeds=inputs_embeds,
|
145 |
+
attention_mask=attention_mask,
|
146 |
+
position_ids=position_ids,
|
147 |
+
past_key_values=past_key_values)
|
148 |
+
|
149 |
+
def extract_visual_feature(self, x, mask=None, detach=False):
|
150 |
+
b, m, n, _ = x.shape
|
151 |
+
x = x.view(b, m*n, -1)
|
152 |
+
# x: b mn c
|
153 |
+
if mask is None:
|
154 |
+
mask = torch.zeros_like(x[..., 0])
|
155 |
+
null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
|
156 |
+
x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
|
157 |
+
|
158 |
+
z_enc = self.proj_in(x_enc)
|
159 |
+
# Move buffers to the end of the image sequence
|
160 |
+
z_enc = torch.cat([
|
161 |
+
z_enc[:, self.mar.buffer_size:],
|
162 |
+
z_enc[:, :self.mar.buffer_size]], dim=1)
|
163 |
+
|
164 |
+
if detach:
|
165 |
+
x_enc = x_enc.detach()
|
166 |
+
z_enc = z_enc.detach()
|
167 |
+
|
168 |
+
return x_enc, z_enc
|
169 |
+
|
170 |
+
def vae_latent_to_decoder_feature(self, z_src_latent):
|
171 |
+
"""
|
172 |
+
Returns:
|
173 |
+
x_con [B, buf_sz + m*n, enc_dim] for the MAE decoder
|
174 |
+
z_src_buf [B, buf_sz + m*n, llm_dim] to scatter into <image> tokens
|
175 |
+
"""
|
176 |
+
B, m, n, token_dim = z_src_latent.shape
|
177 |
+
num_patches = m * n
|
178 |
+
enc_dim = self.mar.encoder_embed_dim # e.g. 1280
|
179 |
+
llm_dim = self.llm.config.hidden_size # e.g. 1536
|
180 |
+
buf_sz = self.mar.buffer_size # e.g. 64
|
181 |
+
|
182 |
+
# 1) flatten patches → [B,4096,token_dim]
|
183 |
+
patch_tokens = z_src_latent.view(B, num_patches, token_dim)
|
184 |
+
|
185 |
+
# 2) project to encoder dim → [B,4096,enc_dim]
|
186 |
+
z_enc = self.mar.z_proj(patch_tokens)
|
187 |
+
z_enc = self.mar.z_proj_ln(z_enc)
|
188 |
+
|
189 |
+
# (optional) add encoder pos embed for image part only
|
190 |
+
full_pos = self.mar.get_encoder_pos_embed(h=m, w=n) # [1, buf_sz+4096, enc_dim]
|
191 |
+
pos_img = full_pos[:, buf_sz:] # [1,4096,enc_dim]
|
192 |
+
z_enc = z_enc + pos_img
|
193 |
+
|
194 |
+
# 3) build x_con for MAE decoder: **one** buffer pad + image tokens
|
195 |
+
buf_enc = torch.zeros(B, buf_sz, enc_dim,
|
196 |
+
device=z_enc.device, dtype=z_enc.dtype)
|
197 |
+
x_con = torch.cat([buf_enc, z_enc], dim=1) # [B,4160,enc_dim]
|
198 |
+
|
199 |
+
# 4) build z_src_buf for LLM: **project the exact same** x_con, then rotate buffer→end
|
200 |
+
z_proj_llm = self.proj_in(x_con) # [B,4160,llm_dim]
|
201 |
+
# rotate: take image portion then buffer portion
|
202 |
+
z_src_buf = torch.cat([
|
203 |
+
z_proj_llm[:, buf_sz:], # [B,4096,llm_dim]
|
204 |
+
z_proj_llm[:, :buf_sz] # [B, 64,llm_dim]
|
205 |
+
], dim=1) # [B,4160,llm_dim]
|
206 |
+
|
207 |
+
return x_con, z_src_buf
|
208 |
+
|
209 |
+
def forward_mae_encoder(self, x, mask, detach=False, **context):
|
210 |
+
b, m, n, _ = x.shape
|
211 |
+
x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
|
212 |
+
inputs = self.prepare_forward_input(x=z_enc, **context)
|
213 |
+
output = self.llm_model(**inputs, return_dict=True)
|
214 |
+
|
215 |
+
z_llm = output.last_hidden_state[:, -z_enc.shape[1]:]
|
216 |
+
|
217 |
+
# move buffers back to the start of the image sequence
|
218 |
+
z_llm = torch.cat([
|
219 |
+
z_llm[:, -self.mar.buffer_size:],
|
220 |
+
z_llm[:, :-self.mar.buffer_size]], dim=1)
|
221 |
+
|
222 |
+
# residual learning
|
223 |
+
x_enc = x_enc + self.proj_out(z_llm)
|
224 |
+
|
225 |
+
return x_enc
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def curtail_cache(past_key_values, cur_len):
|
229 |
+
for past_key_values_ in past_key_values:
|
230 |
+
keys, values = past_key_values_
|
231 |
+
keys.data = keys.data[:, :, :cur_len]
|
232 |
+
values.data = values.data[:, :, :cur_len]
|
233 |
+
|
234 |
+
@torch.no_grad()
|
235 |
+
def prepare_text_conditions(self, prompt, cfg_prompt='Generate an image.'):
|
236 |
+
all_prompts = [self.prompt_template['INSTRUCTION'].format(input=prompt),
|
237 |
+
self.prompt_template['INSTRUCTION'].format(input=cfg_prompt)]
|
238 |
+
|
239 |
+
|
240 |
+
input_ids = [self.tokenizer.encode(p, add_special_tokens=True, return_tensors='pt')[0]
|
241 |
+
for p in all_prompts]
|
242 |
+
valid_lens = [len(input_ids_) for input_ids_ in input_ids]
|
243 |
+
input_ids = pad_sequence(input_ids, batch_first=True,
|
244 |
+
padding_value=self.tokenizer.eos_token_id)
|
245 |
+
attention_mask = torch.zeros_like(input_ids).bool()
|
246 |
+
for i in range(len(input_ids)):
|
247 |
+
attention_mask[i, :valid_lens[i]] = True
|
248 |
+
|
249 |
+
return dict(input_ids=input_ids.to(self.device),
|
250 |
+
attention_mask=attention_mask.to(self.device))
|
251 |
+
|
252 |
+
@torch.no_grad()
|
253 |
+
def sample(self,
|
254 |
+
input_ids=None, inputs_embeds=None,
|
255 |
+
attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0,
|
256 |
+
progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs):
|
257 |
+
if inputs_embeds is None and input_ids is not None:
|
258 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
259 |
+
|
260 |
+
bsz = attention_mask.shape[0]
|
261 |
+
if cfg != 1.0:
|
262 |
+
assert bsz % 2 == 0
|
263 |
+
|
264 |
+
if image_shape is None:
|
265 |
+
m = n = int(self.gen_seq_len ** 0.5)
|
266 |
+
else:
|
267 |
+
m, n = image_shape
|
268 |
+
|
269 |
+
if mask is None:
|
270 |
+
mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype)
|
271 |
+
else:
|
272 |
+
mask = mask.view(bsz, m*n)
|
273 |
+
tokens = torch.zeros(bsz, m*n, self.token_embed_dim,
|
274 |
+
device=self.device, dtype=self.dtype)
|
275 |
+
orders = self.mar.sample_orders(bsz, seq_len=m*n)
|
276 |
+
if cfg != 1.0:
|
277 |
+
orders[bsz//2:] = orders[:bsz//2]
|
278 |
+
|
279 |
+
indices = list(range(num_iter))
|
280 |
+
if progress:
|
281 |
+
indices = tqdm(indices)
|
282 |
+
|
283 |
+
# past key values can be prepared outside (usually in multi-turn editing)
|
284 |
+
if past_key_values is None:
|
285 |
+
output = self.llm_model(inputs_embeds=inputs_embeds,
|
286 |
+
attention_mask=None,
|
287 |
+
position_ids=None,
|
288 |
+
past_key_values=DynamicCache.from_legacy_cache(),
|
289 |
+
return_dict=True,
|
290 |
+
use_cache=True)
|
291 |
+
past_key_values = output.past_key_values
|
292 |
+
|
293 |
+
# generate latents
|
294 |
+
for step in indices:
|
295 |
+
cur_tokens = tokens.clone()
|
296 |
+
x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1),
|
297 |
+
mask.to(self.dtype),
|
298 |
+
past_key_values=past_key_values,
|
299 |
+
# inputs_embeds=inputs_embeds,
|
300 |
+
attention_mask=attention_mask)
|
301 |
+
# import pdb; pdb.set_trace()
|
302 |
+
self.curtail_cache(past_key_values, inputs_embeds.shape[1])
|
303 |
+
# import pdb; pdb.set_trace()
|
304 |
+
|
305 |
+
z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con)
|
306 |
+
|
307 |
+
# mask ratio for the next round, following MaskGIT and MAGE.
|
308 |
+
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
|
309 |
+
mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device)
|
310 |
+
|
311 |
+
# masks out at least one for the next iteration
|
312 |
+
mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
|
313 |
+
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
|
314 |
+
|
315 |
+
# get masking for next iteration and locations to be predicted in this iteration
|
316 |
+
mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device)
|
317 |
+
if cfg != 1.0:
|
318 |
+
mask_next[bsz//2:] = mask_next[:bsz//2]
|
319 |
+
if step >= num_iter - 1:
|
320 |
+
mask_to_pred = mask[:bsz].bool()
|
321 |
+
else:
|
322 |
+
mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
|
323 |
+
mask = mask_next
|
324 |
+
# if not cfg == 1.0:
|
325 |
+
# mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
326 |
+
|
327 |
+
# sample token latents for this step
|
328 |
+
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
329 |
+
# cfg schedule follow Muse
|
330 |
+
if cfg_schedule == "linear":
|
331 |
+
cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n)
|
332 |
+
elif cfg_schedule == "constant":
|
333 |
+
cfg_iter = cfg
|
334 |
+
else:
|
335 |
+
raise NotImplementedError
|
336 |
+
sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype)
|
337 |
+
# if not cfg == 1.0:
|
338 |
+
# sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
|
339 |
+
# mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
|
340 |
+
|
341 |
+
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
|
342 |
+
if cfg != 1.0:
|
343 |
+
cur_tokens[bsz//2:] = cur_tokens[:bsz//2]
|
344 |
+
tokens = cur_tokens.clone()
|
345 |
+
|
346 |
+
pred = self.decode(tokens.view(bsz, m, n, -1))
|
347 |
+
|
348 |
+
if cfg != 1.0:
|
349 |
+
pred = pred[:bsz//2]
|
350 |
+
return pred
|
src/models/skywork_unipic_siglip.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from transformers.cache_utils import DynamicCache
|
7 |
+
from src.builder import BUILDER
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
|
11 |
+
|
12 |
+
def build_mlp(hidden_size, projector_dim, z_dim):
|
13 |
+
return nn.Sequential(
|
14 |
+
nn.Linear(hidden_size, projector_dim),
|
15 |
+
nn.SiLU(),
|
16 |
+
nn.Linear(projector_dim, z_dim),
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def mask_by_order(mask_len, order, bsz, seq_len):
|
21 |
+
masking = torch.zeros(bsz, seq_len, device=order.device)
|
22 |
+
masking = torch.scatter(
|
23 |
+
masking,
|
24 |
+
dim=-1,
|
25 |
+
index=order[:, : mask_len.long()],
|
26 |
+
src=torch.ones(bsz, seq_len, device=order.device),
|
27 |
+
).bool()
|
28 |
+
return masking
|
29 |
+
|
30 |
+
|
31 |
+
class SkyworkUnipic(nn.Module):
|
32 |
+
def __init__(self, vae, vae_scale, llm, mar, tokenizer, prompt_template, siglip2):
|
33 |
+
super().__init__()
|
34 |
+
# VAE
|
35 |
+
self.vae = BUILDER.build(vae)
|
36 |
+
self.vae.requires_grad_(False)
|
37 |
+
self.vae_scale = vae_scale
|
38 |
+
|
39 |
+
# LLM
|
40 |
+
self.llm = BUILDER.build(llm)
|
41 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
42 |
+
self.tokenizer.add_tokens(["<image>"], special_tokens=True)
|
43 |
+
self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>")
|
44 |
+
|
45 |
+
self.prompt_template = prompt_template
|
46 |
+
|
47 |
+
# MAR
|
48 |
+
self.mar = BUILDER.build(mar)
|
49 |
+
# projection layers
|
50 |
+
self.proj_in = build_mlp(
|
51 |
+
hidden_size=self.mar.encoder_embed_dim,
|
52 |
+
projector_dim=self.llm.config.hidden_size,
|
53 |
+
z_dim=self.llm.config.hidden_size,
|
54 |
+
)
|
55 |
+
self.proj_out = build_mlp(
|
56 |
+
hidden_size=self.llm.config.hidden_size,
|
57 |
+
projector_dim=self.llm.config.hidden_size,
|
58 |
+
z_dim=self.mar.encoder_embed_dim,
|
59 |
+
)
|
60 |
+
|
61 |
+
# siglip
|
62 |
+
self.siglip2 = BUILDER.build(siglip2)
|
63 |
+
self.siglip2_proj = build_mlp(
|
64 |
+
hidden_size=1152,
|
65 |
+
projector_dim=self.llm.config.hidden_size,
|
66 |
+
z_dim=self.llm.config.hidden_size,
|
67 |
+
)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def llm_model(self):
|
71 |
+
return self.llm.model
|
72 |
+
|
73 |
+
@property
|
74 |
+
def device(self):
|
75 |
+
return self.llm.device
|
76 |
+
|
77 |
+
@property
|
78 |
+
def dtype(self):
|
79 |
+
return self.llm.dtype
|
80 |
+
|
81 |
+
@property
|
82 |
+
def gen_seq_len(self):
|
83 |
+
return self.mar.seq_len
|
84 |
+
|
85 |
+
@property
|
86 |
+
def token_embed_dim(self):
|
87 |
+
return self.vae.embed_dim * (self.mar.patch_size**2)
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def encode(self, x):
|
91 |
+
posterior = self.vae.encode(x)
|
92 |
+
z = posterior.mode().mul_(self.vae_scale)
|
93 |
+
z = rearrange(
|
94 |
+
z,
|
95 |
+
"b c (m p) (n q) -> b m n (c p q)",
|
96 |
+
p=self.mar.patch_size,
|
97 |
+
q=self.mar.patch_size,
|
98 |
+
)
|
99 |
+
|
100 |
+
return z
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def decode(self, z):
|
104 |
+
z /= self.vae_scale
|
105 |
+
z = rearrange(
|
106 |
+
z,
|
107 |
+
"b m n (c p q) -> b c (m p) (n q)",
|
108 |
+
p=self.mar.patch_size,
|
109 |
+
q=self.mar.patch_size,
|
110 |
+
)
|
111 |
+
|
112 |
+
x = self.vae.decode(z)
|
113 |
+
return x
|
114 |
+
|
115 |
+
def prepare_forward_input(
|
116 |
+
self,
|
117 |
+
x,
|
118 |
+
inputs_embeds=None,
|
119 |
+
input_ids=None,
|
120 |
+
attention_mask=None,
|
121 |
+
past_key_values=None,
|
122 |
+
):
|
123 |
+
b, l, _ = x.shape
|
124 |
+
attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
|
125 |
+
attention_mask = torch.cat(
|
126 |
+
[attention_mask, attention_mask.new_ones(b, l)], dim=1
|
127 |
+
)
|
128 |
+
position_ids = torch.cumsum(attention_mask, dim=1) - 1
|
129 |
+
position_ids[position_ids < 0] = 0
|
130 |
+
|
131 |
+
# import pdb; pdb.set_trace()
|
132 |
+
|
133 |
+
# prepare context
|
134 |
+
if past_key_values is not None:
|
135 |
+
inputs_embeds = x
|
136 |
+
position_ids = position_ids[:, -l:]
|
137 |
+
else:
|
138 |
+
if inputs_embeds is None:
|
139 |
+
input_ids = input_ids.to(self.device)
|
140 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
141 |
+
inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
|
142 |
+
|
143 |
+
return dict(
|
144 |
+
inputs_embeds=inputs_embeds,
|
145 |
+
attention_mask=attention_mask,
|
146 |
+
position_ids=position_ids,
|
147 |
+
past_key_values=past_key_values,
|
148 |
+
)
|
149 |
+
|
150 |
+
def extract_visual_feature(self, x, mask=None, detach=False):
|
151 |
+
b, m, n, _ = x.shape
|
152 |
+
x = x.view(b, m * n, -1)
|
153 |
+
# x: b mn c
|
154 |
+
if mask is None:
|
155 |
+
mask = torch.zeros_like(x[..., 0])
|
156 |
+
null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
|
157 |
+
x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
|
158 |
+
|
159 |
+
z_enc = self.proj_in(x_enc)
|
160 |
+
# Move buffers to the end of the image sequence
|
161 |
+
z_enc = torch.cat(
|
162 |
+
[z_enc[:, self.mar.buffer_size :], z_enc[:, : self.mar.buffer_size]], dim=1
|
163 |
+
)
|
164 |
+
|
165 |
+
if detach:
|
166 |
+
x_enc = x_enc.detach()
|
167 |
+
z_enc = z_enc.detach()
|
168 |
+
|
169 |
+
return x_enc, z_enc
|
170 |
+
|
171 |
+
def forward_mae_encoder(self, x, mask, detach=False, **context):
|
172 |
+
b, m, n, _ = x.shape
|
173 |
+
x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
|
174 |
+
inputs = self.prepare_forward_input(x=z_enc, **context)
|
175 |
+
output = self.llm_model(**inputs, return_dict=True)
|
176 |
+
|
177 |
+
z_llm = output.last_hidden_state[:, -z_enc.shape[1] :]
|
178 |
+
|
179 |
+
# move buffers back to the start of the image sequence
|
180 |
+
z_llm = torch.cat(
|
181 |
+
[z_llm[:, -self.mar.buffer_size :], z_llm[:, : -self.mar.buffer_size]],
|
182 |
+
dim=1,
|
183 |
+
)
|
184 |
+
|
185 |
+
# residual learning
|
186 |
+
x_enc = x_enc + self.proj_out(z_llm)
|
187 |
+
|
188 |
+
return x_enc
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def curtail_cache(past_key_values, cur_len):
|
192 |
+
for past_key_values_ in past_key_values:
|
193 |
+
keys, values = past_key_values_
|
194 |
+
keys.data = keys.data[:, :, :cur_len]
|
195 |
+
values.data = values.data[:, :, :cur_len]
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def prepare_text_conditions(self, prompt, cfg_prompt="Generate an image."):
|
199 |
+
all_prompts = [
|
200 |
+
self.prompt_template["INSTRUCTION"].format(input=prompt),
|
201 |
+
self.prompt_template["INSTRUCTION"].format(input=cfg_prompt),
|
202 |
+
]
|
203 |
+
|
204 |
+
input_ids = [
|
205 |
+
self.tokenizer.encode(p, add_special_tokens=True, return_tensors="pt")[0]
|
206 |
+
for p in all_prompts
|
207 |
+
]
|
208 |
+
valid_lens = [len(input_ids_) for input_ids_ in input_ids]
|
209 |
+
input_ids = pad_sequence(
|
210 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id
|
211 |
+
)
|
212 |
+
attention_mask = torch.zeros_like(input_ids).bool()
|
213 |
+
for i in range(len(input_ids)):
|
214 |
+
attention_mask[i, : valid_lens[i]] = True
|
215 |
+
|
216 |
+
return dict(
|
217 |
+
input_ids=input_ids.to(self.device),
|
218 |
+
attention_mask=attention_mask.to(self.device),
|
219 |
+
)
|
220 |
+
|
221 |
+
@torch.no_grad()
|
222 |
+
def sample(
|
223 |
+
self,
|
224 |
+
input_ids=None,
|
225 |
+
inputs_embeds=None,
|
226 |
+
attention_mask=None,
|
227 |
+
num_iter=64,
|
228 |
+
cfg=1.0,
|
229 |
+
cfg_schedule="constant",
|
230 |
+
temperature=1.0,
|
231 |
+
progress=False,
|
232 |
+
mask=None,
|
233 |
+
past_key_values=None,
|
234 |
+
image_shape=None,
|
235 |
+
x_con=None,
|
236 |
+
**kwargs,
|
237 |
+
):
|
238 |
+
if inputs_embeds is None and input_ids is not None:
|
239 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
240 |
+
|
241 |
+
bsz = attention_mask.shape[0]
|
242 |
+
if cfg != 1.0:
|
243 |
+
assert bsz % 2 == 0
|
244 |
+
|
245 |
+
if image_shape is None:
|
246 |
+
m = n = int(self.gen_seq_len**0.5)
|
247 |
+
else:
|
248 |
+
m, n = image_shape
|
249 |
+
|
250 |
+
if mask is None:
|
251 |
+
mask = torch.ones(bsz, m * n, device=self.device, dtype=self.dtype)
|
252 |
+
else:
|
253 |
+
mask = mask.view(bsz, m * n)
|
254 |
+
tokens = torch.zeros(
|
255 |
+
bsz, m * n, self.token_embed_dim, device=self.device, dtype=self.dtype
|
256 |
+
)
|
257 |
+
orders = self.mar.sample_orders(bsz, seq_len=m * n)
|
258 |
+
if cfg != 1.0:
|
259 |
+
orders[bsz // 2 :] = orders[: bsz // 2]
|
260 |
+
|
261 |
+
indices = list(range(num_iter))
|
262 |
+
if progress:
|
263 |
+
indices = tqdm(indices)
|
264 |
+
|
265 |
+
# past key values can be prepared outside (usually in multi-turn editing)
|
266 |
+
if past_key_values is None:
|
267 |
+
output = self.llm_model(
|
268 |
+
inputs_embeds=inputs_embeds,
|
269 |
+
attention_mask=None,
|
270 |
+
position_ids=None,
|
271 |
+
past_key_values=DynamicCache.from_legacy_cache(),
|
272 |
+
return_dict=True,
|
273 |
+
use_cache=True,
|
274 |
+
)
|
275 |
+
past_key_values = output.past_key_values
|
276 |
+
|
277 |
+
# generate latents
|
278 |
+
for step in indices:
|
279 |
+
cur_tokens = tokens.clone()
|
280 |
+
x_enc = self.forward_mae_encoder(
|
281 |
+
tokens.view(bsz, m, n, -1),
|
282 |
+
mask.to(self.dtype),
|
283 |
+
past_key_values=past_key_values,
|
284 |
+
# inputs_embeds=inputs_embeds,
|
285 |
+
attention_mask=attention_mask,
|
286 |
+
)
|
287 |
+
# import pdb; pdb.set_trace()
|
288 |
+
self.curtail_cache(past_key_values, inputs_embeds.shape[1])
|
289 |
+
# import pdb; pdb.set_trace()
|
290 |
+
|
291 |
+
z = self.mar.forward_mae_decoder(
|
292 |
+
x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con
|
293 |
+
)
|
294 |
+
|
295 |
+
# mask ratio for the next round, following MaskGIT and MAGE.
|
296 |
+
mask_ratio = np.cos(math.pi / 2.0 * (step + 1) / num_iter)
|
297 |
+
mask_len = torch.Tensor([np.floor(m * n * mask_ratio)]).to(self.device)
|
298 |
+
|
299 |
+
# masks out at least one for the next iteration
|
300 |
+
mask_len = torch.maximum(
|
301 |
+
torch.Tensor([1]).to(self.device),
|
302 |
+
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len),
|
303 |
+
)
|
304 |
+
|
305 |
+
# get masking for next iteration and locations to be predicted in this iteration
|
306 |
+
mask_next = mask_by_order(mask_len[0], orders, bsz, m * n).to(self.device)
|
307 |
+
if cfg != 1.0:
|
308 |
+
mask_next[bsz // 2 :] = mask_next[: bsz // 2]
|
309 |
+
if step >= num_iter - 1:
|
310 |
+
mask_to_pred = mask[:bsz].bool()
|
311 |
+
else:
|
312 |
+
mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
|
313 |
+
mask = mask_next
|
314 |
+
# if not cfg == 1.0:
|
315 |
+
# mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
316 |
+
|
317 |
+
# sample token latents for this step
|
318 |
+
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
319 |
+
# cfg schedule follow Muse
|
320 |
+
if cfg_schedule == "linear":
|
321 |
+
cfg_iter = 1 + (cfg - 1) * (m * n - mask_len[0]) / (m * n)
|
322 |
+
elif cfg_schedule == "constant":
|
323 |
+
cfg_iter = cfg
|
324 |
+
else:
|
325 |
+
raise NotImplementedError
|
326 |
+
sampled_token_latent = self.mar.diffloss.sample(
|
327 |
+
z, temperature, cfg_iter
|
328 |
+
).to(self.dtype)
|
329 |
+
# if not cfg == 1.0:
|
330 |
+
# sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
|
331 |
+
# mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
|
332 |
+
|
333 |
+
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
|
334 |
+
if cfg != 1.0:
|
335 |
+
cur_tokens[bsz // 2 :] = cur_tokens[: bsz // 2]
|
336 |
+
tokens = cur_tokens.clone()
|
337 |
+
|
338 |
+
pred = self.decode(tokens.view(bsz, m, n, -1))
|
339 |
+
|
340 |
+
if cfg != 1.0:
|
341 |
+
pred = pred[: bsz // 2]
|
342 |
+
return pred
|
src/optimisers/constructor.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
|
5 |
+
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
6 |
+
OPTIMIZERS)
|
7 |
+
|
8 |
+
|
9 |
+
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
10 |
+
decay = []
|
11 |
+
no_decay = []
|
12 |
+
for name, param in model.named_parameters():
|
13 |
+
if not param.requires_grad:
|
14 |
+
continue # frozen weights
|
15 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
|
16 |
+
no_decay.append(param) # no weight decay on bias, norm and diffloss
|
17 |
+
else:
|
18 |
+
decay.append(param)
|
19 |
+
|
20 |
+
num_decay_params = sum(p.numel() for p in decay)
|
21 |
+
num_nodecay_params = sum(p.numel() for p in no_decay)
|
22 |
+
print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters")
|
23 |
+
print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters")
|
24 |
+
|
25 |
+
return [
|
26 |
+
{'params': no_decay, 'weight_decay': 0.},
|
27 |
+
{'params': decay, 'weight_decay': weight_decay}]
|
28 |
+
|
29 |
+
|
30 |
+
class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
31 |
+
def __call__(self, model: nn.Module) -> OptimWrapper:
|
32 |
+
if hasattr(model, 'module'):
|
33 |
+
model = model.module
|
34 |
+
|
35 |
+
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
|
36 |
+
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
|
37 |
+
optimizer_cfg = self.optimizer_cfg.copy()
|
38 |
+
optimizer_cls = self.optimizer_cfg['type']
|
39 |
+
# Optimizer like HybridAdam in colossalai requires the argument name
|
40 |
+
# `model_params` rather than `params`. Here we get the first argument
|
41 |
+
# name and fill it with the model parameters.
|
42 |
+
if isinstance(optimizer_cls, str):
|
43 |
+
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
|
44 |
+
optimizer_cls = registry.get(self.optimizer_cfg['type'])
|
45 |
+
fisrt_arg_name = next(
|
46 |
+
iter(inspect.signature(optimizer_cls).parameters))
|
47 |
+
# import pdb; pdb.set_trace()
|
48 |
+
param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0))
|
49 |
+
optimizer_cfg[fisrt_arg_name] = param_groups
|
50 |
+
optimizer = OPTIMIZERS.build(optimizer_cfg)
|
51 |
+
|
52 |
+
# # if no paramwise option is specified, just use the global setting
|
53 |
+
# if not self.paramwise_cfg:
|
54 |
+
# optimizer_cfg[fisrt_arg_name] = model.parameters()
|
55 |
+
# optimizer = OPTIMIZERS.build(optimizer_cfg)
|
56 |
+
# else:
|
57 |
+
# # set param-wise lr and weight decay recursively
|
58 |
+
# params: List = []
|
59 |
+
# self.add_params(params, model)
|
60 |
+
# optimizer_cfg[fisrt_arg_name] = params
|
61 |
+
# optimizer = OPTIMIZERS.build(optimizer_cfg)
|
62 |
+
optim_wrapper = OPTIM_WRAPPERS.build(
|
63 |
+
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
|
64 |
+
return optim_wrapper
|
src/optimisers/custom_adamw.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from torch.optim import AdamW
|
3 |
+
|
4 |
+
|
5 |
+
class CustomAdamW(AdamW):
|
6 |
+
def __init__(self, params, weight_decay, *args, **kwargs):
|
7 |
+
import pdb; pdb.set_trace()
|
8 |
+
if isinstance(params, dict):
|
9 |
+
params = [p for p in params.values() if p.requires_grad]
|
10 |
+
else:
|
11 |
+
params = [p for p in params if p.requires_grad]
|
12 |
+
|
13 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
14 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
15 |
+
decay_params = [p for p in params if p.dim() >= 2]
|
16 |
+
nodecay_params = [p for p in params if p.dim() < 2]
|
17 |
+
optim_groups = [
|
18 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
19 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
20 |
+
]
|
21 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
22 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
23 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
24 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
25 |
+
# Create AdamW optimizer and use the fused version if it is available
|
26 |
+
# fused_available = 'fused' in inspect.signature(AdamW).parameters
|
27 |
+
# extra_args = dict(fused=True) if fused_available else dict()
|
28 |
+
# print(f"using fused AdamW: {fused_available}")
|
29 |
+
|
30 |
+
# kwargs.update(extra_args)
|
31 |
+
|
32 |
+
super().__init__(params=optim_groups, *args, **kwargs)
|
src/runners/custom_runner.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import inspect
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, Dict, List, Optional, Union
|
8 |
+
|
9 |
+
from mmengine.logging import print_log
|
10 |
+
from mmengine.dist import get_rank
|
11 |
+
from mmengine.dataset import worker_init_fn as default_worker_init_fn
|
12 |
+
from mmengine.utils import digit_version
|
13 |
+
from mmengine.utils.dl_utils import TORCH_VERSION
|
14 |
+
from mmengine.runner import FlexibleRunner
|
15 |
+
from mmengine.registry import (
|
16 |
+
DATA_SAMPLERS,
|
17 |
+
DATASETS,
|
18 |
+
FUNCTIONS,
|
19 |
+
)
|
20 |
+
from xtuner.registry import BUILDER
|
21 |
+
|
22 |
+
|
23 |
+
def clean_concatdataset_fields(cfg):
|
24 |
+
"""
|
25 |
+
递归清除所有 ConcatDataset 配置中的非法字段(如 image_size)
|
26 |
+
"""
|
27 |
+
if isinstance(cfg, dict):
|
28 |
+
# 如果是 ConcatDataset 层,清除非法字段
|
29 |
+
if cfg.get('type') == "ConcatDataset":
|
30 |
+
for key in ['image_size']:
|
31 |
+
if key in cfg:
|
32 |
+
del cfg[key]
|
33 |
+
|
34 |
+
# 递归处理子字段
|
35 |
+
for k, v in cfg.items():
|
36 |
+
clean_concatdataset_fields(v)
|
37 |
+
|
38 |
+
elif isinstance(cfg, list):
|
39 |
+
for item in cfg:
|
40 |
+
clean_concatdataset_fields(item)
|
41 |
+
|
42 |
+
return cfg
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class CustomRunner(FlexibleRunner):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
**kwargs,
|
50 |
+
):
|
51 |
+
super().__init__(**kwargs)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def build_dataloader(
|
55 |
+
dataloader: Union[DataLoader, Dict],
|
56 |
+
seed: Optional[int] = None,
|
57 |
+
diff_rank_seed: bool = False,
|
58 |
+
) -> DataLoader:
|
59 |
+
"""Build dataloader.
|
60 |
+
|
61 |
+
The method builds three components:
|
62 |
+
|
63 |
+
- Dataset
|
64 |
+
- Sampler
|
65 |
+
- Dataloader
|
66 |
+
|
67 |
+
An example of ``dataloader``::
|
68 |
+
|
69 |
+
dataloader = dict(
|
70 |
+
dataset=dict(type='ToyDataset'),
|
71 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
72 |
+
batch_size=1,
|
73 |
+
num_workers=9
|
74 |
+
)
|
75 |
+
|
76 |
+
Args:
|
77 |
+
dataloader (DataLoader or dict): A Dataloader object or a dict to
|
78 |
+
build Dataloader object. If ``dataloader`` is a Dataloader
|
79 |
+
object, just returns itself.
|
80 |
+
seed (int, optional): Random seed. Defaults to None.
|
81 |
+
diff_rank_seed (bool): Whether or not set different seeds to
|
82 |
+
different ranks. If True, the seed passed to sampler is set
|
83 |
+
to None, in order to synchronize the seeds used in samplers
|
84 |
+
across different ranks. Defaults to False.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Dataloader: DataLoader build from ``dataloader_cfg``.
|
88 |
+
"""
|
89 |
+
if isinstance(dataloader, DataLoader):
|
90 |
+
return dataloader
|
91 |
+
|
92 |
+
dataloader_cfg = copy.deepcopy(dataloader)
|
93 |
+
|
94 |
+
clean_concatdataset_fields(dataloader_cfg)
|
95 |
+
|
96 |
+
# build dataset
|
97 |
+
dataset_cfg = dataloader_cfg.pop('dataset')
|
98 |
+
if isinstance(dataset_cfg, dict):
|
99 |
+
dataset = DATASETS.build(dataset_cfg)
|
100 |
+
if hasattr(dataset, 'full_init'):
|
101 |
+
dataset.full_init()
|
102 |
+
else:
|
103 |
+
# fallback to raise error in dataloader
|
104 |
+
# if `dataset_cfg` is not a valid type
|
105 |
+
dataset = dataset_cfg
|
106 |
+
|
107 |
+
# build sampler
|
108 |
+
sampler_cfg = dataloader_cfg.pop('sampler')
|
109 |
+
if isinstance(sampler_cfg, dict):
|
110 |
+
sampler_seed = None if diff_rank_seed else seed
|
111 |
+
sampler = DATA_SAMPLERS.build(
|
112 |
+
sampler_cfg,
|
113 |
+
default_args=dict(dataset=dataset, seed=sampler_seed))
|
114 |
+
else:
|
115 |
+
# fallback to raise error in dataloader
|
116 |
+
# if `sampler_cfg` is not a valid type
|
117 |
+
sampler = sampler_cfg
|
118 |
+
|
119 |
+
# build batch sampler
|
120 |
+
batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
|
121 |
+
if batch_sampler_cfg is None:
|
122 |
+
batch_sampler = None
|
123 |
+
elif isinstance(batch_sampler_cfg, dict):
|
124 |
+
batch_sampler = DATA_SAMPLERS.build(
|
125 |
+
batch_sampler_cfg,
|
126 |
+
default_args=dict(
|
127 |
+
dataset=dataset,
|
128 |
+
sampler=sampler,
|
129 |
+
batch_size=dataloader_cfg.pop('batch_size')))
|
130 |
+
else:
|
131 |
+
# fallback to raise error in dataloader
|
132 |
+
# if `batch_sampler_cfg` is not a valid type
|
133 |
+
batch_sampler = batch_sampler_cfg
|
134 |
+
|
135 |
+
# build dataloader
|
136 |
+
init_fn: Optional[partial]
|
137 |
+
if 'worker_init_fn' in dataloader_cfg:
|
138 |
+
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
|
139 |
+
worker_init_fn_type = worker_init_fn_cfg.pop('type')
|
140 |
+
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
|
141 |
+
assert callable(worker_init_fn)
|
142 |
+
init_fn = partial(worker_init_fn,
|
143 |
+
**worker_init_fn_cfg) # type: ignore
|
144 |
+
else:
|
145 |
+
if seed is not None:
|
146 |
+
disable_subprocess_warning = dataloader_cfg.pop(
|
147 |
+
'disable_subprocess_warning', False)
|
148 |
+
assert isinstance(disable_subprocess_warning, bool), (
|
149 |
+
'disable_subprocess_warning should be a bool, but got '
|
150 |
+
f'{type(disable_subprocess_warning)}')
|
151 |
+
init_fn = partial(
|
152 |
+
default_worker_init_fn,
|
153 |
+
num_workers=dataloader_cfg.get('num_workers'),
|
154 |
+
rank=get_rank(),
|
155 |
+
seed=seed,
|
156 |
+
disable_subprocess_warning=disable_subprocess_warning)
|
157 |
+
else:
|
158 |
+
init_fn = None
|
159 |
+
|
160 |
+
# `persistent_workers` requires pytorch version >= 1.7
|
161 |
+
if ('persistent_workers' in dataloader_cfg
|
162 |
+
and digit_version(TORCH_VERSION) < digit_version('1.7.0')):
|
163 |
+
print_log(
|
164 |
+
'`persistent_workers` is only available when '
|
165 |
+
'pytorch version >= 1.7',
|
166 |
+
logger='current',
|
167 |
+
level=logging.WARNING)
|
168 |
+
dataloader_cfg.pop('persistent_workers')
|
169 |
+
|
170 |
+
# The default behavior of `collat_fn` in dataloader is to
|
171 |
+
# merge a list of samples to form a mini-batch of Tensor(s).
|
172 |
+
# However, in mmengine, if `collate_fn` is not defined in
|
173 |
+
# dataloader_cfg, `pseudo_collate` will only convert the list of
|
174 |
+
# samples into a dict without stacking the batch tensor.
|
175 |
+
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
|
176 |
+
dict(type='pseudo_collate'))
|
177 |
+
if isinstance(collate_fn_cfg, dict):
|
178 |
+
collate_fn_type = collate_fn_cfg.pop('type')
|
179 |
+
if isinstance(collate_fn_type, str):
|
180 |
+
collate_fn = FUNCTIONS.get(collate_fn_type)
|
181 |
+
elif inspect.isclass(collate_fn_type):
|
182 |
+
collate_fn_cfg['type'] = collate_fn_type
|
183 |
+
collate_fn = BUILDER.build(collate_fn_cfg)
|
184 |
+
else:
|
185 |
+
collate_fn = collate_fn_type
|
186 |
+
if not inspect.isclass(collate_fn_type):
|
187 |
+
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
188 |
+
elif callable(collate_fn_cfg):
|
189 |
+
collate_fn = collate_fn_cfg
|
190 |
+
else:
|
191 |
+
raise TypeError(
|
192 |
+
'collate_fn should be a dict or callable object, but got '
|
193 |
+
f'{collate_fn_cfg}')
|
194 |
+
data_loader = DataLoader(
|
195 |
+
dataset=dataset,
|
196 |
+
sampler=sampler if batch_sampler is None else None,
|
197 |
+
batch_sampler=batch_sampler,
|
198 |
+
collate_fn=collate_fn,
|
199 |
+
worker_init_fn=init_fn,
|
200 |
+
**dataloader_cfg)
|
201 |
+
|
202 |
+
return data_loader
|